whispercpp 1.2.0.2 → 1.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (135) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +5 -0
  3. data/LICENSE +1 -1
  4. data/README.md +165 -434
  5. data/Rakefile +46 -86
  6. data/ext/.gitignore +13 -0
  7. data/ext/cpu.mk +9 -0
  8. data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
  9. data/ext/extconf.rb +185 -7
  10. data/ext/ggml/include/ggml-alloc.h +76 -0
  11. data/ext/ggml/include/ggml-backend.h +352 -0
  12. data/ext/ggml/include/ggml-blas.h +25 -0
  13. data/ext/ggml/include/ggml-cann.h +123 -0
  14. data/ext/ggml/include/ggml-cpp.h +38 -0
  15. data/ext/ggml/include/ggml-cpu.h +135 -0
  16. data/ext/ggml/include/ggml-cuda.h +47 -0
  17. data/ext/ggml/include/ggml-kompute.h +50 -0
  18. data/ext/ggml/include/ggml-metal.h +66 -0
  19. data/ext/ggml/include/ggml-opencl.h +26 -0
  20. data/ext/ggml/include/ggml-opt.h +216 -0
  21. data/ext/ggml/include/ggml-rpc.h +28 -0
  22. data/ext/ggml/include/ggml-sycl.h +49 -0
  23. data/ext/ggml/include/ggml-vulkan.h +31 -0
  24. data/ext/ggml/include/ggml.h +2285 -0
  25. data/ext/ggml/src/ggml-alloc.c +1037 -0
  26. data/ext/ggml/src/ggml-amx/common.h +94 -0
  27. data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  28. data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
  29. data/ext/ggml/src/ggml-amx/mmq.h +17 -0
  30. data/ext/ggml/src/ggml-backend-impl.h +256 -0
  31. data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
  32. data/ext/ggml/src/ggml-backend.cpp +1999 -0
  33. data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
  34. data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
  35. data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
  36. data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
  37. data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  38. data/ext/ggml/src/ggml-cann/common.h +286 -0
  39. data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
  40. data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
  41. data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
  42. data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
  43. data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
  44. data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
  45. data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  46. data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
  47. data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
  48. data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
  49. data/ext/ggml/src/ggml-common.h +1853 -0
  50. data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  51. data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
  52. data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
  53. data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  54. data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  55. data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  56. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
  57. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  58. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  59. data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  60. data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
  61. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  62. data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  63. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  64. data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  65. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  66. data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
  67. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  68. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  69. data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  70. data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
  71. data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  72. data/ext/ggml/src/ggml-impl.h +556 -0
  73. data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
  74. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  75. data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
  76. data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
  77. data/ext/ggml/src/ggml-opt.cpp +854 -0
  78. data/ext/ggml/src/ggml-quants.c +5238 -0
  79. data/ext/ggml/src/ggml-quants.h +100 -0
  80. data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
  81. data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
  82. data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
  83. data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
  84. data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
  85. data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  86. data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  87. data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
  88. data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
  89. data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  90. data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
  91. data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
  92. data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
  93. data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
  94. data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
  95. data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
  96. data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  97. data/ext/ggml/src/ggml-threading.cpp +12 -0
  98. data/ext/ggml/src/ggml-threading.h +14 -0
  99. data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
  100. data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  101. data/ext/ggml/src/ggml.c +7694 -0
  102. data/ext/include/whisper.h +672 -0
  103. data/ext/metal-embed.mk +17 -0
  104. data/ext/metal.mk +6 -0
  105. data/ext/ruby_whisper.cpp +1608 -159
  106. data/ext/ruby_whisper.h +10 -0
  107. data/ext/scripts/get-flags.mk +38 -0
  108. data/ext/src/coreml/whisper-decoder-impl.h +146 -0
  109. data/ext/src/coreml/whisper-decoder-impl.m +201 -0
  110. data/ext/src/coreml/whisper-encoder-impl.h +142 -0
  111. data/ext/src/coreml/whisper-encoder-impl.m +197 -0
  112. data/ext/src/coreml/whisper-encoder.h +26 -0
  113. data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
  114. data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
  115. data/ext/src/whisper.cpp +7393 -0
  116. data/extsources.rb +6 -0
  117. data/lib/whisper/model/uri.rb +157 -0
  118. data/lib/whisper.rb +2 -0
  119. data/tests/helper.rb +7 -0
  120. data/tests/jfk_reader/.gitignore +5 -0
  121. data/tests/jfk_reader/extconf.rb +3 -0
  122. data/tests/jfk_reader/jfk_reader.c +68 -0
  123. data/tests/test_callback.rb +160 -0
  124. data/tests/test_error.rb +20 -0
  125. data/tests/test_model.rb +71 -0
  126. data/tests/test_package.rb +31 -0
  127. data/tests/test_params.rb +160 -0
  128. data/tests/test_segment.rb +83 -0
  129. data/tests/test_whisper.rb +211 -123
  130. data/whispercpp.gemspec +36 -0
  131. metadata +137 -11
  132. data/ext/ggml.c +0 -8616
  133. data/ext/ggml.h +0 -748
  134. data/ext/whisper.cpp +0 -4829
  135. data/ext/whisper.h +0 -402
@@ -0,0 +1,2188 @@
1
+ /*
2
+ * Copyright (c) 2023-2024 The ggml authors
3
+ *
4
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
5
+ * of this software and associated documentation files (the "Software"), to
6
+ * deal in the Software without restriction, including without limitation the
7
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
8
+ * sell copies of the Software, and to permit persons to whom the Software is
9
+ * furnished to do so, subject to the following conditions:
10
+ *
11
+ * The above copyright notice and this permission notice shall be included in
12
+ * all copies or substantial portions of the Software.
13
+ *
14
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
20
+ * IN THE SOFTWARE.
21
+ */
22
+
23
+ #include "ggml-cann.h"
24
+
25
+ #include <acl/acl.h>
26
+ #include <stdarg.h>
27
+
28
+ #include <cmath>
29
+ #include <cstdio>
30
+ #include <cstring>
31
+ #include <mutex>
32
+
33
+ #include "ggml-impl.h"
34
+ #include "ggml-backend-impl.h"
35
+ #include "ggml-cann/aclnn_ops.h"
36
+ #include "ggml-cann/common.h"
37
+
38
+ #define GGML_COMMON_DECL_C
39
+
40
+ #include "ggml-common.h"
41
+
42
+ #define GGML_CANN_NAME "CANN"
43
+
44
+ /**
45
+ * @brief Handles CANN errors by printing an error message and aborting.
46
+ *
47
+ * @param stmt The statement that caused the error.
48
+ * @param func The function in which the error occurred.
49
+ * @param file The file in which the error occurred.
50
+ * @param line The line number where the error occurred.
51
+ * @param msg The error message.
52
+ */
53
+ [[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
54
+ const char* file, int line, const char* msg) {
55
+ int32_t id = -1;
56
+ aclrtGetDevice(&id);
57
+
58
+ GGML_LOG_ERROR("CANN error: %s\n", msg);
59
+ GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func,
60
+ file, line);
61
+ GGML_LOG_ERROR(" %s\n", stmt);
62
+ // abort with GGML_ASSERT to get a stack trace
63
+ GGML_ABORT("CANN error");
64
+ }
65
+
66
+ /**
67
+ * @brief Sets the device to be used by CANN.
68
+ *
69
+ * @param device The device ID to set.
70
+ */
71
+ void ggml_cann_set_device(const int32_t device) {
72
+ // TODO: uncomment these lines after empty context has fixed.
73
+ // int current_device;
74
+ // ACL_CHECK(aclrtGetDevice(&current_device));
75
+
76
+ // if (device == current_device) {
77
+ // return;
78
+ // }
79
+ ACL_CHECK(aclrtSetDevice(device));
80
+ }
81
+
82
+ /**
83
+ * @brief Retrieves the current device ID.
84
+ *
85
+ * @return The current device ID.
86
+ */
87
+ int32_t ggml_cann_get_device() {
88
+ int32_t id;
89
+ ACL_CHECK(aclrtGetDevice(&id));
90
+ return id;
91
+ }
92
+
93
+ /**
94
+ * @brief Initialize the CANN device information.
95
+ *
96
+ * This function initializes the CANN device information by obtaining the
97
+ * device count and setting the memory allocation granularity for each device.
98
+ *
99
+ * @return A structure containing the device information.
100
+ */
101
+ static ggml_cann_device_info ggml_cann_init() {
102
+ ggml_cann_device_info info = {};
103
+
104
+ aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
105
+
106
+ if (err != ACL_SUCCESS) {
107
+ GGML_LOG_ERROR("%s: failed to initialize CANN: %s\n",
108
+ __func__, aclGetRecentErrMsg());
109
+ return info;
110
+ }
111
+
112
+ GGML_ASSERT(info.device_count <= GGML_CANN_MAX_DEVICES);
113
+
114
+ for (int id = 0; id < info.device_count; ++id) {
115
+ aclrtPhysicalMemProp prop = {};
116
+ prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
117
+ prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
118
+ prop.memAttr = ACL_HBM_MEM_HUGE;
119
+ prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
120
+ prop.location.id = id;
121
+ prop.reserve = 0;
122
+ ACL_CHECK(aclrtMemGetAllocationGranularity(
123
+ &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
124
+ &info.devices[id].vmm_granularity));
125
+
126
+ size_t free, total;
127
+ ggml_backend_cann_get_device_memory(id, &free, &total);
128
+ info.devices[id].total_vram = free;
129
+ }
130
+
131
+ // TODO: add more device info later.
132
+ return info;
133
+ }
134
+
135
+ /**
136
+ * @brief Retrieve the CANN device information.
137
+ *
138
+ * This function returns a reference to a structure containing the CANN device
139
+ * information. The device information is initialized once and reused on
140
+ * subsequent calls.
141
+ *
142
+ * @return A reference to the structure containing the device information.
143
+ */
144
+ const ggml_cann_device_info& ggml_cann_info() {
145
+ static ggml_cann_device_info info = ggml_cann_init();
146
+ return info;
147
+ }
148
+
149
+ //#define DEBUG_CANN_MALLOC
150
+ /**
151
+ * @brief A pool of CANN buffers(legacy).
152
+ *
153
+ * This class manages a pool of CANN buffers for a specific device.
154
+ */
155
+ struct ggml_cann_pool_leg : public ggml_cann_pool {
156
+ /**
157
+ * @brief The maximum number of buffers in the pool.
158
+ */
159
+ static const int MAX_BUFFERS = 256;
160
+
161
+ /**
162
+ * @brief The device ID associated with this buffer pool.
163
+ */
164
+ int device;
165
+
166
+ /**
167
+ * @brief Structure representing a CANN buffer.
168
+ */
169
+ struct ggml_cann_buffer {
170
+ void* ptr = nullptr; ///< Pointer to the buffer memory.
171
+ size_t size = 0; ///< Size of the buffer.
172
+ };
173
+
174
+ /**
175
+ * @brief Array of CANN buffers in the pool.
176
+ */
177
+ ggml_cann_buffer buffer_pool[MAX_BUFFERS] = {};
178
+
179
+ /**
180
+ * @brief Total size of all buffers in the pool.
181
+ */
182
+ size_t pool_size = 0;
183
+
184
+ /**
185
+ * @brief Constructor to initialize the buffer pool for a specific device.
186
+ *
187
+ * @param device The device ID to associate with this buffer pool.
188
+ */
189
+ explicit ggml_cann_pool_leg(int device) : device(device) {}
190
+
191
+ /**
192
+ * @brief Destructor to free all buffers in the pool.
193
+ */
194
+ ~ggml_cann_pool_leg() {
195
+ ggml_cann_set_device(device);
196
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
197
+ ggml_cann_buffer& b = buffer_pool[i];
198
+ if (b.ptr != nullptr) {
199
+ ACL_CHECK(aclrtFree(b.ptr));
200
+ pool_size -= b.size;
201
+ }
202
+ }
203
+ GGML_ASSERT(pool_size == 0);
204
+ }
205
+
206
+ /**
207
+ * @brief Allocate a buffer of the given size.
208
+ *
209
+ * @param size The size of the buffer to allocate.
210
+ * @param actual_size A pointer to a variable to receive the actual size of
211
+ * the allocated buffer.
212
+ * @return A pointer to the allocated buffer.
213
+ */
214
+ void* alloc(size_t size, size_t* actual_size) override {
215
+ const size_t alignment = 128;
216
+ size = GGML_PAD(size, alignment);
217
+ if (size == 0) {
218
+ size = alignment;
219
+ }
220
+ #ifdef DEBUG_CANN_MALLOC
221
+ int nnz = 0;
222
+ size_t max_size = 0;
223
+ #endif
224
+ size_t best_diff = 1ull << 36;
225
+ int ibest = -1;
226
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
227
+ ggml_cann_buffer& b = buffer_pool[i];
228
+ if (b.ptr != nullptr) {
229
+ #ifdef DEBUG_CANN_MALLOC
230
+ ++nnz;
231
+ if (b.size > max_size) max_size = b.size;
232
+ #endif
233
+ if (b.size >= size) {
234
+ size_t diff = b.size - size;
235
+ if (diff < best_diff) {
236
+ best_diff = diff;
237
+ ibest = i;
238
+ if (!best_diff) {
239
+ void* ptr = b.ptr;
240
+ *actual_size = b.size;
241
+ b.ptr = nullptr;
242
+ b.size = 0;
243
+ return ptr;
244
+ }
245
+ }
246
+ }
247
+ }
248
+ }
249
+ if (ibest >= 0) {
250
+ ggml_cann_buffer& b = buffer_pool[ibest];
251
+ void* ptr = b.ptr;
252
+ *actual_size = b.size;
253
+ b.ptr = nullptr;
254
+ b.size = 0;
255
+ return ptr;
256
+ }
257
+ void* ptr;
258
+ ggml_cann_set_device(device);
259
+ ACL_CHECK(
260
+ aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
261
+ *actual_size = size;
262
+ pool_size += size;
263
+ #ifdef DEBUG_CANN_MALLOC
264
+ GGML_LOG_INFO(
265
+ "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, "
266
+ "requested %u MB\n",
267
+ __func__, device, nnz, (uint32_t)(max_size / 1024 / 1024),
268
+ (uint32_t)(pool_size / 1024 / 1024),
269
+ (uint32_t)(size / 1024 / 1024));
270
+ #endif
271
+ return ptr;
272
+ }
273
+
274
+ /**
275
+ * @brief Free a buffer and return it to the pool.
276
+ *
277
+ * @param ptr Pointer to the buffer to free.
278
+ * @param size Size of the buffer to free.
279
+ */
280
+ void free(void* ptr, size_t size) override {
281
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
282
+ ggml_cann_buffer& b = buffer_pool[i];
283
+ if (b.ptr == nullptr) {
284
+ b.ptr = ptr;
285
+ b.size = size;
286
+ return;
287
+ }
288
+ }
289
+ // memory should always buffered. these memory may still needed by
290
+ // tasks in stream.
291
+ // TODO, fix me.
292
+ GGML_ABORT("Cann buffer pool full, increase MAX_CANN_BUFFERS\n");
293
+ }
294
+ };
295
+
296
+ /**
297
+ * @brief A pool of CANN buffers with virtual memory.
298
+ *
299
+ * This class manages a pool of CANN buffers with virtual memory for a specific
300
+ * device.
301
+ */
302
+ struct ggml_cann_pool_vmm : public ggml_cann_pool {
303
+ /**
304
+ * @brief The maximum size of the virtual memory pool (32 GB).
305
+ */
306
+ size_t max_size;
307
+
308
+ /**
309
+ * @brief The device ID associated with this buffer pool.
310
+ */
311
+ int device;
312
+
313
+ /**
314
+ * @brief Pointer to the start of the virtual memory pool.
315
+ */
316
+ void* pool_addr = 0;
317
+
318
+ /**
319
+ * @brief Amount of virtual memory used in the pool.
320
+ */
321
+ size_t pool_used = 0;
322
+
323
+ /**
324
+ * @brief Total size of the virtual memory pool.
325
+ */
326
+ size_t pool_size = 0;
327
+
328
+ /**
329
+ * @brief Allocation granularity for the virtual memory pool.
330
+ */
331
+ size_t granularity;
332
+
333
+ /**
334
+ * @brief Handles for the physical memory allocated.
335
+ */
336
+ std::vector<aclrtDrvMemHandle> handles;
337
+
338
+ /**
339
+ * @brief Offsets for the mapped memory regions.
340
+ */
341
+ std::vector<void*> map_offsets;
342
+
343
+ /**
344
+ * @brief Constructor to initialize the buffer pool with virtual memory for
345
+ * a specific device.
346
+ *
347
+ * @param device The device ID to associate with this buffer pool.
348
+ */
349
+ explicit ggml_cann_pool_vmm(int device)
350
+ : device(device),
351
+ granularity(ggml_cann_info().devices[device].vmm_granularity) {
352
+ auto dev = ggml_cann_info().devices[device];
353
+ granularity = dev.vmm_granularity;
354
+ max_size = dev.total_vram;
355
+ }
356
+
357
+ /**
358
+ * @brief Destructor to free all buffers in the virtual memory pool.
359
+ */
360
+ ~ggml_cann_pool_vmm() {
361
+ if (pool_addr != 0) {
362
+ for (auto& offset : map_offsets) {
363
+ ACL_CHECK(aclrtUnmapMem(offset));
364
+ }
365
+ for (auto& handle : handles) {
366
+ ACL_CHECK(aclrtFreePhysical(handle));
367
+ }
368
+ ACL_CHECK(aclrtReleaseMemAddress(pool_addr));
369
+ }
370
+ }
371
+
372
+ /**
373
+ * @brief Allocate a buffer of the given size in the virtual memory pool.
374
+ *
375
+ * @param size The size of the buffer to allocate.
376
+ * @param actual_size A pointer to a variable to receive the actual size of
377
+ * the allocated buffer.
378
+ * @return A pointer to the allocated buffer.
379
+ */
380
+ void* alloc(size_t size, size_t* actual_size) override {
381
+ // round up the allocation size to the alignment to ensure that all
382
+ // allocations are aligned for all data types
383
+ const size_t alignment = 128;
384
+ size = GGML_PAD(size, alignment);
385
+ if (size == 0) {
386
+ size = alignment;
387
+ }
388
+
389
+ size_t avail = pool_size - pool_used;
390
+
391
+ if (size > avail) {
392
+ // round up to the next multiple of the granularity
393
+ size_t reserve_size = size - avail;
394
+ reserve_size = GGML_PAD(reserve_size, granularity);
395
+
396
+ GGML_ASSERT(pool_size + reserve_size <= max_size);
397
+
398
+ // allocate more physical memory
399
+ aclrtPhysicalMemProp prop = {};
400
+ prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
401
+ prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
402
+ prop.memAttr = ACL_HBM_MEM_HUGE;
403
+ prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
404
+ prop.location.id = device;
405
+ prop.reserve = 0;
406
+ aclrtDrvMemHandle handle;
407
+ ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0));
408
+
409
+ // reserve virtual address space (if not already reserved)
410
+ if (pool_addr == 0) {
411
+ ACL_CHECK(aclrtReserveMemAddress(
412
+ &pool_addr, max_size, 0, NULL, 1));
413
+ }
414
+
415
+ // map at the end of the pool
416
+ ACL_CHECK(aclrtMapMem((char*)pool_addr + pool_size, reserve_size, 0,
417
+ handle, 0));
418
+
419
+ handles.push_back(handle);
420
+ map_offsets.push_back((char*)pool_addr + pool_size);
421
+
422
+ // add to the pool
423
+ pool_size += reserve_size;
424
+
425
+ #ifdef DEBUG_CANN_MALLOC
426
+ GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
427
+ device, (unsigned long long) (pool_size/1024/1024),
428
+ (unsigned long long) (reserve_size/1024/1024));
429
+ #endif
430
+ }
431
+
432
+ GGML_ASSERT(pool_addr != 0);
433
+
434
+ void* ptr = (void*)((char*)pool_addr + pool_used);
435
+ *actual_size = size;
436
+ pool_used += size;
437
+
438
+ #ifdef DEBUG_CANN_MALLOC
439
+ GGML_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device,
440
+ (unsigned long long)size, (unsigned long long)ptr);
441
+ #endif
442
+ return ptr;
443
+ }
444
+
445
+ /**
446
+ * @brief Free a buffer and return it to the virtual memory pool.
447
+ *
448
+ * @param ptr Pointer to the buffer to free.
449
+ * @param size Size of the buffer to free.
450
+ */
451
+ void free(void* ptr, size_t size) override {
452
+ #ifdef DEBUG_CANN_MALLOC
453
+ GGML_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device,
454
+ (unsigned long long)size, (unsigned long long)ptr);
455
+ #endif
456
+
457
+ pool_used -= size;
458
+
459
+ // all deallocations must be in reverse order of the allocations
460
+ GGML_ASSERT(ptr == (void*)((char*)pool_addr + pool_used));
461
+ }
462
+ };
463
+
464
+ /**
465
+ * @brief Create a new CANN pool for a specific device.
466
+ *
467
+ * Factory method to create a new CANN pool object based on the device type.
468
+ *
469
+ * @param device The device ID for which to create the pool.
470
+ * @return A unique pointer to the created CANN pool.
471
+ */
472
+ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
473
+ int device) {
474
+ return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
475
+ }
476
+
477
+ // cann buffer
478
+ /**
479
+ * @brief Context for managing a CANN buffer associated with a specific device.
480
+ *
481
+ * This structure holds information about a CANN buffer, including the device
482
+ * ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID.
483
+ */
484
+ struct ggml_backend_cann_buffer_context {
485
+ int32_t device; ///< The device ID associated with this buffer context.
486
+ void* dev_ptr =
487
+ nullptr; ///< Pointer to the device memory allocated for the buffer.
488
+
489
+ /**
490
+ * @brief Constructor to initialize the CANN buffer context.
491
+ *
492
+ * @param device The device ID associated with this buffer context.
493
+ * @param dev_ptr Pointer to the device memory allocated for the buffer.
494
+ */
495
+ ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
496
+ : device(device),
497
+ dev_ptr(dev_ptr) {}
498
+
499
+ /**
500
+ * @brief Destructor to free the device memory allocated for the buffer.
501
+ */
502
+ ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
503
+ };
504
+
505
+ /**
506
+ * @brief Check if a buffer is a CANN buffer.
507
+ *
508
+ * This function checks if a given buffer is a CANN buffer by comparing its
509
+ * `get_name` function pointer to `ggml_backend_cann_buffer_get_name`.
510
+ *
511
+ * @param buffer The buffer to check.
512
+ * @return true if the buffer is a CANN buffer, false otherwise.
513
+ */
514
+ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft);
515
+ static bool ggml_backend_buffer_is_cann(
516
+ ggml_backend_buffer_t buffer) {
517
+ return ggml_backend_buft_is_cann(buffer->buft);
518
+ }
519
+
520
+ /**
521
+ * @brief Free resources associated with a CANN buffer.
522
+ *
523
+ * This function frees the resources associated with a CANN buffer, including
524
+ * its context.
525
+ *
526
+ * @param buffer The CANN buffer to free.
527
+ */
528
+ static void ggml_backend_cann_buffer_free_buffer(
529
+ ggml_backend_buffer_t buffer) {
530
+ ggml_backend_cann_buffer_context* ctx =
531
+ (ggml_backend_cann_buffer_context*)buffer->context;
532
+ delete ctx;
533
+ }
534
+
535
+ /**
536
+ * @brief Retrieve the base pointer of a CANN buffer.
537
+ *
538
+ * This function returns the base pointer of a CANN buffer, which points to the
539
+ * device memory allocated for the buffer.
540
+ *
541
+ * @param buffer The CANN buffer whose base pointer is to be retrieved.
542
+ * @return A pointer to the base of the device memory allocated for the buffer.
543
+ */
544
+ static void* ggml_backend_cann_buffer_get_base(
545
+ ggml_backend_buffer_t buffer) {
546
+ ggml_backend_cann_buffer_context* ctx =
547
+ (ggml_backend_cann_buffer_context*)buffer->context;
548
+ return ctx->dev_ptr;
549
+ }
550
+
551
+ /**
552
+ * @brief Transform quantized Q4.0 tensor data into a format suitable for CANN
553
+ * processing.
554
+ *
555
+ * This function transforms quantized Q4.0 tensor data into a format suitable
556
+ * for CANN processing. It extracts quantization values and scales from the
557
+ * source data and prepares them in a format expected by CANN operations.
558
+ *
559
+ * @param tensor Pointer to the tensor information.
560
+ * @param src Pointer to the source data in Q4.0 format.
561
+ * @param dst Pointer to the destination buffer where transformed data will be
562
+ * stored.
563
+ */
564
+ static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
565
+ const void* src,
566
+ void* dst) {
567
+
568
+ int64_t n_elems = ggml_nelements(tensor);
569
+ int64_t groups = n_elems / QK4_0;
570
+ size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
571
+
572
+ uint8_t* quant_offset = (uint8_t*)dst;
573
+ uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
574
+
575
+ for (int i = 0; i < groups; i++) {
576
+ const block_q4_0* group =
577
+ (const block_q4_0*)((const char*)src + i * sizeof(block_q4_0));
578
+ *scale_offset = group->d;
579
+ scale_offset++;
580
+
581
+ // 0-15
582
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
583
+ (*quant_offset) = (group->qs[j] & 0x0F);
584
+ (*quant_offset) |= ((group->qs[j + 1] << 4));
585
+ quant_offset++;
586
+ }
587
+
588
+ // 16-31
589
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
590
+ (*quant_offset) = (group->qs[j] >> 4);
591
+ (*quant_offset) |= (group->qs[j + 1] & 0xF0);
592
+ quant_offset++;
593
+ }
594
+ }
595
+
596
+ // put (uint4b_t -8) into int4b_t
597
+ for (quant_offset = (uint8_t*)dst;
598
+ quant_offset < (uint8_t*)dst + quant_bytes; quant_offset++) {
599
+ (*quant_offset) ^= 0x88;
600
+ }
601
+ }
602
+
603
+ /**
604
+ * @brief Transform CANN processed data back into quantized Q4.0 format.
605
+ *
606
+ * This function transforms CANN processed data back into quantized Q4.0 format.
607
+ * It reverses the transformation performed by
608
+ * ggml_backend_cann_transform_q4_0(), converting the data back into its
609
+ * original quantized form.
610
+ *
611
+ * @param tensor Pointer to the tensor information.
612
+ * @param src Pointer to the source buffer containing transformed data.
613
+ * @param dst Pointer to the destination buffer where the Q4.0 formatted data
614
+ * will be stored.
615
+ */
616
+ static void ggml_backend_cann_transform_back_q4_0(
617
+ const ggml_tensor* tensor, void* src, void* dst) {
618
+
619
+ int64_t n_elems = ggml_nelements(tensor);
620
+ int64_t groups = n_elems / QK4_0;
621
+ size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
622
+
623
+ uint8_t* quant_offset = (uint8_t*)src;
624
+ uint16_t* scale_offset = (uint16_t*)((char*)src + quant_bytes);
625
+
626
+ for (; quant_offset < (uint8_t*)src + quant_bytes; quant_offset++) {
627
+ (*quant_offset) ^= 0x88;
628
+ }
629
+ quant_offset = (uint8_t*)src;
630
+
631
+ for (int i = 0; i < groups; i++) {
632
+ block_q4_0* group = (block_q4_0*)((char*)dst + i * sizeof(block_q4_0));
633
+ group->d = *scale_offset;
634
+ scale_offset++;
635
+
636
+ // 0-15
637
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
638
+ group->qs[j] = ((*quant_offset) & 0x0F);
639
+ group->qs[j + 1] = ((*quant_offset) >> 4);
640
+ quant_offset++;
641
+ }
642
+
643
+ // 16-31
644
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
645
+ group->qs[j] |= ((*quant_offset) << 4);
646
+ group->qs[j + 1] |= ((*quant_offset) & 0xF0);
647
+ quant_offset++;
648
+ }
649
+ }
650
+ }
651
+
652
+ /**
653
+ * @brief Transform quantized Q8.0 tensor data into a format suitable for CANN
654
+ * processing.
655
+ *
656
+ * This function transforms quantized Q8.0 tensor data into a format suitable
657
+ * for CANN processing. It extracts quantization values and scales from the
658
+ * source data and prepares them in a format expected by CANN operations.
659
+ *
660
+ * @param tensor Pointer to the tensor information.
661
+ * @param src Pointer to the source data in Q8.0 format.
662
+ * @param dst Pointer to the destination buffer where transformed data will be
663
+ * stored.
664
+ */
665
+ static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor,
666
+ const void* src,
667
+ void* dst) {
668
+ int64_t n_elems = ggml_nelements(tensor);
669
+ int64_t groups = n_elems / QK8_0;
670
+ size_t quant_bytes = n_elems * sizeof(uint8_t);
671
+
672
+ uint8_t* quant_offset = (uint8_t*)dst;
673
+ uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
674
+
675
+ for (int i = 0; i < groups; i++) {
676
+ const block_q8_0* group =
677
+ (const block_q8_0*)((const char*)src + i * sizeof(block_q8_0));
678
+ *scale_offset = group->d;
679
+ scale_offset++;
680
+ size_t group_quant_size = QK8_0 * sizeof(uint8_t);
681
+ memcpy(quant_offset, group->qs, group_quant_size);
682
+ quant_offset += group_quant_size;
683
+ }
684
+ }
685
+
686
+ /**
687
+ * @brief Transform CANN processed data back into quantized Q8.0 format.
688
+ *
689
+ * This function transforms CANN processed data back into quantized Q8.0 format.
690
+ * It reverses the transformation performed by
691
+ * ggml_backend_cann_transform_q8_0(), converting the data back into its
692
+ * original quantized form.
693
+ *
694
+ * @param tensor Pointer to the tensor information.
695
+ * @param src Pointer to the source buffer containing transformed data.
696
+ * @param dst Pointer to the destination buffer where the Q8.0 formatted data
697
+ * will be stored.
698
+ */
699
+ static void ggml_backend_cann_transform_back_q8_0(
700
+ const ggml_tensor* tensor, const void* src, void* dst) {
701
+ int64_t n_elems = ggml_nelements(tensor);
702
+ int64_t groups = n_elems / QK8_0;
703
+ size_t quant_bytes = n_elems * sizeof(uint8_t);
704
+
705
+ const uint8_t* quant_offset = (const uint8_t*)src;
706
+ const uint16_t* scale_offset =
707
+ (const uint16_t*)((const char*)src + quant_bytes);
708
+
709
+ for (int i = 0; i < groups; i++) {
710
+ block_q8_0* group = (block_q8_0*)((char*)dst + i * sizeof(block_q8_0));
711
+ group->d = *scale_offset;
712
+ scale_offset++;
713
+ size_t group_quant_size = QK8_0 * sizeof(uint8_t);
714
+ memcpy(group->qs, quant_offset, group_quant_size);
715
+ quant_offset += group_quant_size;
716
+ }
717
+ }
718
+
719
+ /**
720
+ * @brief Transform tensor data based on its type for CANN processing.
721
+ *
722
+ * This function transforms tensor data based on its quantization type for CANN
723
+ * processing. It dispatches the transformation based on the tensor's type to
724
+ * specialized functions handling Q4.0 and Q8.0 formats.
725
+ *
726
+ * @param tensor Pointer to the tensor information.
727
+ * @param src Pointer to the source data to be transformed.
728
+ * @param dst Pointer to the destination buffer where transformed data will be
729
+ * stored.
730
+ */
731
+ static void ggml_backend_cann_transform(ggml_tensor* tensor,
732
+ const void* src, void* dst) {
733
+ switch (tensor->type) {
734
+ case GGML_TYPE_Q4_0:
735
+ ggml_backend_cann_transform_q4_0(tensor, src, dst);
736
+ break;
737
+ case GGML_TYPE_Q8_0:
738
+ ggml_backend_cann_transform_q8_0(tensor, src, dst);
739
+ break;
740
+ default:
741
+ break;
742
+ }
743
+ }
744
+
745
+ /**
746
+ * @brief Transform CANN processed data back into tensor data based on its type.
747
+ *
748
+ * This function transforms CANN processed data back into tensor data based on
749
+ * its quantization type for Q4.0 and Q8.0 formats. It dispatches the
750
+ * transformation based on the tensor's type to specialized functions.
751
+ *
752
+ * @param tensor Pointer to the tensor information.
753
+ * @param src Pointer to the source data containing CANN processed data.
754
+ * @param dst Pointer to the destination buffer where transformed tensor data
755
+ * will be stored.
756
+ */
757
+ static void ggml_backend_cann_transform_back(
758
+ const ggml_tensor* tensor, void* src, void* dst) {
759
+ switch (tensor->type) {
760
+ case GGML_TYPE_Q4_0:
761
+ ggml_backend_cann_transform_back_q4_0(tensor, src, dst);
762
+ break;
763
+ case GGML_TYPE_Q8_0:
764
+ ggml_backend_cann_transform_back_q8_0(tensor, src, dst);
765
+ break;
766
+ default:
767
+ break;
768
+ }
769
+ }
770
+
771
+ /**
772
+ * @brief Check if transformation is needed for a given tensor type.
773
+ *
774
+ * This function checks if transformation is needed for a given tensor type
775
+ * to prepare data for CANN processing.
776
+ *
777
+ * @param type The tensor type to check.
778
+ * @return true if transformation is needed, false otherwise.
779
+ */
780
+ static bool need_transform(ggml_type type) {
781
+ switch (type) {
782
+ case GGML_TYPE_Q4_0:
783
+ case GGML_TYPE_Q8_0:
784
+ return true;
785
+ default:
786
+ return false;
787
+ }
788
+ }
789
+
790
+ /**
791
+ * @brief Initialize a tensor using data from a CANN buffer.
792
+ *
793
+ * This function initializes a tensor using data from a CANN buffer.
794
+ * It handles special cases such as views and quantization.
795
+ *
796
+ * @param buffer The CANN buffer from which to initialize the tensor.
797
+ * @param tensor Pointer to the tensor to be initialized.
798
+ */
799
+ static void ggml_backend_cann_buffer_init_tensor(
800
+ ggml_backend_buffer_t buffer, ggml_tensor* tensor) {
801
+ if (tensor->view_src != NULL && tensor->view_offs == 0) {
802
+ GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
803
+ return;
804
+ }
805
+
806
+ // TODO: can backend doesn't support quantized yet. Just leave the code
807
+ // here.
808
+ if (ggml_is_quantized(tensor->type)) {
809
+ // Initialize padding to 0 to avoid possible NaN values
810
+ size_t original_size = ggml_nbytes(tensor);
811
+ size_t padded_size =
812
+ ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
813
+
814
+ if (padded_size > original_size && tensor->view_src == nullptr) {
815
+ size_t memset_size = padded_size - original_size;
816
+ ACL_CHECK(aclrtMemset((char*)tensor->data + original_size,
817
+ memset_size, 0, memset_size));
818
+ }
819
+ }
820
+ }
821
+
822
+ // TODO: need handle tensor which has paddings.
823
+ /**
824
+ * @brief Set tensor data in a CANN buffer.
825
+ *
826
+ * This function sets tensor data in a CANN buffer, handling transformations
827
+ * if needed based on the tensor's type.
828
+ *
829
+ * @param buffer The CANN buffer where the tensor data will be set.
830
+ * @param tensor Pointer to the tensor whose data will be set.
831
+ * @param data Pointer to the source data to be copied into the tensor.
832
+ * @param offset Offset in the source data from where to start copying.
833
+ * @param size Size of the data to be copied, in bytes.
834
+ */
835
+ static void ggml_backend_cann_buffer_set_tensor(
836
+ ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data,
837
+ size_t offset, size_t size) {
838
+ ggml_backend_cann_buffer_context *ctx =
839
+ (ggml_backend_cann_buffer_context *)buffer->context;
840
+
841
+ ggml_cann_set_device(ctx->device);
842
+ // TODO: refer to cann(#6017), it use thread's default stream.
843
+ // For acl, synchronous functions use this default stream.
844
+ // Why aclrtSynchronizeDevice?
845
+
846
+ if (!need_transform(tensor->type)) {
847
+ ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
848
+ ACL_MEMCPY_HOST_TO_DEVICE));
849
+ } else {
850
+ void *transform_buffer = malloc(size);
851
+ ggml_backend_cann_transform(tensor, data, transform_buffer);
852
+
853
+ ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size,
854
+ transform_buffer, size,
855
+ ACL_MEMCPY_HOST_TO_DEVICE));
856
+ free(transform_buffer);
857
+ }
858
+ }
859
+
860
+ /**
861
+ * @brief Get tensor data from a CANN buffer.
862
+ *
863
+ * This function retrieves tensor data from a CANN buffer, handling
864
+ * transformations if needed based on the tensor's type.
865
+ *
866
+ * @param buffer The CANN buffer from which to retrieve tensor data.
867
+ * @param tensor Pointer to the tensor whose data will be retrieved.
868
+ * @param data Pointer to the destination buffer where the tensor data will be
869
+ * copied.
870
+ * @param offset Offset in the destination buffer where to start copying.
871
+ * @param size Size of the data to be copied, in bytes.
872
+ */
873
+ static void ggml_backend_cann_buffer_get_tensor(
874
+ ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data,
875
+ size_t offset, size_t size) {
876
+ ggml_backend_cann_buffer_context* ctx =
877
+ (ggml_backend_cann_buffer_context*)buffer->context;
878
+
879
+ ggml_cann_set_device(ctx->device);
880
+
881
+ if (!need_transform(tensor->type)) {
882
+ ACL_CHECK(aclrtMemcpy(data, size, (char*)tensor->data + offset, size,
883
+ ACL_MEMCPY_DEVICE_TO_HOST));
884
+ } else {
885
+ void* transform_buffer = malloc(size);
886
+ ACL_CHECK(aclrtMemcpy(transform_buffer, size,
887
+ (char*)tensor->data + offset, size,
888
+ ACL_MEMCPY_DEVICE_TO_HOST));
889
+ ggml_backend_cann_transform_back(tensor, transform_buffer, data);
890
+ free(transform_buffer);
891
+ }
892
+ }
893
+
894
+ /**
895
+ * @brief Copy tensor data between CANN buffers if possible.
896
+ *
897
+ * This function copies tensor data between CANN buffers if the source and
898
+ * destination buffers are CANN buffers and they meet the necessary conditions
899
+ * (same device or devices can access each other).
900
+ *
901
+ * @param buffer The destination CANN buffer where the tensor data will be
902
+ * copied.
903
+ * @param src Pointer to the source tensor whose data will be copied.
904
+ * @param dst Pointer to the destination tensor where the data will be copied.
905
+ * @return true if the copy operation succeeded, false otherwise.
906
+ */
907
+ static bool ggml_backend_cann_buffer_cpy_tensor(
908
+ ggml_backend_buffer_t buffer, const ggml_tensor* src, ggml_tensor* dst) {
909
+ if (ggml_backend_buffer_is_cann(src->buffer)) {
910
+ ggml_backend_cann_buffer_context* src_ctx =
911
+ (ggml_backend_cann_buffer_context*)src->buffer->context;
912
+ ggml_backend_cann_buffer_context* dst_ctx =
913
+ (ggml_backend_cann_buffer_context*)buffer->context;
914
+
915
+ size_t memcpy_size = ggml_nbytes(src);
916
+ // Same device.
917
+ if (src_ctx->device == dst_ctx->device) {
918
+ ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
919
+ (const char*)src->data, memcpy_size,
920
+ ACL_MEMCPY_DEVICE_TO_DEVICE));
921
+ return true;
922
+ } else {
923
+ // Different device but can access by peer.
924
+ int32_t canAccessPeer = 0;
925
+ ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device,
926
+ dst_ctx->device));
927
+ if (canAccessPeer) {
928
+ ggml_cann_set_device(src_ctx->device);
929
+ ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0));
930
+ ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
931
+ (const char*)src->data, memcpy_size,
932
+ ACL_MEMCPY_DEVICE_TO_DEVICE));
933
+ return true;
934
+ }
935
+ }
936
+ }
937
+ return false;
938
+ }
939
+
940
+ /**
941
+ * @brief Clear a CANN buffer by setting all its memory to a specified value.
942
+ *
943
+ * This function clears a CANN buffer by setting all its memory to a specified
944
+ * value.
945
+ *
946
+ * @param buffer The CANN buffer to be cleared.
947
+ * @param value The value to which each byte in the buffer will be set.
948
+ */
949
+ static void ggml_backend_cann_buffer_clear(
950
+ ggml_backend_buffer_t buffer, uint8_t value) {
951
+ ggml_backend_cann_buffer_context* ctx =
952
+ (ggml_backend_cann_buffer_context*)buffer->context;
953
+
954
+ ggml_cann_set_device(ctx->device);
955
+ ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
956
+ }
957
+
958
+ /**
959
+ * @brief Interface for a CANN buffer in the backend.
960
+ *
961
+ * This structure defines function pointers to operations that can be performed
962
+ * on a CANN buffer within the backend.
963
+ */
964
+ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
965
+ /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
966
+ /* .get_base = */ ggml_backend_cann_buffer_get_base,
967
+ /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
968
+ /* .memset_tensor = */ NULL,
969
+ /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
970
+ /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
971
+ /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor,
972
+ /* .clear = */ ggml_backend_cann_buffer_clear,
973
+ /* .reset = */ NULL,
974
+ };
975
+
976
+ // cann buffer type
977
+ /**
978
+ * @brief Structure representing context information for a specific backend
979
+ * buffer type.
980
+ */
981
+ struct ggml_backend_cann_buffer_type_context {
982
+ int32_t
983
+ device; /**< Device identifier associated with the buffer context. */
984
+ std::string name; /**< Name associated with the buffer context. */
985
+ };
986
+
987
+ /**
988
+ * @brief Retrieves the name associated with a CANN buffer type.
989
+ *
990
+ * This function returns the descriptive name associated with the specified
991
+ * CANN buffer type context.
992
+ *
993
+ * @param buft Pointer to the buffer type context.
994
+ * @return Const pointer to the C-style string containing the name.
995
+ */
996
+ static const char* ggml_backend_cann_buffer_type_name(
997
+ ggml_backend_buffer_type_t buft) {
998
+ ggml_backend_cann_buffer_type_context* buft_ctx =
999
+ (ggml_backend_cann_buffer_type_context*)buft->context;
1000
+
1001
+ return buft_ctx->name.c_str();
1002
+ }
1003
+
1004
+ /**
1005
+ * @brief Allocates a new CANN buffer of the specified type and size.
1006
+ *
1007
+ * This function allocates a new CANN buffer on the specified device with the
1008
+ * given size.
1009
+ *
1010
+ * @param buft Pointer to the buffer type context.
1011
+ * @param size Size in bytes of the buffer to allocate.
1012
+ * @return Pointer to the allocated buffer, or nullptr if allocation fails.
1013
+ */
1014
+ static ggml_backend_buffer_t
1015
+ ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
1016
+ size_t size) {
1017
+ ggml_backend_cann_buffer_type_context* buft_ctx =
1018
+ (ggml_backend_cann_buffer_type_context*)buft->context;
1019
+
1020
+ ggml_cann_set_device(buft_ctx->device);
1021
+
1022
+ size = std::max(size, (size_t)1);
1023
+
1024
+ void* dev_ptr;
1025
+ aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
1026
+ if (err != ACL_SUCCESS) {
1027
+ GGML_LOG_ERROR(
1028
+ "%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
1029
+ __func__, size / 1024.0 / 1024.0, buft_ctx->device,
1030
+ aclGetRecentErrMsg());
1031
+ return nullptr;
1032
+ }
1033
+
1034
+ ggml_backend_cann_buffer_context* ctx =
1035
+ new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
1036
+
1037
+ return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
1038
+ ctx, size);
1039
+ }
1040
+
1041
+ /**
1042
+ * @brief Retrieves the memory alignment requirement for CANN buffers of this
1043
+ * type.
1044
+ *
1045
+ * This function returns the alignment requirement in bytes for memory allocated
1046
+ * by the CANN buffer type.
1047
+ *
1048
+ * @param buft Pointer to the buffer type context (unused in this
1049
+ * implementation).
1050
+ * @return The alignment requirement in bytes (fixed at 128 bytes for CANN
1051
+ * buffers).
1052
+ */
1053
+ static size_t ggml_backend_cann_buffer_type_get_alignment(
1054
+ ggml_backend_buffer_type_t buft) {
1055
+ return 128;
1056
+
1057
+ GGML_UNUSED(buft);
1058
+ }
1059
+
1060
+ /**
1061
+ * @brief Calculates the allocation size required for a tensor in a CANN buffer.
1062
+ *
1063
+ * Computes the total allocation size needed for storing the tensor's data in a
1064
+ * CANN buffer, considering any necessary padding or adjustments for quantized
1065
+ * types.
1066
+ *
1067
+ * @param buft Pointer to the buffer type context (unused in this
1068
+ * implementation).
1069
+ * @param tensor Pointer to the tensor for which the allocation size is
1070
+ * calculated.
1071
+ * @return The total allocation size in bytes required for the tensor in the
1072
+ * CANN buffer.
1073
+ */
1074
+ static size_t ggml_backend_cann_buffer_type_get_alloc_size(
1075
+ ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
1076
+ size_t size = ggml_nbytes(tensor);
1077
+ int64_t ne0 = tensor->ne[0];
1078
+
1079
+ // last line must bigger than 32, because every single op deal at
1080
+ // least 32 bytes.
1081
+ // TODO: quantized type?
1082
+ // int64_t line_size = ne0 * ggml_element_size(tensor);
1083
+ // int64_t line_size_align_32 = (line_size + 31) & ~31;
1084
+ // size += (line_size_align_32 - line_size);
1085
+
1086
+ // TODO: not support quantized yet.
1087
+ // TODO: consider un-continue tensor.
1088
+ if (ggml_is_quantized(tensor->type)) {
1089
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
1090
+ size += ggml_row_size(
1091
+ tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
1092
+ }
1093
+ }
1094
+
1095
+ return size;
1096
+
1097
+ GGML_UNUSED(buft);
1098
+ }
1099
+
1100
+ static bool ggml_backend_cann_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
1101
+ return false;
1102
+
1103
+ GGML_UNUSED(buft);
1104
+ }
1105
+
1106
+ /**
1107
+ * @brief Interface for managing CANN buffer types in the GGML backend.
1108
+ *
1109
+ * Provides function pointers for allocating, querying properties, and managing
1110
+ * memory for CANN buffer types in the GGML backend.
1111
+ */
1112
+ static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
1113
+ /* .get_name = */ ggml_backend_cann_buffer_type_name,
1114
+ /* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer,
1115
+ /* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment,
1116
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
1117
+ /* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size,
1118
+ /* .is_host = */ ggml_backend_cann_buffer_type_is_host,
1119
+ };
1120
+
1121
+ /**
1122
+ * @brief Retrieves the CANN buffer type for a specified device.
1123
+ *
1124
+ * This function initializes and returns the buffer type interface associated
1125
+ * with the given device. It ensures thread-safe access using a mutex.
1126
+ *
1127
+ * @param device The device index for which to retrieve the buffer type.
1128
+ * @return A pointer to the buffer type interface for the specified device, or
1129
+ * nullptr if the device index is out of range.
1130
+ */
1131
+ ggml_backend_buffer_type_t
1132
+ ggml_backend_cann_buffer_type(int32_t device) {
1133
+ static std::mutex mutex;
1134
+ std::lock_guard<std::mutex> lock(mutex);
1135
+
1136
+ if (device >= ggml_backend_cann_get_device_count()) {
1137
+ return nullptr;
1138
+ }
1139
+
1140
+ static ggml_backend_buffer_type
1141
+ ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];
1142
+
1143
+ static bool ggml_backend_cann_buffer_type_initialized = false;
1144
+
1145
+ if (!ggml_backend_cann_buffer_type_initialized) {
1146
+ for (int32_t i = 0; i < ggml_cann_info().device_count; i++) {
1147
+ ggml_backend_cann_buffer_types[i] = {
1148
+ /* .iface = */ ggml_backend_cann_buffer_type_interface,
1149
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), i),
1150
+ /* .context = */
1151
+ new ggml_backend_cann_buffer_type_context{
1152
+ i, "CANN" + std::to_string(i)},
1153
+ };
1154
+ }
1155
+ ggml_backend_cann_buffer_type_initialized = true;
1156
+ }
1157
+
1158
+ return &ggml_backend_cann_buffer_types[device];
1159
+ }
1160
+
1161
+ /**
1162
+ * @brief Retrieves the name associated with a CANN host buffer type.
1163
+ *
1164
+ * This function returns the descriptive name associated with the specified
1165
+ * CANN host buffer type context.
1166
+ *
1167
+ * @param buft Pointer to the host buffer type context.
1168
+ * @return Const pointer to the C-style string containing the name.
1169
+ */
1170
+ static const char * ggml_backend_cann_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
1171
+ return "CANN_Host";
1172
+
1173
+ GGML_UNUSED(buft);
1174
+ }
1175
+
1176
+ /**
1177
+ * @brief Retrieves the name associated with a CANN host buffer.
1178
+ *
1179
+ * This function returns the descriptive name associated with the specified
1180
+ * CANN host buffer context.
1181
+ *
1182
+ * @param buft Pointer to the host buffer context.
1183
+ * @return Const pointer to the C-style string containing the name.
1184
+ */
1185
+ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buffer) {
1186
+ return "CANN_Host";
1187
+
1188
+ GGML_UNUSED(buffer);
1189
+ }
1190
+
1191
+ /**
1192
+ * @brief Free resources associated with a CANN host buffer.
1193
+ *
1194
+ * This function frees the resources associated with a CANN host buffer, including
1195
+ * its context.
1196
+ *
1197
+ * @param buffer The CANN host buffer to free.
1198
+ */
1199
+ static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {
1200
+ ACL_CHECK(aclrtFreeHost(buffer->context));
1201
+ }
1202
+
1203
+ /**
1204
+ * @brief Allocates a new CANN host buffer of the specified size.
1205
+ *
1206
+ * This function allocates a new CANN host buffer with the given size.
1207
+ * @param size Size in bytes of the host buffer to allocate.
1208
+ * @return Pointer to the allocated host buffer, or nullptr if allocation fails.
1209
+ */
1210
+ static void * ggml_cann_host_malloc(size_t size) {
1211
+ if (getenv("GGML_CANN_NO_PINNED") != nullptr) {
1212
+ return nullptr;
1213
+ }
1214
+
1215
+ const size_t alignment = 128;
1216
+ size = GGML_PAD(size, alignment);
1217
+ if (size == 0) {
1218
+ size = alignment;
1219
+ }
1220
+
1221
+ void * hostPtr = nullptr;
1222
+ aclError err = aclrtMallocHost((void **) &hostPtr, size);
1223
+ if (err != ACL_SUCCESS) {
1224
+ GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
1225
+ size / 1024.0 / 1024.0, aclGetRecentErrMsg());
1226
+ return nullptr;
1227
+ }
1228
+ return hostPtr;
1229
+ }
1230
+
1231
+ /**
1232
+ * @brief Allocates a new CANN host buffer of the specified type and size.
1233
+ *
1234
+ * @param buft Pointer to the host buffer type context.
1235
+ * @param size Size in bytes of the host buffer to allocate.
1236
+ * @return Pointer to the allocated host buffer, or CPU buffer pointer if allocation fails.
1237
+ */
1238
+ static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1239
+ void * hostPtr = ggml_cann_host_malloc(size);
1240
+
1241
+ if (hostPtr == nullptr) {
1242
+ // fallback to cpu buffer
1243
+ return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
1244
+ }
1245
+
1246
+ ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size);
1247
+ buffer->buft = buft;
1248
+ buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free;
1249
+
1250
+ return buffer;
1251
+ }
1252
+
1253
+ /**
1254
+ * @brief Interface for managing CANN host buffer types in the GGML backend.
1255
+ *
1256
+ * Provides function pointers for allocating, querying properties, and managing
1257
+ * memory for CANN buffer types in the GGML backend.
1258
+ */
1259
+ ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
1260
+ static struct ggml_backend_buffer_type ggml_backend_cann_buffer_type_host = {
1261
+ /* .iface = */ {
1262
+ /* .get_name = */ ggml_backend_cann_host_buffer_type_name,
1263
+ /* .alloc_buffer = */ ggml_backend_cann_host_buffer_type_alloc_buffer,
1264
+ /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
1265
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
1266
+ /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
1267
+ /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
1268
+ },
1269
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
1270
+ /* .context = */ nullptr,
1271
+ };
1272
+
1273
+ return &ggml_backend_cann_buffer_type_host;
1274
+ }
1275
+
1276
+ /**
1277
+ * @brief Computes the forward operation for a given tensor using CANN
1278
+ * operations.
1279
+ *
1280
+ * This function selects the appropriate CANN operation based on the type of
1281
+ * operation specified in the tensor and performs the computation.
1282
+ *
1283
+ * @param ctx The CANN context containing necessary resources and
1284
+ * configurations.
1285
+ * @param dst The destination tensor where the result of the computation will be
1286
+ * stored.
1287
+ * @return true if the computation was successful; false otherwise.
1288
+ */
1289
+ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1290
+ struct ggml_tensor* dst) {
1291
+ switch (dst->op) {
1292
+ case GGML_OP_REPEAT:
1293
+ ggml_cann_repeat(ctx, dst);
1294
+ break;
1295
+ case GGML_OP_GET_ROWS:
1296
+ ggml_cann_get_rows(ctx, dst);
1297
+ break;
1298
+ case GGML_OP_DUP:
1299
+ ggml_cann_dup(ctx, dst);
1300
+ break;
1301
+ case GGML_OP_ADD:
1302
+ ggml_cann_add(ctx, dst);
1303
+ break;
1304
+ case GGML_OP_ACC:
1305
+ ggml_cann_acc(ctx, dst);
1306
+ break;
1307
+ case GGML_OP_MUL:
1308
+ ggml_cann_mul_div<aclnnMulGetWorkspaceSize, aclnnMul>(ctx, dst);
1309
+ break;
1310
+ case GGML_OP_DIV:
1311
+ ggml_cann_mul_div<aclnnDivGetWorkspaceSize, aclnnDiv>(ctx, dst);
1312
+ break;
1313
+ case GGML_OP_UNARY:
1314
+ switch (ggml_get_unary_op(dst)) {
1315
+ case GGML_UNARY_OP_GELU:
1316
+ ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
1317
+ ctx, dst);
1318
+ break;
1319
+ case GGML_UNARY_OP_SILU:
1320
+ ggml_cann_activation<aclnnSiluGetWorkspaceSize, aclnnSilu>(
1321
+ ctx, dst);
1322
+ break;
1323
+ // TODO: Use faster gelu??
1324
+ case GGML_UNARY_OP_GELU_QUICK:
1325
+ ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
1326
+ ctx, dst);
1327
+ break;
1328
+ case GGML_UNARY_OP_TANH:
1329
+ ggml_cann_activation<aclnnTanhGetWorkspaceSize, aclnnTanh>(
1330
+ ctx, dst);
1331
+ break;
1332
+ case GGML_UNARY_OP_RELU:
1333
+ ggml_cann_activation<aclnnReluGetWorkspaceSize, aclnnRelu>(
1334
+ ctx, dst);
1335
+ break;
1336
+ case GGML_UNARY_OP_HARDSIGMOID:
1337
+ ggml_cann_activation<aclnnHardsigmoidGetWorkspaceSize,
1338
+ aclnnHardsigmoid>(ctx, dst);
1339
+ break;
1340
+ case GGML_UNARY_OP_HARDSWISH:
1341
+ ggml_cann_activation<aclnnHardswishGetWorkspaceSize,
1342
+ aclnnHardswish>(ctx, dst);
1343
+ break;
1344
+ default:
1345
+ return false;
1346
+ }
1347
+ break;
1348
+ case GGML_OP_NORM:
1349
+ ggml_cann_norm(ctx, dst);
1350
+ break;
1351
+ case GGML_OP_GROUP_NORM:
1352
+ ggml_cann_group_norm(ctx, dst);
1353
+ break;
1354
+ case GGML_OP_CONCAT:
1355
+ ggml_cann_concat(ctx, dst);
1356
+ break;
1357
+ case GGML_OP_UPSCALE:
1358
+ ggml_cann_upsample_nearest2d(ctx, dst);
1359
+ break;
1360
+ case GGML_OP_PAD:
1361
+ ggml_cann_pad(ctx, dst);
1362
+ break;
1363
+ case GGML_OP_ARANGE:
1364
+ ggml_cann_arange(ctx, dst);
1365
+ break;
1366
+ case GGML_OP_TIMESTEP_EMBEDDING:
1367
+ ggml_cann_timestep_embedding(ctx, dst);
1368
+ break;
1369
+ case GGML_OP_LEAKY_RELU:
1370
+ ggml_cann_leaky_relu(ctx, dst);
1371
+ break;
1372
+ case GGML_OP_RMS_NORM:
1373
+ ggml_cann_rms_norm(ctx, dst);
1374
+ break;
1375
+ case GGML_OP_MUL_MAT:
1376
+ ggml_cann_mul_mat(ctx, dst);
1377
+ break;
1378
+ case GGML_OP_MUL_MAT_ID:
1379
+ return false;
1380
+ case GGML_OP_SCALE:
1381
+ ggml_cann_scale(ctx, dst);
1382
+ break;
1383
+ case GGML_OP_SQR:
1384
+ ggml_cann_sqr(ctx, dst);
1385
+ break;
1386
+ case GGML_OP_CLAMP:
1387
+ ggml_cann_clamp(ctx, dst);
1388
+ break;
1389
+ case GGML_OP_CPY:
1390
+ ggml_cann_cpy(ctx, dst);
1391
+ break;
1392
+ case GGML_OP_CONT:
1393
+ ggml_cann_dup(ctx, dst);
1394
+ break;
1395
+ case GGML_OP_NONE:
1396
+ case GGML_OP_RESHAPE:
1397
+ case GGML_OP_VIEW:
1398
+ case GGML_OP_PERMUTE:
1399
+ case GGML_OP_TRANSPOSE:
1400
+ break;
1401
+ case GGML_OP_DIAG_MASK_INF:
1402
+ ggml_cann_diag_mask(ctx, dst, -INFINITY);
1403
+ break;
1404
+ case GGML_OP_SOFT_MAX:
1405
+ ggml_cann_softmax(ctx, dst);
1406
+ break;
1407
+ case GGML_OP_ROPE:
1408
+ ggml_cann_rope(ctx, dst);
1409
+ break;
1410
+ case GGML_OP_IM2COL:
1411
+ ggml_cann_im2col(ctx, dst);
1412
+ break;
1413
+ case GGML_OP_POOL_2D:
1414
+ ggml_cann_pool2d(ctx, dst);
1415
+ break;
1416
+ case GGML_OP_SUM_ROWS:
1417
+ ggml_cann_sum_rows(ctx, dst);
1418
+ break;
1419
+ case GGML_OP_ARGSORT:
1420
+ ggml_cann_argsort(ctx, dst);
1421
+ break;
1422
+ default:
1423
+ return false;
1424
+ }
1425
+
1426
+ return true;
1427
+ }
1428
+
1429
+ // backend
1430
+ /**
1431
+ * @brief Retrieves the name associated with the CANN backend.
1432
+ *
1433
+ * This function returns the name assigned to the CANN backend, which is stored
1434
+ * in the context of the provided backend structure.
1435
+ *
1436
+ * @param backend Pointer to the CANN backend structure.
1437
+ * @return A pointer to a constant string representing the backend name.
1438
+ */
1439
+ static const char* ggml_backend_cann_name(ggml_backend_t backend) {
1440
+ ggml_backend_cann_context* cann_ctx =
1441
+ (ggml_backend_cann_context*)backend->context;
1442
+
1443
+ return cann_ctx->name.c_str();
1444
+ }
1445
+
1446
+ /**
1447
+ * @brief Frees resources associated with the CANN backend.
1448
+ *
1449
+ * This function releases resources associated with the CANN backend context
1450
+ * and resets the device associated with the backend to its initial state.
1451
+ *
1452
+ * @param backend Pointer to the CANN backend structure to be freed.
1453
+ */
1454
+ static void ggml_backend_cann_free(ggml_backend_t backend) {
1455
+ ggml_backend_cann_context* cann_ctx =
1456
+ (ggml_backend_cann_context*)backend->context;
1457
+ ACL_CHECK(aclrtSynchronizeDevice());
1458
+ ACL_CHECK(aclrtResetDevice(cann_ctx->device));
1459
+
1460
+ // finalize when last backend freed.
1461
+ if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
1462
+ ACL_CHECK(aclFinalize());
1463
+ }
1464
+
1465
+ delete cann_ctx;
1466
+ delete backend;
1467
+ }
1468
+
1469
+ /**
1470
+ * @brief Sets tensor data asynchronously in the CANN backend.
1471
+ *
1472
+ * This function asynchronously sets tensor data in the CANN backend. Depending
1473
+ * on the tensor type, it may perform data transformations before copying data
1474
+ * to the device.
1475
+ *
1476
+ * @param backend Pointer to the CANN backend structure.
1477
+ * @param tensor Pointer to the tensor structure to set data for.
1478
+ * @param data Pointer to the host data to copy to the tensor.
1479
+ * @param offset Offset in bytes within the host data.
1480
+ * @param size Size of the data to copy in bytes.
1481
+ */
1482
+ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
1483
+ ggml_tensor *tensor,
1484
+ const void *data,
1485
+ size_t offset,
1486
+ size_t size) {
1487
+ ggml_backend_cann_context *cann_ctx =
1488
+ (ggml_backend_cann_context *)backend->context;
1489
+
1490
+ if (!need_transform(tensor->type)) {
1491
+ ACL_CHECK(aclrtMemcpyAsync((char *)tensor->data + offset, size, data,
1492
+ size, ACL_MEMCPY_HOST_TO_DEVICE,
1493
+ cann_ctx->stream()));
1494
+ } else {
1495
+ void *transform_buffer = malloc(size);
1496
+ ggml_backend_cann_transform(tensor, data, transform_buffer);
1497
+
1498
+ ACL_CHECK(aclrtMemcpyAsync(
1499
+ (char *)tensor->data + offset, size, transform_buffer, size,
1500
+ ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
1501
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1502
+ free(transform_buffer);
1503
+ }
1504
+ }
1505
+
1506
+ static void ggml_backend_cann_get_tensor_async(
1507
+ ggml_backend_t backend, const ggml_tensor *tensor, void *data,
1508
+ size_t offset, size_t size) {
1509
+ ggml_backend_cann_context *cann_ctx =
1510
+ (ggml_backend_cann_context *)backend->context;
1511
+ ggml_backend_buffer_t buf =
1512
+ tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
1513
+
1514
+ GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
1515
+ "unsupported buffer type");
1516
+
1517
+ if (!need_transform(tensor->type)) {
1518
+ ACL_CHECK(aclrtMemcpyAsync(data, size, (char *)tensor->data + offset,
1519
+ size, ACL_MEMCPY_DEVICE_TO_HOST,
1520
+ cann_ctx->stream()));
1521
+ } else {
1522
+ void *transform_buffer = malloc(size);
1523
+ ACL_CHECK(aclrtMemcpyAsync(
1524
+ transform_buffer, size, (char *)tensor->data + offset, size,
1525
+ ACL_MEMCPY_DEVICE_TO_HOST, cann_ctx->stream()));
1526
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1527
+ ggml_backend_cann_transform_back(tensor, transform_buffer, data);
1528
+ free(transform_buffer);
1529
+ }
1530
+ }
1531
+
1532
+ /**
1533
+ * @brief Asynchronously copies tensor data between CANN backends.
1534
+ *
1535
+ * This function copies tensor data asynchronously between two CANN backends. It
1536
+ * checks if both tensors reside in CANN buffers and whether the devices support
1537
+ * peer-to-peer access for direct copying. If not, it returns false.
1538
+ *
1539
+ * @param backend_src Pointer to the source CANN backend structure.
1540
+ * @param backend_dst Pointer to the destination CANN backend structure.
1541
+ * @param src Pointer to the source tensor to copy data from.
1542
+ * @param dst Pointer to the destination tensor to copy data to.
1543
+ * @return true if the copy operation succeeds, false otherwise.
1544
+ */
1545
+ static bool ggml_backend_cann_cpy_tensor_async(
1546
+ ggml_backend_t backend_src, ggml_backend_t backend_dst,
1547
+ const ggml_tensor* src, ggml_tensor* dst) {
1548
+ GGML_ASSERT(ggml_backend_is_cann(backend_src) ||
1549
+ ggml_backend_is_cann(backend_dst));
1550
+
1551
+ if (!ggml_backend_buffer_is_cann(src->buffer) ||
1552
+ !ggml_backend_buffer_is_cann(dst->buffer)) {
1553
+ return false;
1554
+ }
1555
+
1556
+ ggml_backend_buffer_t buf_src =
1557
+ src->view_src ? src->view_src->buffer : src->buffer;
1558
+ ggml_backend_buffer_t buf_dst =
1559
+ dst->view_src ? dst->view_src->buffer : dst->buffer;
1560
+
1561
+ ggml_backend_cann_context* cann_ctx_src =
1562
+ (ggml_backend_cann_context*)backend_src->context;
1563
+ ggml_backend_cann_context* cann_ctx_dst =
1564
+ (ggml_backend_cann_context*)backend_dst->context;
1565
+
1566
+ size_t copy_size = ggml_nbytes(dst);
1567
+ if (backend_src != backend_dst) {
1568
+ ggml_backend_cann_buffer_context* buf_ctx_src =
1569
+ (ggml_backend_cann_buffer_context*)buf_src->context;
1570
+ ggml_backend_cann_buffer_context* buf_ctx_dst =
1571
+ (ggml_backend_cann_buffer_context*)buf_dst->context;
1572
+
1573
+ GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device);
1574
+ GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device);
1575
+
1576
+ int32_t canAccessPeer = 0;
1577
+ ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device,
1578
+ cann_ctx_dst->device));
1579
+ if (!canAccessPeer) {
1580
+ return false;
1581
+ }
1582
+
1583
+ // need open both directions for memcpyasync between devices.
1584
+ ggml_cann_set_device(cann_ctx_dst->device);
1585
+ ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));
1586
+ ggml_cann_set_device(cann_ctx_src->device);
1587
+ ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
1588
+
1589
+ ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
1590
+ ACL_MEMCPY_DEVICE_TO_DEVICE,
1591
+ cann_ctx_src->stream()));
1592
+
1593
+ //TODO: workaround for Event didn`t work here.
1594
+ aclrtSynchronizeStream(cann_ctx_src->stream());
1595
+ } else {
1596
+ // src and dst are on the same backend
1597
+ ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
1598
+ ACL_MEMCPY_DEVICE_TO_DEVICE,
1599
+ cann_ctx_dst->stream()));
1600
+ }
1601
+
1602
+ return true;
1603
+ }
1604
+
1605
+ /**
1606
+ * @brief Synchronizes a CANN backend.
1607
+ *
1608
+ * This function synchronizes the specified CANN backend by waiting for all
1609
+ * operations in its associated stream to complete.
1610
+ *
1611
+ * @param backend Pointer to the CANN backend structure to synchronize.
1612
+ */
1613
+ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
1614
+ ggml_backend_cann_context* cann_ctx =
1615
+ (ggml_backend_cann_context*)backend->context;
1616
+
1617
+ ggml_cann_set_device(cann_ctx->device);
1618
+
1619
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1620
+ }
1621
+
1622
+ /**
1623
+ * @brief Computes a computational graph using a CANN backend.
1624
+ *
1625
+ * This function computes the operations defined in the computational graph
1626
+ * using the specified CANN backend.
1627
+ *
1628
+ * @param backend Pointer to the CANN backend structure to use for computation.
1629
+ * @param cgraph Pointer to the computational graph structure containing nodes
1630
+ * representing operations to be computed.
1631
+ * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation
1632
+ * completes successfully, otherwise an appropriate error status.
1633
+ */
1634
+ static enum ggml_status ggml_backend_cann_graph_compute(
1635
+ ggml_backend_t backend, ggml_cgraph* cgraph) {
1636
+ ggml_backend_cann_context* cann_ctx =
1637
+ (ggml_backend_cann_context*)backend->context;
1638
+
1639
+ ggml_cann_set_device(cann_ctx->device);
1640
+
1641
+ for (int i = 0; i < cgraph->n_nodes; i++) {
1642
+ ggml_tensor* node = cgraph->nodes[i];
1643
+
1644
+ if (ggml_is_empty(node) || node->op == GGML_OP_NONE) {
1645
+ continue;
1646
+ }
1647
+
1648
+ bool ok = ggml_cann_compute_forward(*cann_ctx, node);
1649
+
1650
+ if (!ok) {
1651
+ GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
1652
+ node->name, ggml_op_name(node->op));
1653
+ }
1654
+ GGML_ASSERT(ok);
1655
+ }
1656
+
1657
+ return GGML_STATUS_SUCCESS;
1658
+ }
1659
+
1660
+ /**
1661
+ * @brief Checks if the CANN backend supports a specific operation.
1662
+ *
1663
+ * This function checks whether the specified operation is supported by the
1664
+ * CANN backend.
1665
+ *
1666
+ * @param backend Pointer to the CANN backend structure to check support for
1667
+ * the operation.
1668
+ * @param op Pointer to the tensor representing the operation to check.
1669
+ * @return bool Returns true if the operation is supported by the backend,
1670
+ * otherwise false.
1671
+ */
1672
+ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1673
+ const ggml_tensor* op) {
1674
+ switch (op->op) {
1675
+ case GGML_OP_UNARY:
1676
+ switch (ggml_get_unary_op(op)) {
1677
+ case GGML_UNARY_OP_GELU:
1678
+ case GGML_UNARY_OP_SILU:
1679
+ case GGML_UNARY_OP_RELU:
1680
+ case GGML_UNARY_OP_HARDSIGMOID:
1681
+ case GGML_UNARY_OP_HARDSWISH:
1682
+ case GGML_UNARY_OP_GELU_QUICK:
1683
+ case GGML_UNARY_OP_TANH:
1684
+ return true;
1685
+ default:
1686
+ return false;
1687
+ }
1688
+ case GGML_OP_MUL_MAT: {
1689
+ switch (op->src[0]->type) {
1690
+ case GGML_TYPE_Q8_0:
1691
+ // Current groupsize should not be greater than k-1 in
1692
+ // aclnnWeightQuantBatchMatmulV2GetWorkspaceSize
1693
+ if (op->src[0]->ne[0] <= QK8_0) {
1694
+ return false;
1695
+ }
1696
+ case GGML_TYPE_F16:
1697
+ case GGML_TYPE_F32:
1698
+ case GGML_TYPE_Q4_0:
1699
+ return true;
1700
+ default:
1701
+ return false;
1702
+ }
1703
+ }
1704
+ case GGML_OP_MUL_MAT_ID:
1705
+ return false;
1706
+ // embedding
1707
+ case GGML_OP_GET_ROWS: {
1708
+ switch (op->src[0]->type) {
1709
+ case GGML_TYPE_F32:
1710
+ case GGML_TYPE_F16:
1711
+ case GGML_TYPE_Q4_0:
1712
+ case GGML_TYPE_Q8_0:
1713
+ return true;
1714
+ default:
1715
+ return false;
1716
+ }
1717
+ } break;
1718
+ case GGML_OP_CPY: {
1719
+ switch (op->type) {
1720
+ case GGML_TYPE_F32:
1721
+ case GGML_TYPE_F16:
1722
+ case GGML_TYPE_Q8_0:
1723
+ case GGML_TYPE_Q4_0:
1724
+ return true;
1725
+ default:
1726
+ return false;
1727
+ }
1728
+ }
1729
+ case GGML_OP_CONT: {
1730
+ // TODO: support GGML_TYPE_BF16
1731
+ switch (op->src[0]->type) {
1732
+ case GGML_TYPE_F32:
1733
+ case GGML_TYPE_F16:
1734
+ return true;
1735
+ default:
1736
+ return false;
1737
+ }
1738
+ }
1739
+ case GGML_OP_ROPE: {
1740
+ // TODO: with ops-test v == 1
1741
+ float * ext_factor = (float*)((int32_t*)op->op_params + 7);
1742
+ // TODO: n_dims <= ne0
1743
+ if (op->src[0]->ne[0] != op->op_params[1]) {
1744
+ return false;
1745
+ }
1746
+ // TODO: ext_factor != 0
1747
+ if (*ext_factor != 0) {
1748
+ return false;
1749
+ }
1750
+
1751
+ const int mode = ((const int32_t *) op->op_params)[2];
1752
+ if (mode & GGML_ROPE_TYPE_MROPE) {
1753
+ return false;
1754
+ }
1755
+ if (mode & GGML_ROPE_TYPE_VISION) {
1756
+ return false;
1757
+ }
1758
+
1759
+ return true;
1760
+ }
1761
+ case GGML_OP_UPSCALE: {
1762
+ // aclnnUpsampleNearest2dGetWorkspaceSize not support
1763
+ // selfDimN[2]/outDimN[2] or selfDimC[3]/outDimC[3] not equal
1764
+ if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) {
1765
+ return false;
1766
+ }
1767
+ return true;
1768
+ }
1769
+ case GGML_OP_IM2COL:
1770
+ case GGML_OP_CONCAT:
1771
+ case GGML_OP_DUP:
1772
+ case GGML_OP_REPEAT:
1773
+ case GGML_OP_NONE:
1774
+ case GGML_OP_RESHAPE:
1775
+ case GGML_OP_VIEW:
1776
+ case GGML_OP_PERMUTE:
1777
+ case GGML_OP_TRANSPOSE:
1778
+ case GGML_OP_NORM:
1779
+ case GGML_OP_ADD:
1780
+ case GGML_OP_MUL:
1781
+ case GGML_OP_DIV:
1782
+ case GGML_OP_RMS_NORM:
1783
+ case GGML_OP_SCALE:
1784
+ case GGML_OP_SQR:
1785
+ case GGML_OP_CLAMP:
1786
+ case GGML_OP_DIAG_MASK_INF:
1787
+ case GGML_OP_SOFT_MAX:
1788
+ case GGML_OP_POOL_2D:
1789
+ case GGML_OP_SUM_ROWS:
1790
+ case GGML_OP_ARGSORT:
1791
+ case GGML_OP_ACC:
1792
+ case GGML_OP_GROUP_NORM:
1793
+ case GGML_OP_PAD:
1794
+ case GGML_OP_ARANGE:
1795
+ case GGML_OP_TIMESTEP_EMBEDDING:
1796
+ case GGML_OP_LEAKY_RELU:
1797
+ return true;
1798
+ default:
1799
+ return false;
1800
+ }
1801
+
1802
+ GGML_UNUSED(dev);
1803
+ }
1804
+
1805
+ /**
1806
+ * @brief Checks if the backend buffer type is associated with the CANN backend.
1807
+ *
1808
+ * This function checks whether the provided backend buffer type is associated
1809
+ * with the CANN backend based on the comparison of its name retrieval function
1810
+ * pointer.
1811
+ *
1812
+ * @param buft Pointer to the backend buffer type to check.
1813
+ * @return bool Returns true if the buffer type is associated with the CANN
1814
+ * backend, otherwise false.
1815
+ */
1816
+ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
1817
+ return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
1818
+ }
1819
+
1820
+ /**
1821
+ * @brief Determines if a tensor operation should be offloaded to the CANN
1822
+ * backend.
1823
+ *
1824
+ * This function checks if a given tensor operation should be offloaded to the
1825
+ * CANN backend based on the operation type and the size of the tensor. It
1826
+ * returns true if the second dimension (ne[1]) of the tensor is greater than or
1827
+ * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
1828
+ *
1829
+ * @param backend Pointer to the CANN backend.
1830
+ * @param op Pointer to the tensor operation to check.
1831
+ * @return bool Returns true if the operation should be offloaded, otherwise
1832
+ * false.
1833
+ */
1834
+ static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev,
1835
+ const ggml_tensor* op) {
1836
+ const int min_batch_size = 32;
1837
+ GGML_UNUSED(dev);
1838
+
1839
+ return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
1840
+ }
1841
+
1842
+ /**
1843
+ * @brief Records an event on the CANN backend stream.
1844
+ *
1845
+ * This function records the given event on the ACL runtime stream associated
1846
+ * with the backend context.
1847
+ *
1848
+ * @param event Pointer to the event structure to be recorded.
1849
+ */
1850
+ static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
1851
+ ggml_backend_cann_context* cann_ctx =
1852
+ (ggml_backend_cann_context*)backend->context;
1853
+ ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream()));
1854
+ }
1855
+
1856
+ /**
1857
+ * @brief Waits for a recorded event to complete on the CANN backend stream.
1858
+ *
1859
+ * This function makes the given backend wait for the event to complete on its
1860
+ * ACL runtime stream.
1861
+ *
1862
+ * @param backend Pointer to the backend structure.
1863
+ * @param event Pointer to the event structure that the backend needs to wait
1864
+ * for.
1865
+ */
1866
+ static void ggml_backend_cann_event_wait(ggml_backend_t backend,
1867
+ ggml_backend_event_t event) {
1868
+ ggml_backend_cann_context* cann_ctx =
1869
+ (ggml_backend_cann_context*)backend->context;
1870
+ if (ggml_backend_is_cann(backend)) {
1871
+ ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(),
1872
+ (aclrtEvent)event->context));
1873
+ } else {
1874
+ GGML_ABORT("fatal error");
1875
+ }
1876
+ }
1877
+
1878
+ /**
1879
+ * @brief Structure defining the interface for the CANN backend.
1880
+ *
1881
+ * This structure contains function pointers for various operations
1882
+ * supported by the CANN backend, including name retrieval, memory
1883
+ * management, tensor operations, synchronization, and event handling.
1884
+ */
1885
+ static const ggml_backend_i ggml_backend_cann_interface = {
1886
+ /* .get_name = */ ggml_backend_cann_name,
1887
+ /* .free = */ ggml_backend_cann_free,
1888
+ /* .set_tensor_async = */ ggml_backend_cann_set_tensor_async,
1889
+ /* .get_tensor_async = */ ggml_backend_cann_get_tensor_async,
1890
+ /* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async,
1891
+ /* .synchronize = */ ggml_backend_cann_synchronize,
1892
+ /* .graph_plan_create = */ NULL,
1893
+ /* .graph_plan_free = */ NULL,
1894
+ /* .graph_plan_update = */ NULL,
1895
+ /* .graph_plan_compute = */ NULL,
1896
+ /* .graph_compute = */ ggml_backend_cann_graph_compute,
1897
+ /* .event_record = */ ggml_backend_cann_event_record,
1898
+ /* .event_wait = */ ggml_backend_cann_event_wait,
1899
+ };
1900
+
1901
+ /**
1902
+ * @brief Return the hardcoded GUID for the CANN backend.
1903
+ *
1904
+ * This function returns a static GUID which uniquely identifies the CANN
1905
+ * backend.
1906
+ *
1907
+ * @return A pointer to the static GUID.
1908
+ */
1909
+ static ggml_guid_t ggml_backend_cann_guid() {
1910
+ static ggml_guid guid = {0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,
1911
+ 0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64};
1912
+ return &guid;
1913
+ }
1914
+
1915
+ // backend device
1916
+ struct ggml_backend_cann_device_context {
1917
+ int device;
1918
+ std::string name;
1919
+ std::string description;
1920
+ };
1921
+
1922
+ static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
1923
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1924
+ return ctx->name.c_str();
1925
+ }
1926
+
1927
+ static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
1928
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1929
+ return ctx->description.c_str();
1930
+ }
1931
+
1932
+ static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1933
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1934
+ ggml_backend_cann_get_device_memory(ctx->device, free, total);
1935
+ }
1936
+
1937
+ static enum ggml_backend_dev_type ggml_backend_cann_device_get_type(ggml_backend_dev_t dev) {
1938
+ GGML_UNUSED(dev);
1939
+ return GGML_BACKEND_DEVICE_TYPE_GPU;
1940
+ }
1941
+
1942
+ static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
1943
+ props->name = ggml_backend_cann_device_get_name(dev);
1944
+ props->description = ggml_backend_cann_device_get_description(dev);
1945
+ props->type = ggml_backend_cann_device_get_type(dev);
1946
+ ggml_backend_cann_device_get_memory(dev, &props->memory_free, &props->memory_total);
1947
+
1948
+ bool host_buffer = getenv("GGML_CANN_NO_PINNED") == nullptr;
1949
+
1950
+ props->caps = {
1951
+ /* .async = */ false,
1952
+ /* .host_buffer = */ host_buffer,
1953
+ /* .buffer_from_host_ptr = */ false,
1954
+ /* .events = */ true,
1955
+ };
1956
+ }
1957
+
1958
+ static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {
1959
+ GGML_UNUSED(params);
1960
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1961
+ return ggml_backend_cann_init(ctx->device);
1962
+ }
1963
+
1964
+ /**
1965
+ * @brief Checks if the CANN backend supports a specific backend buffer type.
1966
+ *
1967
+ * This function determines whether the CANN backend supports the given backend
1968
+ * buffer type by comparing the device context of the backend and buffer type.
1969
+ * It returns true if the devices are same between the backend context and
1970
+ * buffer type context.
1971
+ *
1972
+ * @param backend Pointer to the CANN backend.
1973
+ * @param buft Pointer to the backend buffer type to check.
1974
+ * @return bool Returns true if the CANN backend supports the buffer type,
1975
+ * otherwise false.
1976
+ */
1977
+ static bool ggml_backend_cann_supports_buft(
1978
+ ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
1979
+ if (ggml_backend_buft_is_cann(buft)) {
1980
+ ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
1981
+ ggml_backend_cann_buffer_type_context * buft_ctx =
1982
+ (ggml_backend_cann_buffer_type_context *)buft->context;
1983
+ return buft_ctx->device == dev_ctx->device;
1984
+ }
1985
+ return false;
1986
+ }
1987
+
1988
+ static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {
1989
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1990
+ return ggml_backend_cann_buffer_type(ctx->device);
1991
+ }
1992
+
1993
+ static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(ggml_backend_dev_t dev) {
1994
+ GGML_UNUSED(dev);
1995
+ return ggml_backend_cann_host_buffer_type();
1996
+ }
1997
+
1998
+ /**
1999
+ * @brief Creates a new event for the CANN backend device.
2000
+ *
2001
+ * This function initializes a new event for the CANN backend by setting the
2002
+ * device and creating an ACL runtime event. The created event is then wrapped
2003
+ * in a ggml_backend_event structure and returned.
2004
+ *
2005
+ * @param backend Pointer to the CANN backend.
2006
+ * @return ggml_backend_event_t Returns a pointer to the new event structure.
2007
+ */
2008
+ static ggml_backend_event_t ggml_backend_cann_device_event_new(
2009
+ ggml_backend_dev_t dev) {
2010
+ ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
2011
+
2012
+ ggml_cann_set_device(dev_ctx->device);
2013
+
2014
+ aclrtEvent event;
2015
+ ACL_CHECK(aclrtCreateEvent(&event));
2016
+
2017
+ return new ggml_backend_event{
2018
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), dev_ctx->device),
2019
+ /* .context = */ event,
2020
+ };
2021
+ }
2022
+
2023
+ /**
2024
+ * @brief Frees a CANN backend event.
2025
+ *
2026
+ * This function destroys the ACL runtime event associated with the given CANN
2027
+ * backend event and then deletes the event structure itself.
2028
+ *
2029
+ * @param event Pointer to the event structure to be freed.
2030
+ */
2031
+ static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
2032
+ ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
2033
+
2034
+ delete event;
2035
+ GGML_UNUSED(dev);
2036
+ }
2037
+
2038
+ /**
2039
+ * @brief Synchronizes the given event on the CANN backend.
2040
+ *
2041
+ * This function waits for the specified event to complete on the ACL runtime.
2042
+ *
2043
+ * @param event Pointer to the event structure to be synchronized.
2044
+ */
2045
+ static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
2046
+ ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
2047
+
2048
+ GGML_UNUSED(dev);
2049
+ }
2050
+
2051
+ static const ggml_backend_device_i ggml_backend_cann_device_interface = {
2052
+ /* .get_name = */ ggml_backend_cann_device_get_name,
2053
+ /* .get_description = */ ggml_backend_cann_device_get_description,
2054
+ /* .get_memory = */ ggml_backend_cann_device_get_memory,
2055
+ /* .get_type = */ ggml_backend_cann_device_get_type,
2056
+ /* .get_props = */ ggml_backend_cann_device_get_props,
2057
+ /* .init_backend = */ ggml_backend_cann_device_init, // called for every card
2058
+ /* .get_buffer_type = */ ggml_backend_cann_device_get_buffer_type,
2059
+ /* .get_host_buffer_type = */ ggml_backend_cann_device_get_host_buffer_type,
2060
+ /* .buffer_from_host_ptr = */ NULL, // not supported for CANN
2061
+ /* .supports_op = */ ggml_backend_cann_supports_op,
2062
+ /* .supports_buft = */ ggml_backend_cann_supports_buft,
2063
+ /* .offload_op = */ ggml_backend_cann_offload_op,
2064
+ /* .event_new = */ ggml_backend_cann_device_event_new,
2065
+ /* .event_free = */ ggml_backend_cann_device_event_free,
2066
+ /* .event_synchronize = */ ggml_backend_cann_device_event_synchronize,
2067
+ };
2068
+
2069
+
2070
+ // backend reg
2071
+ struct ggml_backend_cann_reg_context {
2072
+ std::vector<ggml_backend_dev_t> devices;
2073
+ };
2074
+
2075
+ static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) {
2076
+ GGML_UNUSED(reg);
2077
+ return GGML_CANN_NAME;
2078
+ }
2079
+
2080
+ static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) {
2081
+ ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
2082
+ return ctx->devices.size();
2083
+ }
2084
+
2085
+ static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) {
2086
+ ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
2087
+ GGML_ASSERT(index < ctx->devices.size());
2088
+ return ctx->devices[index];
2089
+ }
2090
+
2091
+ static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
2092
+ GGML_UNUSED(reg);
2093
+ GGML_UNUSED(name);
2094
+ // reserved for future use
2095
+ return nullptr;
2096
+ }
2097
+
2098
+ static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
2099
+ /* .get_name = */ ggml_backend_cann_reg_get_name,
2100
+ /* .get_device_count = */ ggml_backend_cann_reg_get_device_count,
2101
+ /* .get_device = */ ggml_backend_cann_reg_get_device,
2102
+ /* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address,
2103
+ };
2104
+
2105
+ // backend registry, called only once for cann backend
2106
+ ggml_backend_reg_t ggml_backend_cann_reg() {
2107
+ static ggml_backend_reg reg;
2108
+ static bool initialized = false;
2109
+
2110
+ {
2111
+ static std::mutex mutex;
2112
+ std::lock_guard<std::mutex> lock(mutex);
2113
+ if (!initialized) {
2114
+ aclInit(nullptr);
2115
+ ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
2116
+
2117
+ for (int i = 0; i < ggml_cann_info().device_count; i++) {
2118
+ ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context();
2119
+ dev_ctx->description = aclrtGetSocName();
2120
+ dev_ctx->device = i;
2121
+ dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
2122
+ ggml_cann_set_device(i);
2123
+ ggml_backend_dev_t dev = new ggml_backend_device {
2124
+ /* .iface = */ ggml_backend_cann_device_interface,
2125
+ /* .reg = */ &reg,
2126
+ /* .context = */ dev_ctx
2127
+ };
2128
+ ctx->devices.push_back(dev);
2129
+ }
2130
+
2131
+ reg = ggml_backend_reg {
2132
+ /* .api_version = */ GGML_BACKEND_API_VERSION,
2133
+ /* .iface = */ ggml_backend_cann_reg_interface,
2134
+ /* .context = */ ctx
2135
+ };
2136
+ }
2137
+
2138
+ initialized = true;
2139
+ }
2140
+
2141
+ return &reg;
2142
+ }
2143
+
2144
+ ggml_backend_t ggml_backend_cann_init(int32_t device) {
2145
+ aclInit(nullptr);
2146
+ if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
2147
+ GGML_LOG_ERROR("%s: error: invalid device %d\n", __func__, device);
2148
+ return nullptr;
2149
+ }
2150
+
2151
+ ggml_backend_cann_context* ctx = new ggml_backend_cann_context(device);
2152
+ if (ctx == nullptr) {
2153
+ GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
2154
+ return nullptr;
2155
+ }
2156
+ ggml_cann_set_device(ctx->device);
2157
+ ggml_backend_t cann_backend =
2158
+ new ggml_backend{/* .guid = */ ggml_backend_cann_guid(),
2159
+ /* .interface = */ ggml_backend_cann_interface,
2160
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
2161
+ /* .context = */ ctx};
2162
+
2163
+ return cann_backend;
2164
+ }
2165
+
2166
+ bool ggml_backend_is_cann(ggml_backend_t backend) {
2167
+ return backend != NULL &&
2168
+ ggml_guid_matches(backend->guid, ggml_backend_cann_guid());
2169
+ }
2170
+
2171
+ int32_t ggml_backend_cann_get_device_count() {
2172
+ return ggml_cann_info().device_count;
2173
+ }
2174
+
2175
+ void ggml_backend_cann_get_device_description(
2176
+ int32_t device, char* description, size_t description_size) {
2177
+ ggml_cann_set_device(device);
2178
+ const char* soc_name = aclrtGetSocName();
2179
+ snprintf(description, description_size, "%s", soc_name);
2180
+ }
2181
+
2182
+ void ggml_backend_cann_get_device_memory(int32_t device, size_t* free,
2183
+ size_t* total) {
2184
+ ggml_cann_set_device(device);
2185
+ ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total));
2186
+ }
2187
+
2188
+ GGML_BACKEND_DL_IMPL(ggml_backend_cann_reg)