whispercpp 1.3.0 → 1.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +5 -0
- data/LICENSE +1 -1
- data/README.md +165 -434
- data/Rakefile +60 -11
- data/ext/.gitignore +13 -0
- data/ext/cpu.mk +9 -0
- data/ext/{dr_wav.h → examples/dr_wav.h} +3560 -1179
- data/ext/extconf.rb +185 -16
- data/ext/ggml/include/ggml-alloc.h +76 -0
- data/ext/ggml/include/ggml-backend.h +352 -0
- data/ext/ggml/include/ggml-blas.h +25 -0
- data/ext/ggml/include/ggml-cann.h +123 -0
- data/ext/ggml/include/ggml-cpp.h +38 -0
- data/ext/ggml/include/ggml-cpu.h +135 -0
- data/ext/ggml/include/ggml-cuda.h +47 -0
- data/ext/ggml/include/ggml-kompute.h +50 -0
- data/ext/ggml/include/ggml-metal.h +66 -0
- data/ext/ggml/include/ggml-opencl.h +26 -0
- data/ext/ggml/include/ggml-opt.h +216 -0
- data/ext/ggml/include/ggml-rpc.h +28 -0
- data/ext/ggml/include/ggml-sycl.h +49 -0
- data/ext/ggml/include/ggml-vulkan.h +31 -0
- data/ext/{ggml.h → ggml/include/ggml.h} +479 -596
- data/ext/ggml/src/ggml-alloc.c +1037 -0
- data/ext/ggml/src/ggml-amx/common.h +94 -0
- data/ext/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- data/ext/ggml/src/ggml-amx/mmq.cpp +2510 -0
- data/ext/ggml/src/ggml-amx/mmq.h +17 -0
- data/ext/ggml/src/ggml-backend-impl.h +256 -0
- data/ext/ggml/src/ggml-backend-reg.cpp +552 -0
- data/ext/ggml/src/ggml-backend.cpp +1999 -0
- data/ext/ggml/src/ggml-blas/ggml-blas.cpp +517 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.cpp +175 -0
- data/ext/ggml/src/ggml-cann/acl_tensor.h +258 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.cpp +3427 -0
- data/ext/ggml/src/ggml-cann/aclnn_ops.h +592 -0
- data/ext/ggml/src/ggml-cann/common.h +286 -0
- data/ext/ggml/src/ggml-cann/ggml-cann.cpp +2188 -0
- data/ext/ggml/src/ggml-cann/kernels/ascendc_kernels.h +19 -0
- data/ext/ggml/src/ggml-cann/kernels/dup.cpp +236 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f16.cpp +197 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_f32.cpp +190 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +204 -0
- data/ext/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +218 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +216 -0
- data/ext/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +295 -0
- data/ext/ggml/src/ggml-common.h +1853 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- data/ext/ggml/src/ggml-cpu/amx/amx.h +8 -0
- data/ext/ggml/src/ggml-cpu/amx/common.h +91 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- data/ext/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- data/ext/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +4262 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-impl.h +386 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- data/ext/ggml/src/ggml-cpu/ggml-cpu.cpp +622 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
- data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- data/ext/ggml/src/ggml-cuda/vendors/hip.h +186 -0
- data/ext/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- data/ext/ggml/src/ggml-impl.h +556 -0
- data/ext/ggml/src/ggml-kompute/ggml-kompute.cpp +2251 -0
- data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.m +4884 -0
- data/ext/ggml/src/ggml-metal/ggml-metal.metal +6732 -0
- data/ext/ggml/src/ggml-opt.cpp +854 -0
- data/ext/ggml/src/ggml-quants.c +5238 -0
- data/ext/ggml/src/ggml-quants.h +100 -0
- data/ext/ggml/src/ggml-rpc/ggml-rpc.cpp +1406 -0
- data/ext/ggml/src/ggml-sycl/common.cpp +95 -0
- data/ext/ggml/src/ggml-sycl/concat.cpp +196 -0
- data/ext/ggml/src/ggml-sycl/conv.cpp +99 -0
- data/ext/ggml/src/ggml-sycl/convert.cpp +547 -0
- data/ext/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
- data/ext/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- data/ext/ggml/src/ggml-sycl/ggml-sycl.cpp +4729 -0
- data/ext/ggml/src/ggml-sycl/im2col.cpp +126 -0
- data/ext/ggml/src/ggml-sycl/mmq.cpp +3031 -0
- data/ext/ggml/src/ggml-sycl/mmvq.cpp +1015 -0
- data/ext/ggml/src/ggml-sycl/norm.cpp +378 -0
- data/ext/ggml/src/ggml-sycl/outprod.cpp +56 -0
- data/ext/ggml/src/ggml-sycl/rope.cpp +276 -0
- data/ext/ggml/src/ggml-sycl/softmax.cpp +251 -0
- data/ext/ggml/src/ggml-sycl/tsembd.cpp +72 -0
- data/ext/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- data/ext/ggml/src/ggml-threading.cpp +12 -0
- data/ext/ggml/src/ggml-threading.h +14 -0
- data/ext/ggml/src/ggml-vulkan/ggml-vulkan.cpp +8657 -0
- data/ext/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- data/ext/ggml/src/ggml.c +7694 -0
- data/ext/{whisper.h → include/whisper.h} +23 -22
- data/ext/metal-embed.mk +17 -0
- data/ext/metal.mk +6 -0
- data/ext/ruby_whisper.cpp +1492 -9
- data/ext/ruby_whisper.h +10 -0
- data/ext/scripts/get-flags.mk +38 -0
- data/ext/src/coreml/whisper-decoder-impl.h +146 -0
- data/ext/src/coreml/whisper-decoder-impl.m +201 -0
- data/ext/src/coreml/whisper-encoder-impl.h +142 -0
- data/ext/src/coreml/whisper-encoder-impl.m +197 -0
- data/ext/src/coreml/whisper-encoder.h +26 -0
- data/ext/src/openvino/whisper-openvino-encoder.cpp +108 -0
- data/ext/src/openvino/whisper-openvino-encoder.h +31 -0
- data/ext/{whisper.cpp → src/whisper.cpp} +661 -492
- data/extsources.rb +6 -0
- data/lib/whisper/model/uri.rb +157 -0
- data/lib/whisper.rb +2 -0
- data/tests/helper.rb +7 -0
- data/tests/jfk_reader/.gitignore +5 -0
- data/tests/jfk_reader/extconf.rb +3 -0
- data/tests/jfk_reader/jfk_reader.c +68 -0
- data/tests/test_callback.rb +160 -0
- data/tests/test_error.rb +20 -0
- data/tests/test_model.rb +71 -0
- data/tests/test_package.rb +31 -0
- data/tests/test_params.rb +160 -0
- data/tests/test_segment.rb +83 -0
- data/tests/test_whisper.rb +211 -123
- data/whispercpp.gemspec +36 -0
- metadata +137 -11
- data/ext/ggml.c +0 -21755
@@ -0,0 +1,2188 @@
|
|
1
|
+
/*
|
2
|
+
* Copyright (c) 2023-2024 The ggml authors
|
3
|
+
*
|
4
|
+
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
5
|
+
* of this software and associated documentation files (the "Software"), to
|
6
|
+
* deal in the Software without restriction, including without limitation the
|
7
|
+
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
8
|
+
* sell copies of the Software, and to permit persons to whom the Software is
|
9
|
+
* furnished to do so, subject to the following conditions:
|
10
|
+
*
|
11
|
+
* The above copyright notice and this permission notice shall be included in
|
12
|
+
* all copies or substantial portions of the Software.
|
13
|
+
*
|
14
|
+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
15
|
+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
16
|
+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
17
|
+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
18
|
+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
19
|
+
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
20
|
+
* IN THE SOFTWARE.
|
21
|
+
*/
|
22
|
+
|
23
|
+
#include "ggml-cann.h"
|
24
|
+
|
25
|
+
#include <acl/acl.h>
|
26
|
+
#include <stdarg.h>
|
27
|
+
|
28
|
+
#include <cmath>
|
29
|
+
#include <cstdio>
|
30
|
+
#include <cstring>
|
31
|
+
#include <mutex>
|
32
|
+
|
33
|
+
#include "ggml-impl.h"
|
34
|
+
#include "ggml-backend-impl.h"
|
35
|
+
#include "ggml-cann/aclnn_ops.h"
|
36
|
+
#include "ggml-cann/common.h"
|
37
|
+
|
38
|
+
#define GGML_COMMON_DECL_C
|
39
|
+
|
40
|
+
#include "ggml-common.h"
|
41
|
+
|
42
|
+
#define GGML_CANN_NAME "CANN"
|
43
|
+
|
44
|
+
/**
|
45
|
+
* @brief Handles CANN errors by printing an error message and aborting.
|
46
|
+
*
|
47
|
+
* @param stmt The statement that caused the error.
|
48
|
+
* @param func The function in which the error occurred.
|
49
|
+
* @param file The file in which the error occurred.
|
50
|
+
* @param line The line number where the error occurred.
|
51
|
+
* @param msg The error message.
|
52
|
+
*/
|
53
|
+
[[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
|
54
|
+
const char* file, int line, const char* msg) {
|
55
|
+
int32_t id = -1;
|
56
|
+
aclrtGetDevice(&id);
|
57
|
+
|
58
|
+
GGML_LOG_ERROR("CANN error: %s\n", msg);
|
59
|
+
GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func,
|
60
|
+
file, line);
|
61
|
+
GGML_LOG_ERROR(" %s\n", stmt);
|
62
|
+
// abort with GGML_ASSERT to get a stack trace
|
63
|
+
GGML_ABORT("CANN error");
|
64
|
+
}
|
65
|
+
|
66
|
+
/**
|
67
|
+
* @brief Sets the device to be used by CANN.
|
68
|
+
*
|
69
|
+
* @param device The device ID to set.
|
70
|
+
*/
|
71
|
+
void ggml_cann_set_device(const int32_t device) {
|
72
|
+
// TODO: uncomment these lines after empty context has fixed.
|
73
|
+
// int current_device;
|
74
|
+
// ACL_CHECK(aclrtGetDevice(¤t_device));
|
75
|
+
|
76
|
+
// if (device == current_device) {
|
77
|
+
// return;
|
78
|
+
// }
|
79
|
+
ACL_CHECK(aclrtSetDevice(device));
|
80
|
+
}
|
81
|
+
|
82
|
+
/**
|
83
|
+
* @brief Retrieves the current device ID.
|
84
|
+
*
|
85
|
+
* @return The current device ID.
|
86
|
+
*/
|
87
|
+
int32_t ggml_cann_get_device() {
|
88
|
+
int32_t id;
|
89
|
+
ACL_CHECK(aclrtGetDevice(&id));
|
90
|
+
return id;
|
91
|
+
}
|
92
|
+
|
93
|
+
/**
|
94
|
+
* @brief Initialize the CANN device information.
|
95
|
+
*
|
96
|
+
* This function initializes the CANN device information by obtaining the
|
97
|
+
* device count and setting the memory allocation granularity for each device.
|
98
|
+
*
|
99
|
+
* @return A structure containing the device information.
|
100
|
+
*/
|
101
|
+
static ggml_cann_device_info ggml_cann_init() {
|
102
|
+
ggml_cann_device_info info = {};
|
103
|
+
|
104
|
+
aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
|
105
|
+
|
106
|
+
if (err != ACL_SUCCESS) {
|
107
|
+
GGML_LOG_ERROR("%s: failed to initialize CANN: %s\n",
|
108
|
+
__func__, aclGetRecentErrMsg());
|
109
|
+
return info;
|
110
|
+
}
|
111
|
+
|
112
|
+
GGML_ASSERT(info.device_count <= GGML_CANN_MAX_DEVICES);
|
113
|
+
|
114
|
+
for (int id = 0; id < info.device_count; ++id) {
|
115
|
+
aclrtPhysicalMemProp prop = {};
|
116
|
+
prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
|
117
|
+
prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
|
118
|
+
prop.memAttr = ACL_HBM_MEM_HUGE;
|
119
|
+
prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
|
120
|
+
prop.location.id = id;
|
121
|
+
prop.reserve = 0;
|
122
|
+
ACL_CHECK(aclrtMemGetAllocationGranularity(
|
123
|
+
&prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
|
124
|
+
&info.devices[id].vmm_granularity));
|
125
|
+
|
126
|
+
size_t free, total;
|
127
|
+
ggml_backend_cann_get_device_memory(id, &free, &total);
|
128
|
+
info.devices[id].total_vram = free;
|
129
|
+
}
|
130
|
+
|
131
|
+
// TODO: add more device info later.
|
132
|
+
return info;
|
133
|
+
}
|
134
|
+
|
135
|
+
/**
|
136
|
+
* @brief Retrieve the CANN device information.
|
137
|
+
*
|
138
|
+
* This function returns a reference to a structure containing the CANN device
|
139
|
+
* information. The device information is initialized once and reused on
|
140
|
+
* subsequent calls.
|
141
|
+
*
|
142
|
+
* @return A reference to the structure containing the device information.
|
143
|
+
*/
|
144
|
+
const ggml_cann_device_info& ggml_cann_info() {
|
145
|
+
static ggml_cann_device_info info = ggml_cann_init();
|
146
|
+
return info;
|
147
|
+
}
|
148
|
+
|
149
|
+
//#define DEBUG_CANN_MALLOC
|
150
|
+
/**
|
151
|
+
* @brief A pool of CANN buffers(legacy).
|
152
|
+
*
|
153
|
+
* This class manages a pool of CANN buffers for a specific device.
|
154
|
+
*/
|
155
|
+
struct ggml_cann_pool_leg : public ggml_cann_pool {
|
156
|
+
/**
|
157
|
+
* @brief The maximum number of buffers in the pool.
|
158
|
+
*/
|
159
|
+
static const int MAX_BUFFERS = 256;
|
160
|
+
|
161
|
+
/**
|
162
|
+
* @brief The device ID associated with this buffer pool.
|
163
|
+
*/
|
164
|
+
int device;
|
165
|
+
|
166
|
+
/**
|
167
|
+
* @brief Structure representing a CANN buffer.
|
168
|
+
*/
|
169
|
+
struct ggml_cann_buffer {
|
170
|
+
void* ptr = nullptr; ///< Pointer to the buffer memory.
|
171
|
+
size_t size = 0; ///< Size of the buffer.
|
172
|
+
};
|
173
|
+
|
174
|
+
/**
|
175
|
+
* @brief Array of CANN buffers in the pool.
|
176
|
+
*/
|
177
|
+
ggml_cann_buffer buffer_pool[MAX_BUFFERS] = {};
|
178
|
+
|
179
|
+
/**
|
180
|
+
* @brief Total size of all buffers in the pool.
|
181
|
+
*/
|
182
|
+
size_t pool_size = 0;
|
183
|
+
|
184
|
+
/**
|
185
|
+
* @brief Constructor to initialize the buffer pool for a specific device.
|
186
|
+
*
|
187
|
+
* @param device The device ID to associate with this buffer pool.
|
188
|
+
*/
|
189
|
+
explicit ggml_cann_pool_leg(int device) : device(device) {}
|
190
|
+
|
191
|
+
/**
|
192
|
+
* @brief Destructor to free all buffers in the pool.
|
193
|
+
*/
|
194
|
+
~ggml_cann_pool_leg() {
|
195
|
+
ggml_cann_set_device(device);
|
196
|
+
for (int i = 0; i < MAX_BUFFERS; ++i) {
|
197
|
+
ggml_cann_buffer& b = buffer_pool[i];
|
198
|
+
if (b.ptr != nullptr) {
|
199
|
+
ACL_CHECK(aclrtFree(b.ptr));
|
200
|
+
pool_size -= b.size;
|
201
|
+
}
|
202
|
+
}
|
203
|
+
GGML_ASSERT(pool_size == 0);
|
204
|
+
}
|
205
|
+
|
206
|
+
/**
|
207
|
+
* @brief Allocate a buffer of the given size.
|
208
|
+
*
|
209
|
+
* @param size The size of the buffer to allocate.
|
210
|
+
* @param actual_size A pointer to a variable to receive the actual size of
|
211
|
+
* the allocated buffer.
|
212
|
+
* @return A pointer to the allocated buffer.
|
213
|
+
*/
|
214
|
+
void* alloc(size_t size, size_t* actual_size) override {
|
215
|
+
const size_t alignment = 128;
|
216
|
+
size = GGML_PAD(size, alignment);
|
217
|
+
if (size == 0) {
|
218
|
+
size = alignment;
|
219
|
+
}
|
220
|
+
#ifdef DEBUG_CANN_MALLOC
|
221
|
+
int nnz = 0;
|
222
|
+
size_t max_size = 0;
|
223
|
+
#endif
|
224
|
+
size_t best_diff = 1ull << 36;
|
225
|
+
int ibest = -1;
|
226
|
+
for (int i = 0; i < MAX_BUFFERS; ++i) {
|
227
|
+
ggml_cann_buffer& b = buffer_pool[i];
|
228
|
+
if (b.ptr != nullptr) {
|
229
|
+
#ifdef DEBUG_CANN_MALLOC
|
230
|
+
++nnz;
|
231
|
+
if (b.size > max_size) max_size = b.size;
|
232
|
+
#endif
|
233
|
+
if (b.size >= size) {
|
234
|
+
size_t diff = b.size - size;
|
235
|
+
if (diff < best_diff) {
|
236
|
+
best_diff = diff;
|
237
|
+
ibest = i;
|
238
|
+
if (!best_diff) {
|
239
|
+
void* ptr = b.ptr;
|
240
|
+
*actual_size = b.size;
|
241
|
+
b.ptr = nullptr;
|
242
|
+
b.size = 0;
|
243
|
+
return ptr;
|
244
|
+
}
|
245
|
+
}
|
246
|
+
}
|
247
|
+
}
|
248
|
+
}
|
249
|
+
if (ibest >= 0) {
|
250
|
+
ggml_cann_buffer& b = buffer_pool[ibest];
|
251
|
+
void* ptr = b.ptr;
|
252
|
+
*actual_size = b.size;
|
253
|
+
b.ptr = nullptr;
|
254
|
+
b.size = 0;
|
255
|
+
return ptr;
|
256
|
+
}
|
257
|
+
void* ptr;
|
258
|
+
ggml_cann_set_device(device);
|
259
|
+
ACL_CHECK(
|
260
|
+
aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
|
261
|
+
*actual_size = size;
|
262
|
+
pool_size += size;
|
263
|
+
#ifdef DEBUG_CANN_MALLOC
|
264
|
+
GGML_LOG_INFO(
|
265
|
+
"%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, "
|
266
|
+
"requested %u MB\n",
|
267
|
+
__func__, device, nnz, (uint32_t)(max_size / 1024 / 1024),
|
268
|
+
(uint32_t)(pool_size / 1024 / 1024),
|
269
|
+
(uint32_t)(size / 1024 / 1024));
|
270
|
+
#endif
|
271
|
+
return ptr;
|
272
|
+
}
|
273
|
+
|
274
|
+
/**
|
275
|
+
* @brief Free a buffer and return it to the pool.
|
276
|
+
*
|
277
|
+
* @param ptr Pointer to the buffer to free.
|
278
|
+
* @param size Size of the buffer to free.
|
279
|
+
*/
|
280
|
+
void free(void* ptr, size_t size) override {
|
281
|
+
for (int i = 0; i < MAX_BUFFERS; ++i) {
|
282
|
+
ggml_cann_buffer& b = buffer_pool[i];
|
283
|
+
if (b.ptr == nullptr) {
|
284
|
+
b.ptr = ptr;
|
285
|
+
b.size = size;
|
286
|
+
return;
|
287
|
+
}
|
288
|
+
}
|
289
|
+
// memory should always buffered. these memory may still needed by
|
290
|
+
// tasks in stream.
|
291
|
+
// TODO, fix me.
|
292
|
+
GGML_ABORT("Cann buffer pool full, increase MAX_CANN_BUFFERS\n");
|
293
|
+
}
|
294
|
+
};
|
295
|
+
|
296
|
+
/**
|
297
|
+
* @brief A pool of CANN buffers with virtual memory.
|
298
|
+
*
|
299
|
+
* This class manages a pool of CANN buffers with virtual memory for a specific
|
300
|
+
* device.
|
301
|
+
*/
|
302
|
+
struct ggml_cann_pool_vmm : public ggml_cann_pool {
|
303
|
+
/**
|
304
|
+
* @brief The maximum size of the virtual memory pool (32 GB).
|
305
|
+
*/
|
306
|
+
size_t max_size;
|
307
|
+
|
308
|
+
/**
|
309
|
+
* @brief The device ID associated with this buffer pool.
|
310
|
+
*/
|
311
|
+
int device;
|
312
|
+
|
313
|
+
/**
|
314
|
+
* @brief Pointer to the start of the virtual memory pool.
|
315
|
+
*/
|
316
|
+
void* pool_addr = 0;
|
317
|
+
|
318
|
+
/**
|
319
|
+
* @brief Amount of virtual memory used in the pool.
|
320
|
+
*/
|
321
|
+
size_t pool_used = 0;
|
322
|
+
|
323
|
+
/**
|
324
|
+
* @brief Total size of the virtual memory pool.
|
325
|
+
*/
|
326
|
+
size_t pool_size = 0;
|
327
|
+
|
328
|
+
/**
|
329
|
+
* @brief Allocation granularity for the virtual memory pool.
|
330
|
+
*/
|
331
|
+
size_t granularity;
|
332
|
+
|
333
|
+
/**
|
334
|
+
* @brief Handles for the physical memory allocated.
|
335
|
+
*/
|
336
|
+
std::vector<aclrtDrvMemHandle> handles;
|
337
|
+
|
338
|
+
/**
|
339
|
+
* @brief Offsets for the mapped memory regions.
|
340
|
+
*/
|
341
|
+
std::vector<void*> map_offsets;
|
342
|
+
|
343
|
+
/**
|
344
|
+
* @brief Constructor to initialize the buffer pool with virtual memory for
|
345
|
+
* a specific device.
|
346
|
+
*
|
347
|
+
* @param device The device ID to associate with this buffer pool.
|
348
|
+
*/
|
349
|
+
explicit ggml_cann_pool_vmm(int device)
|
350
|
+
: device(device),
|
351
|
+
granularity(ggml_cann_info().devices[device].vmm_granularity) {
|
352
|
+
auto dev = ggml_cann_info().devices[device];
|
353
|
+
granularity = dev.vmm_granularity;
|
354
|
+
max_size = dev.total_vram;
|
355
|
+
}
|
356
|
+
|
357
|
+
/**
|
358
|
+
* @brief Destructor to free all buffers in the virtual memory pool.
|
359
|
+
*/
|
360
|
+
~ggml_cann_pool_vmm() {
|
361
|
+
if (pool_addr != 0) {
|
362
|
+
for (auto& offset : map_offsets) {
|
363
|
+
ACL_CHECK(aclrtUnmapMem(offset));
|
364
|
+
}
|
365
|
+
for (auto& handle : handles) {
|
366
|
+
ACL_CHECK(aclrtFreePhysical(handle));
|
367
|
+
}
|
368
|
+
ACL_CHECK(aclrtReleaseMemAddress(pool_addr));
|
369
|
+
}
|
370
|
+
}
|
371
|
+
|
372
|
+
/**
|
373
|
+
* @brief Allocate a buffer of the given size in the virtual memory pool.
|
374
|
+
*
|
375
|
+
* @param size The size of the buffer to allocate.
|
376
|
+
* @param actual_size A pointer to a variable to receive the actual size of
|
377
|
+
* the allocated buffer.
|
378
|
+
* @return A pointer to the allocated buffer.
|
379
|
+
*/
|
380
|
+
void* alloc(size_t size, size_t* actual_size) override {
|
381
|
+
// round up the allocation size to the alignment to ensure that all
|
382
|
+
// allocations are aligned for all data types
|
383
|
+
const size_t alignment = 128;
|
384
|
+
size = GGML_PAD(size, alignment);
|
385
|
+
if (size == 0) {
|
386
|
+
size = alignment;
|
387
|
+
}
|
388
|
+
|
389
|
+
size_t avail = pool_size - pool_used;
|
390
|
+
|
391
|
+
if (size > avail) {
|
392
|
+
// round up to the next multiple of the granularity
|
393
|
+
size_t reserve_size = size - avail;
|
394
|
+
reserve_size = GGML_PAD(reserve_size, granularity);
|
395
|
+
|
396
|
+
GGML_ASSERT(pool_size + reserve_size <= max_size);
|
397
|
+
|
398
|
+
// allocate more physical memory
|
399
|
+
aclrtPhysicalMemProp prop = {};
|
400
|
+
prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
|
401
|
+
prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
|
402
|
+
prop.memAttr = ACL_HBM_MEM_HUGE;
|
403
|
+
prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
|
404
|
+
prop.location.id = device;
|
405
|
+
prop.reserve = 0;
|
406
|
+
aclrtDrvMemHandle handle;
|
407
|
+
ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0));
|
408
|
+
|
409
|
+
// reserve virtual address space (if not already reserved)
|
410
|
+
if (pool_addr == 0) {
|
411
|
+
ACL_CHECK(aclrtReserveMemAddress(
|
412
|
+
&pool_addr, max_size, 0, NULL, 1));
|
413
|
+
}
|
414
|
+
|
415
|
+
// map at the end of the pool
|
416
|
+
ACL_CHECK(aclrtMapMem((char*)pool_addr + pool_size, reserve_size, 0,
|
417
|
+
handle, 0));
|
418
|
+
|
419
|
+
handles.push_back(handle);
|
420
|
+
map_offsets.push_back((char*)pool_addr + pool_size);
|
421
|
+
|
422
|
+
// add to the pool
|
423
|
+
pool_size += reserve_size;
|
424
|
+
|
425
|
+
#ifdef DEBUG_CANN_MALLOC
|
426
|
+
GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
|
427
|
+
device, (unsigned long long) (pool_size/1024/1024),
|
428
|
+
(unsigned long long) (reserve_size/1024/1024));
|
429
|
+
#endif
|
430
|
+
}
|
431
|
+
|
432
|
+
GGML_ASSERT(pool_addr != 0);
|
433
|
+
|
434
|
+
void* ptr = (void*)((char*)pool_addr + pool_used);
|
435
|
+
*actual_size = size;
|
436
|
+
pool_used += size;
|
437
|
+
|
438
|
+
#ifdef DEBUG_CANN_MALLOC
|
439
|
+
GGML_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device,
|
440
|
+
(unsigned long long)size, (unsigned long long)ptr);
|
441
|
+
#endif
|
442
|
+
return ptr;
|
443
|
+
}
|
444
|
+
|
445
|
+
/**
|
446
|
+
* @brief Free a buffer and return it to the virtual memory pool.
|
447
|
+
*
|
448
|
+
* @param ptr Pointer to the buffer to free.
|
449
|
+
* @param size Size of the buffer to free.
|
450
|
+
*/
|
451
|
+
void free(void* ptr, size_t size) override {
|
452
|
+
#ifdef DEBUG_CANN_MALLOC
|
453
|
+
GGML_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device,
|
454
|
+
(unsigned long long)size, (unsigned long long)ptr);
|
455
|
+
#endif
|
456
|
+
|
457
|
+
pool_used -= size;
|
458
|
+
|
459
|
+
// all deallocations must be in reverse order of the allocations
|
460
|
+
GGML_ASSERT(ptr == (void*)((char*)pool_addr + pool_used));
|
461
|
+
}
|
462
|
+
};
|
463
|
+
|
464
|
+
/**
|
465
|
+
* @brief Create a new CANN pool for a specific device.
|
466
|
+
*
|
467
|
+
* Factory method to create a new CANN pool object based on the device type.
|
468
|
+
*
|
469
|
+
* @param device The device ID for which to create the pool.
|
470
|
+
* @return A unique pointer to the created CANN pool.
|
471
|
+
*/
|
472
|
+
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
|
473
|
+
int device) {
|
474
|
+
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
|
475
|
+
}
|
476
|
+
|
477
|
+
// cann buffer
|
478
|
+
/**
|
479
|
+
* @brief Context for managing a CANN buffer associated with a specific device.
|
480
|
+
*
|
481
|
+
* This structure holds information about a CANN buffer, including the device
|
482
|
+
* ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID.
|
483
|
+
*/
|
484
|
+
struct ggml_backend_cann_buffer_context {
|
485
|
+
int32_t device; ///< The device ID associated with this buffer context.
|
486
|
+
void* dev_ptr =
|
487
|
+
nullptr; ///< Pointer to the device memory allocated for the buffer.
|
488
|
+
|
489
|
+
/**
|
490
|
+
* @brief Constructor to initialize the CANN buffer context.
|
491
|
+
*
|
492
|
+
* @param device The device ID associated with this buffer context.
|
493
|
+
* @param dev_ptr Pointer to the device memory allocated for the buffer.
|
494
|
+
*/
|
495
|
+
ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
|
496
|
+
: device(device),
|
497
|
+
dev_ptr(dev_ptr) {}
|
498
|
+
|
499
|
+
/**
|
500
|
+
* @brief Destructor to free the device memory allocated for the buffer.
|
501
|
+
*/
|
502
|
+
~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
|
503
|
+
};
|
504
|
+
|
505
|
+
/**
|
506
|
+
* @brief Check if a buffer is a CANN buffer.
|
507
|
+
*
|
508
|
+
* This function checks if a given buffer is a CANN buffer by comparing its
|
509
|
+
* `get_name` function pointer to `ggml_backend_cann_buffer_get_name`.
|
510
|
+
*
|
511
|
+
* @param buffer The buffer to check.
|
512
|
+
* @return true if the buffer is a CANN buffer, false otherwise.
|
513
|
+
*/
|
514
|
+
static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft);
|
515
|
+
static bool ggml_backend_buffer_is_cann(
|
516
|
+
ggml_backend_buffer_t buffer) {
|
517
|
+
return ggml_backend_buft_is_cann(buffer->buft);
|
518
|
+
}
|
519
|
+
|
520
|
+
/**
|
521
|
+
* @brief Free resources associated with a CANN buffer.
|
522
|
+
*
|
523
|
+
* This function frees the resources associated with a CANN buffer, including
|
524
|
+
* its context.
|
525
|
+
*
|
526
|
+
* @param buffer The CANN buffer to free.
|
527
|
+
*/
|
528
|
+
static void ggml_backend_cann_buffer_free_buffer(
|
529
|
+
ggml_backend_buffer_t buffer) {
|
530
|
+
ggml_backend_cann_buffer_context* ctx =
|
531
|
+
(ggml_backend_cann_buffer_context*)buffer->context;
|
532
|
+
delete ctx;
|
533
|
+
}
|
534
|
+
|
535
|
+
/**
|
536
|
+
* @brief Retrieve the base pointer of a CANN buffer.
|
537
|
+
*
|
538
|
+
* This function returns the base pointer of a CANN buffer, which points to the
|
539
|
+
* device memory allocated for the buffer.
|
540
|
+
*
|
541
|
+
* @param buffer The CANN buffer whose base pointer is to be retrieved.
|
542
|
+
* @return A pointer to the base of the device memory allocated for the buffer.
|
543
|
+
*/
|
544
|
+
static void* ggml_backend_cann_buffer_get_base(
|
545
|
+
ggml_backend_buffer_t buffer) {
|
546
|
+
ggml_backend_cann_buffer_context* ctx =
|
547
|
+
(ggml_backend_cann_buffer_context*)buffer->context;
|
548
|
+
return ctx->dev_ptr;
|
549
|
+
}
|
550
|
+
|
551
|
+
/**
|
552
|
+
* @brief Transform quantized Q4.0 tensor data into a format suitable for CANN
|
553
|
+
* processing.
|
554
|
+
*
|
555
|
+
* This function transforms quantized Q4.0 tensor data into a format suitable
|
556
|
+
* for CANN processing. It extracts quantization values and scales from the
|
557
|
+
* source data and prepares them in a format expected by CANN operations.
|
558
|
+
*
|
559
|
+
* @param tensor Pointer to the tensor information.
|
560
|
+
* @param src Pointer to the source data in Q4.0 format.
|
561
|
+
* @param dst Pointer to the destination buffer where transformed data will be
|
562
|
+
* stored.
|
563
|
+
*/
|
564
|
+
static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
|
565
|
+
const void* src,
|
566
|
+
void* dst) {
|
567
|
+
|
568
|
+
int64_t n_elems = ggml_nelements(tensor);
|
569
|
+
int64_t groups = n_elems / QK4_0;
|
570
|
+
size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
|
571
|
+
|
572
|
+
uint8_t* quant_offset = (uint8_t*)dst;
|
573
|
+
uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
|
574
|
+
|
575
|
+
for (int i = 0; i < groups; i++) {
|
576
|
+
const block_q4_0* group =
|
577
|
+
(const block_q4_0*)((const char*)src + i * sizeof(block_q4_0));
|
578
|
+
*scale_offset = group->d;
|
579
|
+
scale_offset++;
|
580
|
+
|
581
|
+
// 0-15
|
582
|
+
for (int j = 0; j < QK4_0 / 2; j += 2) {
|
583
|
+
(*quant_offset) = (group->qs[j] & 0x0F);
|
584
|
+
(*quant_offset) |= ((group->qs[j + 1] << 4));
|
585
|
+
quant_offset++;
|
586
|
+
}
|
587
|
+
|
588
|
+
// 16-31
|
589
|
+
for (int j = 0; j < QK4_0 / 2; j += 2) {
|
590
|
+
(*quant_offset) = (group->qs[j] >> 4);
|
591
|
+
(*quant_offset) |= (group->qs[j + 1] & 0xF0);
|
592
|
+
quant_offset++;
|
593
|
+
}
|
594
|
+
}
|
595
|
+
|
596
|
+
// put (uint4b_t -8) into int4b_t
|
597
|
+
for (quant_offset = (uint8_t*)dst;
|
598
|
+
quant_offset < (uint8_t*)dst + quant_bytes; quant_offset++) {
|
599
|
+
(*quant_offset) ^= 0x88;
|
600
|
+
}
|
601
|
+
}
|
602
|
+
|
603
|
+
/**
|
604
|
+
* @brief Transform CANN processed data back into quantized Q4.0 format.
|
605
|
+
*
|
606
|
+
* This function transforms CANN processed data back into quantized Q4.0 format.
|
607
|
+
* It reverses the transformation performed by
|
608
|
+
* ggml_backend_cann_transform_q4_0(), converting the data back into its
|
609
|
+
* original quantized form.
|
610
|
+
*
|
611
|
+
* @param tensor Pointer to the tensor information.
|
612
|
+
* @param src Pointer to the source buffer containing transformed data.
|
613
|
+
* @param dst Pointer to the destination buffer where the Q4.0 formatted data
|
614
|
+
* will be stored.
|
615
|
+
*/
|
616
|
+
static void ggml_backend_cann_transform_back_q4_0(
|
617
|
+
const ggml_tensor* tensor, void* src, void* dst) {
|
618
|
+
|
619
|
+
int64_t n_elems = ggml_nelements(tensor);
|
620
|
+
int64_t groups = n_elems / QK4_0;
|
621
|
+
size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
|
622
|
+
|
623
|
+
uint8_t* quant_offset = (uint8_t*)src;
|
624
|
+
uint16_t* scale_offset = (uint16_t*)((char*)src + quant_bytes);
|
625
|
+
|
626
|
+
for (; quant_offset < (uint8_t*)src + quant_bytes; quant_offset++) {
|
627
|
+
(*quant_offset) ^= 0x88;
|
628
|
+
}
|
629
|
+
quant_offset = (uint8_t*)src;
|
630
|
+
|
631
|
+
for (int i = 0; i < groups; i++) {
|
632
|
+
block_q4_0* group = (block_q4_0*)((char*)dst + i * sizeof(block_q4_0));
|
633
|
+
group->d = *scale_offset;
|
634
|
+
scale_offset++;
|
635
|
+
|
636
|
+
// 0-15
|
637
|
+
for (int j = 0; j < QK4_0 / 2; j += 2) {
|
638
|
+
group->qs[j] = ((*quant_offset) & 0x0F);
|
639
|
+
group->qs[j + 1] = ((*quant_offset) >> 4);
|
640
|
+
quant_offset++;
|
641
|
+
}
|
642
|
+
|
643
|
+
// 16-31
|
644
|
+
for (int j = 0; j < QK4_0 / 2; j += 2) {
|
645
|
+
group->qs[j] |= ((*quant_offset) << 4);
|
646
|
+
group->qs[j + 1] |= ((*quant_offset) & 0xF0);
|
647
|
+
quant_offset++;
|
648
|
+
}
|
649
|
+
}
|
650
|
+
}
|
651
|
+
|
652
|
+
/**
|
653
|
+
* @brief Transform quantized Q8.0 tensor data into a format suitable for CANN
|
654
|
+
* processing.
|
655
|
+
*
|
656
|
+
* This function transforms quantized Q8.0 tensor data into a format suitable
|
657
|
+
* for CANN processing. It extracts quantization values and scales from the
|
658
|
+
* source data and prepares them in a format expected by CANN operations.
|
659
|
+
*
|
660
|
+
* @param tensor Pointer to the tensor information.
|
661
|
+
* @param src Pointer to the source data in Q8.0 format.
|
662
|
+
* @param dst Pointer to the destination buffer where transformed data will be
|
663
|
+
* stored.
|
664
|
+
*/
|
665
|
+
static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor,
|
666
|
+
const void* src,
|
667
|
+
void* dst) {
|
668
|
+
int64_t n_elems = ggml_nelements(tensor);
|
669
|
+
int64_t groups = n_elems / QK8_0;
|
670
|
+
size_t quant_bytes = n_elems * sizeof(uint8_t);
|
671
|
+
|
672
|
+
uint8_t* quant_offset = (uint8_t*)dst;
|
673
|
+
uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
|
674
|
+
|
675
|
+
for (int i = 0; i < groups; i++) {
|
676
|
+
const block_q8_0* group =
|
677
|
+
(const block_q8_0*)((const char*)src + i * sizeof(block_q8_0));
|
678
|
+
*scale_offset = group->d;
|
679
|
+
scale_offset++;
|
680
|
+
size_t group_quant_size = QK8_0 * sizeof(uint8_t);
|
681
|
+
memcpy(quant_offset, group->qs, group_quant_size);
|
682
|
+
quant_offset += group_quant_size;
|
683
|
+
}
|
684
|
+
}
|
685
|
+
|
686
|
+
/**
|
687
|
+
* @brief Transform CANN processed data back into quantized Q8.0 format.
|
688
|
+
*
|
689
|
+
* This function transforms CANN processed data back into quantized Q8.0 format.
|
690
|
+
* It reverses the transformation performed by
|
691
|
+
* ggml_backend_cann_transform_q8_0(), converting the data back into its
|
692
|
+
* original quantized form.
|
693
|
+
*
|
694
|
+
* @param tensor Pointer to the tensor information.
|
695
|
+
* @param src Pointer to the source buffer containing transformed data.
|
696
|
+
* @param dst Pointer to the destination buffer where the Q8.0 formatted data
|
697
|
+
* will be stored.
|
698
|
+
*/
|
699
|
+
static void ggml_backend_cann_transform_back_q8_0(
|
700
|
+
const ggml_tensor* tensor, const void* src, void* dst) {
|
701
|
+
int64_t n_elems = ggml_nelements(tensor);
|
702
|
+
int64_t groups = n_elems / QK8_0;
|
703
|
+
size_t quant_bytes = n_elems * sizeof(uint8_t);
|
704
|
+
|
705
|
+
const uint8_t* quant_offset = (const uint8_t*)src;
|
706
|
+
const uint16_t* scale_offset =
|
707
|
+
(const uint16_t*)((const char*)src + quant_bytes);
|
708
|
+
|
709
|
+
for (int i = 0; i < groups; i++) {
|
710
|
+
block_q8_0* group = (block_q8_0*)((char*)dst + i * sizeof(block_q8_0));
|
711
|
+
group->d = *scale_offset;
|
712
|
+
scale_offset++;
|
713
|
+
size_t group_quant_size = QK8_0 * sizeof(uint8_t);
|
714
|
+
memcpy(group->qs, quant_offset, group_quant_size);
|
715
|
+
quant_offset += group_quant_size;
|
716
|
+
}
|
717
|
+
}
|
718
|
+
|
719
|
+
/**
|
720
|
+
* @brief Transform tensor data based on its type for CANN processing.
|
721
|
+
*
|
722
|
+
* This function transforms tensor data based on its quantization type for CANN
|
723
|
+
* processing. It dispatches the transformation based on the tensor's type to
|
724
|
+
* specialized functions handling Q4.0 and Q8.0 formats.
|
725
|
+
*
|
726
|
+
* @param tensor Pointer to the tensor information.
|
727
|
+
* @param src Pointer to the source data to be transformed.
|
728
|
+
* @param dst Pointer to the destination buffer where transformed data will be
|
729
|
+
* stored.
|
730
|
+
*/
|
731
|
+
static void ggml_backend_cann_transform(ggml_tensor* tensor,
|
732
|
+
const void* src, void* dst) {
|
733
|
+
switch (tensor->type) {
|
734
|
+
case GGML_TYPE_Q4_0:
|
735
|
+
ggml_backend_cann_transform_q4_0(tensor, src, dst);
|
736
|
+
break;
|
737
|
+
case GGML_TYPE_Q8_0:
|
738
|
+
ggml_backend_cann_transform_q8_0(tensor, src, dst);
|
739
|
+
break;
|
740
|
+
default:
|
741
|
+
break;
|
742
|
+
}
|
743
|
+
}
|
744
|
+
|
745
|
+
/**
|
746
|
+
* @brief Transform CANN processed data back into tensor data based on its type.
|
747
|
+
*
|
748
|
+
* This function transforms CANN processed data back into tensor data based on
|
749
|
+
* its quantization type for Q4.0 and Q8.0 formats. It dispatches the
|
750
|
+
* transformation based on the tensor's type to specialized functions.
|
751
|
+
*
|
752
|
+
* @param tensor Pointer to the tensor information.
|
753
|
+
* @param src Pointer to the source data containing CANN processed data.
|
754
|
+
* @param dst Pointer to the destination buffer where transformed tensor data
|
755
|
+
* will be stored.
|
756
|
+
*/
|
757
|
+
static void ggml_backend_cann_transform_back(
|
758
|
+
const ggml_tensor* tensor, void* src, void* dst) {
|
759
|
+
switch (tensor->type) {
|
760
|
+
case GGML_TYPE_Q4_0:
|
761
|
+
ggml_backend_cann_transform_back_q4_0(tensor, src, dst);
|
762
|
+
break;
|
763
|
+
case GGML_TYPE_Q8_0:
|
764
|
+
ggml_backend_cann_transform_back_q8_0(tensor, src, dst);
|
765
|
+
break;
|
766
|
+
default:
|
767
|
+
break;
|
768
|
+
}
|
769
|
+
}
|
770
|
+
|
771
|
+
/**
|
772
|
+
* @brief Check if transformation is needed for a given tensor type.
|
773
|
+
*
|
774
|
+
* This function checks if transformation is needed for a given tensor type
|
775
|
+
* to prepare data for CANN processing.
|
776
|
+
*
|
777
|
+
* @param type The tensor type to check.
|
778
|
+
* @return true if transformation is needed, false otherwise.
|
779
|
+
*/
|
780
|
+
static bool need_transform(ggml_type type) {
|
781
|
+
switch (type) {
|
782
|
+
case GGML_TYPE_Q4_0:
|
783
|
+
case GGML_TYPE_Q8_0:
|
784
|
+
return true;
|
785
|
+
default:
|
786
|
+
return false;
|
787
|
+
}
|
788
|
+
}
|
789
|
+
|
790
|
+
/**
|
791
|
+
* @brief Initialize a tensor using data from a CANN buffer.
|
792
|
+
*
|
793
|
+
* This function initializes a tensor using data from a CANN buffer.
|
794
|
+
* It handles special cases such as views and quantization.
|
795
|
+
*
|
796
|
+
* @param buffer The CANN buffer from which to initialize the tensor.
|
797
|
+
* @param tensor Pointer to the tensor to be initialized.
|
798
|
+
*/
|
799
|
+
static void ggml_backend_cann_buffer_init_tensor(
|
800
|
+
ggml_backend_buffer_t buffer, ggml_tensor* tensor) {
|
801
|
+
if (tensor->view_src != NULL && tensor->view_offs == 0) {
|
802
|
+
GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
|
803
|
+
return;
|
804
|
+
}
|
805
|
+
|
806
|
+
// TODO: can backend doesn't support quantized yet. Just leave the code
|
807
|
+
// here.
|
808
|
+
if (ggml_is_quantized(tensor->type)) {
|
809
|
+
// Initialize padding to 0 to avoid possible NaN values
|
810
|
+
size_t original_size = ggml_nbytes(tensor);
|
811
|
+
size_t padded_size =
|
812
|
+
ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
|
813
|
+
|
814
|
+
if (padded_size > original_size && tensor->view_src == nullptr) {
|
815
|
+
size_t memset_size = padded_size - original_size;
|
816
|
+
ACL_CHECK(aclrtMemset((char*)tensor->data + original_size,
|
817
|
+
memset_size, 0, memset_size));
|
818
|
+
}
|
819
|
+
}
|
820
|
+
}
|
821
|
+
|
822
|
+
// TODO: need handle tensor which has paddings.
|
823
|
+
/**
|
824
|
+
* @brief Set tensor data in a CANN buffer.
|
825
|
+
*
|
826
|
+
* This function sets tensor data in a CANN buffer, handling transformations
|
827
|
+
* if needed based on the tensor's type.
|
828
|
+
*
|
829
|
+
* @param buffer The CANN buffer where the tensor data will be set.
|
830
|
+
* @param tensor Pointer to the tensor whose data will be set.
|
831
|
+
* @param data Pointer to the source data to be copied into the tensor.
|
832
|
+
* @param offset Offset in the source data from where to start copying.
|
833
|
+
* @param size Size of the data to be copied, in bytes.
|
834
|
+
*/
|
835
|
+
static void ggml_backend_cann_buffer_set_tensor(
|
836
|
+
ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data,
|
837
|
+
size_t offset, size_t size) {
|
838
|
+
ggml_backend_cann_buffer_context *ctx =
|
839
|
+
(ggml_backend_cann_buffer_context *)buffer->context;
|
840
|
+
|
841
|
+
ggml_cann_set_device(ctx->device);
|
842
|
+
// TODO: refer to cann(#6017), it use thread's default stream.
|
843
|
+
// For acl, synchronous functions use this default stream.
|
844
|
+
// Why aclrtSynchronizeDevice?
|
845
|
+
|
846
|
+
if (!need_transform(tensor->type)) {
|
847
|
+
ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
|
848
|
+
ACL_MEMCPY_HOST_TO_DEVICE));
|
849
|
+
} else {
|
850
|
+
void *transform_buffer = malloc(size);
|
851
|
+
ggml_backend_cann_transform(tensor, data, transform_buffer);
|
852
|
+
|
853
|
+
ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size,
|
854
|
+
transform_buffer, size,
|
855
|
+
ACL_MEMCPY_HOST_TO_DEVICE));
|
856
|
+
free(transform_buffer);
|
857
|
+
}
|
858
|
+
}
|
859
|
+
|
860
|
+
/**
|
861
|
+
* @brief Get tensor data from a CANN buffer.
|
862
|
+
*
|
863
|
+
* This function retrieves tensor data from a CANN buffer, handling
|
864
|
+
* transformations if needed based on the tensor's type.
|
865
|
+
*
|
866
|
+
* @param buffer The CANN buffer from which to retrieve tensor data.
|
867
|
+
* @param tensor Pointer to the tensor whose data will be retrieved.
|
868
|
+
* @param data Pointer to the destination buffer where the tensor data will be
|
869
|
+
* copied.
|
870
|
+
* @param offset Offset in the destination buffer where to start copying.
|
871
|
+
* @param size Size of the data to be copied, in bytes.
|
872
|
+
*/
|
873
|
+
static void ggml_backend_cann_buffer_get_tensor(
|
874
|
+
ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data,
|
875
|
+
size_t offset, size_t size) {
|
876
|
+
ggml_backend_cann_buffer_context* ctx =
|
877
|
+
(ggml_backend_cann_buffer_context*)buffer->context;
|
878
|
+
|
879
|
+
ggml_cann_set_device(ctx->device);
|
880
|
+
|
881
|
+
if (!need_transform(tensor->type)) {
|
882
|
+
ACL_CHECK(aclrtMemcpy(data, size, (char*)tensor->data + offset, size,
|
883
|
+
ACL_MEMCPY_DEVICE_TO_HOST));
|
884
|
+
} else {
|
885
|
+
void* transform_buffer = malloc(size);
|
886
|
+
ACL_CHECK(aclrtMemcpy(transform_buffer, size,
|
887
|
+
(char*)tensor->data + offset, size,
|
888
|
+
ACL_MEMCPY_DEVICE_TO_HOST));
|
889
|
+
ggml_backend_cann_transform_back(tensor, transform_buffer, data);
|
890
|
+
free(transform_buffer);
|
891
|
+
}
|
892
|
+
}
|
893
|
+
|
894
|
+
/**
|
895
|
+
* @brief Copy tensor data between CANN buffers if possible.
|
896
|
+
*
|
897
|
+
* This function copies tensor data between CANN buffers if the source and
|
898
|
+
* destination buffers are CANN buffers and they meet the necessary conditions
|
899
|
+
* (same device or devices can access each other).
|
900
|
+
*
|
901
|
+
* @param buffer The destination CANN buffer where the tensor data will be
|
902
|
+
* copied.
|
903
|
+
* @param src Pointer to the source tensor whose data will be copied.
|
904
|
+
* @param dst Pointer to the destination tensor where the data will be copied.
|
905
|
+
* @return true if the copy operation succeeded, false otherwise.
|
906
|
+
*/
|
907
|
+
static bool ggml_backend_cann_buffer_cpy_tensor(
|
908
|
+
ggml_backend_buffer_t buffer, const ggml_tensor* src, ggml_tensor* dst) {
|
909
|
+
if (ggml_backend_buffer_is_cann(src->buffer)) {
|
910
|
+
ggml_backend_cann_buffer_context* src_ctx =
|
911
|
+
(ggml_backend_cann_buffer_context*)src->buffer->context;
|
912
|
+
ggml_backend_cann_buffer_context* dst_ctx =
|
913
|
+
(ggml_backend_cann_buffer_context*)buffer->context;
|
914
|
+
|
915
|
+
size_t memcpy_size = ggml_nbytes(src);
|
916
|
+
// Same device.
|
917
|
+
if (src_ctx->device == dst_ctx->device) {
|
918
|
+
ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
|
919
|
+
(const char*)src->data, memcpy_size,
|
920
|
+
ACL_MEMCPY_DEVICE_TO_DEVICE));
|
921
|
+
return true;
|
922
|
+
} else {
|
923
|
+
// Different device but can access by peer.
|
924
|
+
int32_t canAccessPeer = 0;
|
925
|
+
ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device,
|
926
|
+
dst_ctx->device));
|
927
|
+
if (canAccessPeer) {
|
928
|
+
ggml_cann_set_device(src_ctx->device);
|
929
|
+
ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0));
|
930
|
+
ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
|
931
|
+
(const char*)src->data, memcpy_size,
|
932
|
+
ACL_MEMCPY_DEVICE_TO_DEVICE));
|
933
|
+
return true;
|
934
|
+
}
|
935
|
+
}
|
936
|
+
}
|
937
|
+
return false;
|
938
|
+
}
|
939
|
+
|
940
|
+
/**
|
941
|
+
* @brief Clear a CANN buffer by setting all its memory to a specified value.
|
942
|
+
*
|
943
|
+
* This function clears a CANN buffer by setting all its memory to a specified
|
944
|
+
* value.
|
945
|
+
*
|
946
|
+
* @param buffer The CANN buffer to be cleared.
|
947
|
+
* @param value The value to which each byte in the buffer will be set.
|
948
|
+
*/
|
949
|
+
static void ggml_backend_cann_buffer_clear(
|
950
|
+
ggml_backend_buffer_t buffer, uint8_t value) {
|
951
|
+
ggml_backend_cann_buffer_context* ctx =
|
952
|
+
(ggml_backend_cann_buffer_context*)buffer->context;
|
953
|
+
|
954
|
+
ggml_cann_set_device(ctx->device);
|
955
|
+
ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
|
956
|
+
}
|
957
|
+
|
958
|
+
/**
|
959
|
+
* @brief Interface for a CANN buffer in the backend.
|
960
|
+
*
|
961
|
+
* This structure defines function pointers to operations that can be performed
|
962
|
+
* on a CANN buffer within the backend.
|
963
|
+
*/
|
964
|
+
static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
|
965
|
+
/* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
|
966
|
+
/* .get_base = */ ggml_backend_cann_buffer_get_base,
|
967
|
+
/* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
|
968
|
+
/* .memset_tensor = */ NULL,
|
969
|
+
/* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
|
970
|
+
/* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
|
971
|
+
/* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor,
|
972
|
+
/* .clear = */ ggml_backend_cann_buffer_clear,
|
973
|
+
/* .reset = */ NULL,
|
974
|
+
};
|
975
|
+
|
976
|
+
// cann buffer type
|
977
|
+
/**
|
978
|
+
* @brief Structure representing context information for a specific backend
|
979
|
+
* buffer type.
|
980
|
+
*/
|
981
|
+
struct ggml_backend_cann_buffer_type_context {
|
982
|
+
int32_t
|
983
|
+
device; /**< Device identifier associated with the buffer context. */
|
984
|
+
std::string name; /**< Name associated with the buffer context. */
|
985
|
+
};
|
986
|
+
|
987
|
+
/**
|
988
|
+
* @brief Retrieves the name associated with a CANN buffer type.
|
989
|
+
*
|
990
|
+
* This function returns the descriptive name associated with the specified
|
991
|
+
* CANN buffer type context.
|
992
|
+
*
|
993
|
+
* @param buft Pointer to the buffer type context.
|
994
|
+
* @return Const pointer to the C-style string containing the name.
|
995
|
+
*/
|
996
|
+
static const char* ggml_backend_cann_buffer_type_name(
|
997
|
+
ggml_backend_buffer_type_t buft) {
|
998
|
+
ggml_backend_cann_buffer_type_context* buft_ctx =
|
999
|
+
(ggml_backend_cann_buffer_type_context*)buft->context;
|
1000
|
+
|
1001
|
+
return buft_ctx->name.c_str();
|
1002
|
+
}
|
1003
|
+
|
1004
|
+
/**
|
1005
|
+
* @brief Allocates a new CANN buffer of the specified type and size.
|
1006
|
+
*
|
1007
|
+
* This function allocates a new CANN buffer on the specified device with the
|
1008
|
+
* given size.
|
1009
|
+
*
|
1010
|
+
* @param buft Pointer to the buffer type context.
|
1011
|
+
* @param size Size in bytes of the buffer to allocate.
|
1012
|
+
* @return Pointer to the allocated buffer, or nullptr if allocation fails.
|
1013
|
+
*/
|
1014
|
+
static ggml_backend_buffer_t
|
1015
|
+
ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
|
1016
|
+
size_t size) {
|
1017
|
+
ggml_backend_cann_buffer_type_context* buft_ctx =
|
1018
|
+
(ggml_backend_cann_buffer_type_context*)buft->context;
|
1019
|
+
|
1020
|
+
ggml_cann_set_device(buft_ctx->device);
|
1021
|
+
|
1022
|
+
size = std::max(size, (size_t)1);
|
1023
|
+
|
1024
|
+
void* dev_ptr;
|
1025
|
+
aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
|
1026
|
+
if (err != ACL_SUCCESS) {
|
1027
|
+
GGML_LOG_ERROR(
|
1028
|
+
"%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
|
1029
|
+
__func__, size / 1024.0 / 1024.0, buft_ctx->device,
|
1030
|
+
aclGetRecentErrMsg());
|
1031
|
+
return nullptr;
|
1032
|
+
}
|
1033
|
+
|
1034
|
+
ggml_backend_cann_buffer_context* ctx =
|
1035
|
+
new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
|
1036
|
+
|
1037
|
+
return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
|
1038
|
+
ctx, size);
|
1039
|
+
}
|
1040
|
+
|
1041
|
+
/**
|
1042
|
+
* @brief Retrieves the memory alignment requirement for CANN buffers of this
|
1043
|
+
* type.
|
1044
|
+
*
|
1045
|
+
* This function returns the alignment requirement in bytes for memory allocated
|
1046
|
+
* by the CANN buffer type.
|
1047
|
+
*
|
1048
|
+
* @param buft Pointer to the buffer type context (unused in this
|
1049
|
+
* implementation).
|
1050
|
+
* @return The alignment requirement in bytes (fixed at 128 bytes for CANN
|
1051
|
+
* buffers).
|
1052
|
+
*/
|
1053
|
+
static size_t ggml_backend_cann_buffer_type_get_alignment(
|
1054
|
+
ggml_backend_buffer_type_t buft) {
|
1055
|
+
return 128;
|
1056
|
+
|
1057
|
+
GGML_UNUSED(buft);
|
1058
|
+
}
|
1059
|
+
|
1060
|
+
/**
|
1061
|
+
* @brief Calculates the allocation size required for a tensor in a CANN buffer.
|
1062
|
+
*
|
1063
|
+
* Computes the total allocation size needed for storing the tensor's data in a
|
1064
|
+
* CANN buffer, considering any necessary padding or adjustments for quantized
|
1065
|
+
* types.
|
1066
|
+
*
|
1067
|
+
* @param buft Pointer to the buffer type context (unused in this
|
1068
|
+
* implementation).
|
1069
|
+
* @param tensor Pointer to the tensor for which the allocation size is
|
1070
|
+
* calculated.
|
1071
|
+
* @return The total allocation size in bytes required for the tensor in the
|
1072
|
+
* CANN buffer.
|
1073
|
+
*/
|
1074
|
+
static size_t ggml_backend_cann_buffer_type_get_alloc_size(
|
1075
|
+
ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
|
1076
|
+
size_t size = ggml_nbytes(tensor);
|
1077
|
+
int64_t ne0 = tensor->ne[0];
|
1078
|
+
|
1079
|
+
// last line must bigger than 32, because every single op deal at
|
1080
|
+
// least 32 bytes.
|
1081
|
+
// TODO: quantized type?
|
1082
|
+
// int64_t line_size = ne0 * ggml_element_size(tensor);
|
1083
|
+
// int64_t line_size_align_32 = (line_size + 31) & ~31;
|
1084
|
+
// size += (line_size_align_32 - line_size);
|
1085
|
+
|
1086
|
+
// TODO: not support quantized yet.
|
1087
|
+
// TODO: consider un-continue tensor.
|
1088
|
+
if (ggml_is_quantized(tensor->type)) {
|
1089
|
+
if (ne0 % MATRIX_ROW_PADDING != 0) {
|
1090
|
+
size += ggml_row_size(
|
1091
|
+
tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
|
1092
|
+
}
|
1093
|
+
}
|
1094
|
+
|
1095
|
+
return size;
|
1096
|
+
|
1097
|
+
GGML_UNUSED(buft);
|
1098
|
+
}
|
1099
|
+
|
1100
|
+
static bool ggml_backend_cann_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
1101
|
+
return false;
|
1102
|
+
|
1103
|
+
GGML_UNUSED(buft);
|
1104
|
+
}
|
1105
|
+
|
1106
|
+
/**
|
1107
|
+
* @brief Interface for managing CANN buffer types in the GGML backend.
|
1108
|
+
*
|
1109
|
+
* Provides function pointers for allocating, querying properties, and managing
|
1110
|
+
* memory for CANN buffer types in the GGML backend.
|
1111
|
+
*/
|
1112
|
+
static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
|
1113
|
+
/* .get_name = */ ggml_backend_cann_buffer_type_name,
|
1114
|
+
/* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer,
|
1115
|
+
/* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment,
|
1116
|
+
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
1117
|
+
/* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size,
|
1118
|
+
/* .is_host = */ ggml_backend_cann_buffer_type_is_host,
|
1119
|
+
};
|
1120
|
+
|
1121
|
+
/**
|
1122
|
+
* @brief Retrieves the CANN buffer type for a specified device.
|
1123
|
+
*
|
1124
|
+
* This function initializes and returns the buffer type interface associated
|
1125
|
+
* with the given device. It ensures thread-safe access using a mutex.
|
1126
|
+
*
|
1127
|
+
* @param device The device index for which to retrieve the buffer type.
|
1128
|
+
* @return A pointer to the buffer type interface for the specified device, or
|
1129
|
+
* nullptr if the device index is out of range.
|
1130
|
+
*/
|
1131
|
+
ggml_backend_buffer_type_t
|
1132
|
+
ggml_backend_cann_buffer_type(int32_t device) {
|
1133
|
+
static std::mutex mutex;
|
1134
|
+
std::lock_guard<std::mutex> lock(mutex);
|
1135
|
+
|
1136
|
+
if (device >= ggml_backend_cann_get_device_count()) {
|
1137
|
+
return nullptr;
|
1138
|
+
}
|
1139
|
+
|
1140
|
+
static ggml_backend_buffer_type
|
1141
|
+
ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];
|
1142
|
+
|
1143
|
+
static bool ggml_backend_cann_buffer_type_initialized = false;
|
1144
|
+
|
1145
|
+
if (!ggml_backend_cann_buffer_type_initialized) {
|
1146
|
+
for (int32_t i = 0; i < ggml_cann_info().device_count; i++) {
|
1147
|
+
ggml_backend_cann_buffer_types[i] = {
|
1148
|
+
/* .iface = */ ggml_backend_cann_buffer_type_interface,
|
1149
|
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), i),
|
1150
|
+
/* .context = */
|
1151
|
+
new ggml_backend_cann_buffer_type_context{
|
1152
|
+
i, "CANN" + std::to_string(i)},
|
1153
|
+
};
|
1154
|
+
}
|
1155
|
+
ggml_backend_cann_buffer_type_initialized = true;
|
1156
|
+
}
|
1157
|
+
|
1158
|
+
return &ggml_backend_cann_buffer_types[device];
|
1159
|
+
}
|
1160
|
+
|
1161
|
+
/**
|
1162
|
+
* @brief Retrieves the name associated with a CANN host buffer type.
|
1163
|
+
*
|
1164
|
+
* This function returns the descriptive name associated with the specified
|
1165
|
+
* CANN host buffer type context.
|
1166
|
+
*
|
1167
|
+
* @param buft Pointer to the host buffer type context.
|
1168
|
+
* @return Const pointer to the C-style string containing the name.
|
1169
|
+
*/
|
1170
|
+
static const char * ggml_backend_cann_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
|
1171
|
+
return "CANN_Host";
|
1172
|
+
|
1173
|
+
GGML_UNUSED(buft);
|
1174
|
+
}
|
1175
|
+
|
1176
|
+
/**
|
1177
|
+
* @brief Retrieves the name associated with a CANN host buffer.
|
1178
|
+
*
|
1179
|
+
* This function returns the descriptive name associated with the specified
|
1180
|
+
* CANN host buffer context.
|
1181
|
+
*
|
1182
|
+
* @param buft Pointer to the host buffer context.
|
1183
|
+
* @return Const pointer to the C-style string containing the name.
|
1184
|
+
*/
|
1185
|
+
static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buffer) {
|
1186
|
+
return "CANN_Host";
|
1187
|
+
|
1188
|
+
GGML_UNUSED(buffer);
|
1189
|
+
}
|
1190
|
+
|
1191
|
+
/**
|
1192
|
+
* @brief Free resources associated with a CANN host buffer.
|
1193
|
+
*
|
1194
|
+
* This function frees the resources associated with a CANN host buffer, including
|
1195
|
+
* its context.
|
1196
|
+
*
|
1197
|
+
* @param buffer The CANN host buffer to free.
|
1198
|
+
*/
|
1199
|
+
static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {
|
1200
|
+
ACL_CHECK(aclrtFreeHost(buffer->context));
|
1201
|
+
}
|
1202
|
+
|
1203
|
+
/**
|
1204
|
+
* @brief Allocates a new CANN host buffer of the specified size.
|
1205
|
+
*
|
1206
|
+
* This function allocates a new CANN host buffer with the given size.
|
1207
|
+
* @param size Size in bytes of the host buffer to allocate.
|
1208
|
+
* @return Pointer to the allocated host buffer, or nullptr if allocation fails.
|
1209
|
+
*/
|
1210
|
+
static void * ggml_cann_host_malloc(size_t size) {
|
1211
|
+
if (getenv("GGML_CANN_NO_PINNED") != nullptr) {
|
1212
|
+
return nullptr;
|
1213
|
+
}
|
1214
|
+
|
1215
|
+
const size_t alignment = 128;
|
1216
|
+
size = GGML_PAD(size, alignment);
|
1217
|
+
if (size == 0) {
|
1218
|
+
size = alignment;
|
1219
|
+
}
|
1220
|
+
|
1221
|
+
void * hostPtr = nullptr;
|
1222
|
+
aclError err = aclrtMallocHost((void **) &hostPtr, size);
|
1223
|
+
if (err != ACL_SUCCESS) {
|
1224
|
+
GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
|
1225
|
+
size / 1024.0 / 1024.0, aclGetRecentErrMsg());
|
1226
|
+
return nullptr;
|
1227
|
+
}
|
1228
|
+
return hostPtr;
|
1229
|
+
}
|
1230
|
+
|
1231
|
+
/**
|
1232
|
+
* @brief Allocates a new CANN host buffer of the specified type and size.
|
1233
|
+
*
|
1234
|
+
* @param buft Pointer to the host buffer type context.
|
1235
|
+
* @param size Size in bytes of the host buffer to allocate.
|
1236
|
+
* @return Pointer to the allocated host buffer, or CPU buffer pointer if allocation fails.
|
1237
|
+
*/
|
1238
|
+
static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
1239
|
+
void * hostPtr = ggml_cann_host_malloc(size);
|
1240
|
+
|
1241
|
+
if (hostPtr == nullptr) {
|
1242
|
+
// fallback to cpu buffer
|
1243
|
+
return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
|
1244
|
+
}
|
1245
|
+
|
1246
|
+
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size);
|
1247
|
+
buffer->buft = buft;
|
1248
|
+
buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free;
|
1249
|
+
|
1250
|
+
return buffer;
|
1251
|
+
}
|
1252
|
+
|
1253
|
+
/**
|
1254
|
+
* @brief Interface for managing CANN host buffer types in the GGML backend.
|
1255
|
+
*
|
1256
|
+
* Provides function pointers for allocating, querying properties, and managing
|
1257
|
+
* memory for CANN buffer types in the GGML backend.
|
1258
|
+
*/
|
1259
|
+
ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
|
1260
|
+
static struct ggml_backend_buffer_type ggml_backend_cann_buffer_type_host = {
|
1261
|
+
/* .iface = */ {
|
1262
|
+
/* .get_name = */ ggml_backend_cann_host_buffer_type_name,
|
1263
|
+
/* .alloc_buffer = */ ggml_backend_cann_host_buffer_type_alloc_buffer,
|
1264
|
+
/* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
|
1265
|
+
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
1266
|
+
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
|
1267
|
+
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
|
1268
|
+
},
|
1269
|
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
|
1270
|
+
/* .context = */ nullptr,
|
1271
|
+
};
|
1272
|
+
|
1273
|
+
return &ggml_backend_cann_buffer_type_host;
|
1274
|
+
}
|
1275
|
+
|
1276
|
+
/**
|
1277
|
+
* @brief Computes the forward operation for a given tensor using CANN
|
1278
|
+
* operations.
|
1279
|
+
*
|
1280
|
+
* This function selects the appropriate CANN operation based on the type of
|
1281
|
+
* operation specified in the tensor and performs the computation.
|
1282
|
+
*
|
1283
|
+
* @param ctx The CANN context containing necessary resources and
|
1284
|
+
* configurations.
|
1285
|
+
* @param dst The destination tensor where the result of the computation will be
|
1286
|
+
* stored.
|
1287
|
+
* @return true if the computation was successful; false otherwise.
|
1288
|
+
*/
|
1289
|
+
static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
1290
|
+
struct ggml_tensor* dst) {
|
1291
|
+
switch (dst->op) {
|
1292
|
+
case GGML_OP_REPEAT:
|
1293
|
+
ggml_cann_repeat(ctx, dst);
|
1294
|
+
break;
|
1295
|
+
case GGML_OP_GET_ROWS:
|
1296
|
+
ggml_cann_get_rows(ctx, dst);
|
1297
|
+
break;
|
1298
|
+
case GGML_OP_DUP:
|
1299
|
+
ggml_cann_dup(ctx, dst);
|
1300
|
+
break;
|
1301
|
+
case GGML_OP_ADD:
|
1302
|
+
ggml_cann_add(ctx, dst);
|
1303
|
+
break;
|
1304
|
+
case GGML_OP_ACC:
|
1305
|
+
ggml_cann_acc(ctx, dst);
|
1306
|
+
break;
|
1307
|
+
case GGML_OP_MUL:
|
1308
|
+
ggml_cann_mul_div<aclnnMulGetWorkspaceSize, aclnnMul>(ctx, dst);
|
1309
|
+
break;
|
1310
|
+
case GGML_OP_DIV:
|
1311
|
+
ggml_cann_mul_div<aclnnDivGetWorkspaceSize, aclnnDiv>(ctx, dst);
|
1312
|
+
break;
|
1313
|
+
case GGML_OP_UNARY:
|
1314
|
+
switch (ggml_get_unary_op(dst)) {
|
1315
|
+
case GGML_UNARY_OP_GELU:
|
1316
|
+
ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
|
1317
|
+
ctx, dst);
|
1318
|
+
break;
|
1319
|
+
case GGML_UNARY_OP_SILU:
|
1320
|
+
ggml_cann_activation<aclnnSiluGetWorkspaceSize, aclnnSilu>(
|
1321
|
+
ctx, dst);
|
1322
|
+
break;
|
1323
|
+
// TODO: Use faster gelu??
|
1324
|
+
case GGML_UNARY_OP_GELU_QUICK:
|
1325
|
+
ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
|
1326
|
+
ctx, dst);
|
1327
|
+
break;
|
1328
|
+
case GGML_UNARY_OP_TANH:
|
1329
|
+
ggml_cann_activation<aclnnTanhGetWorkspaceSize, aclnnTanh>(
|
1330
|
+
ctx, dst);
|
1331
|
+
break;
|
1332
|
+
case GGML_UNARY_OP_RELU:
|
1333
|
+
ggml_cann_activation<aclnnReluGetWorkspaceSize, aclnnRelu>(
|
1334
|
+
ctx, dst);
|
1335
|
+
break;
|
1336
|
+
case GGML_UNARY_OP_HARDSIGMOID:
|
1337
|
+
ggml_cann_activation<aclnnHardsigmoidGetWorkspaceSize,
|
1338
|
+
aclnnHardsigmoid>(ctx, dst);
|
1339
|
+
break;
|
1340
|
+
case GGML_UNARY_OP_HARDSWISH:
|
1341
|
+
ggml_cann_activation<aclnnHardswishGetWorkspaceSize,
|
1342
|
+
aclnnHardswish>(ctx, dst);
|
1343
|
+
break;
|
1344
|
+
default:
|
1345
|
+
return false;
|
1346
|
+
}
|
1347
|
+
break;
|
1348
|
+
case GGML_OP_NORM:
|
1349
|
+
ggml_cann_norm(ctx, dst);
|
1350
|
+
break;
|
1351
|
+
case GGML_OP_GROUP_NORM:
|
1352
|
+
ggml_cann_group_norm(ctx, dst);
|
1353
|
+
break;
|
1354
|
+
case GGML_OP_CONCAT:
|
1355
|
+
ggml_cann_concat(ctx, dst);
|
1356
|
+
break;
|
1357
|
+
case GGML_OP_UPSCALE:
|
1358
|
+
ggml_cann_upsample_nearest2d(ctx, dst);
|
1359
|
+
break;
|
1360
|
+
case GGML_OP_PAD:
|
1361
|
+
ggml_cann_pad(ctx, dst);
|
1362
|
+
break;
|
1363
|
+
case GGML_OP_ARANGE:
|
1364
|
+
ggml_cann_arange(ctx, dst);
|
1365
|
+
break;
|
1366
|
+
case GGML_OP_TIMESTEP_EMBEDDING:
|
1367
|
+
ggml_cann_timestep_embedding(ctx, dst);
|
1368
|
+
break;
|
1369
|
+
case GGML_OP_LEAKY_RELU:
|
1370
|
+
ggml_cann_leaky_relu(ctx, dst);
|
1371
|
+
break;
|
1372
|
+
case GGML_OP_RMS_NORM:
|
1373
|
+
ggml_cann_rms_norm(ctx, dst);
|
1374
|
+
break;
|
1375
|
+
case GGML_OP_MUL_MAT:
|
1376
|
+
ggml_cann_mul_mat(ctx, dst);
|
1377
|
+
break;
|
1378
|
+
case GGML_OP_MUL_MAT_ID:
|
1379
|
+
return false;
|
1380
|
+
case GGML_OP_SCALE:
|
1381
|
+
ggml_cann_scale(ctx, dst);
|
1382
|
+
break;
|
1383
|
+
case GGML_OP_SQR:
|
1384
|
+
ggml_cann_sqr(ctx, dst);
|
1385
|
+
break;
|
1386
|
+
case GGML_OP_CLAMP:
|
1387
|
+
ggml_cann_clamp(ctx, dst);
|
1388
|
+
break;
|
1389
|
+
case GGML_OP_CPY:
|
1390
|
+
ggml_cann_cpy(ctx, dst);
|
1391
|
+
break;
|
1392
|
+
case GGML_OP_CONT:
|
1393
|
+
ggml_cann_dup(ctx, dst);
|
1394
|
+
break;
|
1395
|
+
case GGML_OP_NONE:
|
1396
|
+
case GGML_OP_RESHAPE:
|
1397
|
+
case GGML_OP_VIEW:
|
1398
|
+
case GGML_OP_PERMUTE:
|
1399
|
+
case GGML_OP_TRANSPOSE:
|
1400
|
+
break;
|
1401
|
+
case GGML_OP_DIAG_MASK_INF:
|
1402
|
+
ggml_cann_diag_mask(ctx, dst, -INFINITY);
|
1403
|
+
break;
|
1404
|
+
case GGML_OP_SOFT_MAX:
|
1405
|
+
ggml_cann_softmax(ctx, dst);
|
1406
|
+
break;
|
1407
|
+
case GGML_OP_ROPE:
|
1408
|
+
ggml_cann_rope(ctx, dst);
|
1409
|
+
break;
|
1410
|
+
case GGML_OP_IM2COL:
|
1411
|
+
ggml_cann_im2col(ctx, dst);
|
1412
|
+
break;
|
1413
|
+
case GGML_OP_POOL_2D:
|
1414
|
+
ggml_cann_pool2d(ctx, dst);
|
1415
|
+
break;
|
1416
|
+
case GGML_OP_SUM_ROWS:
|
1417
|
+
ggml_cann_sum_rows(ctx, dst);
|
1418
|
+
break;
|
1419
|
+
case GGML_OP_ARGSORT:
|
1420
|
+
ggml_cann_argsort(ctx, dst);
|
1421
|
+
break;
|
1422
|
+
default:
|
1423
|
+
return false;
|
1424
|
+
}
|
1425
|
+
|
1426
|
+
return true;
|
1427
|
+
}
|
1428
|
+
|
1429
|
+
// backend
|
1430
|
+
/**
|
1431
|
+
* @brief Retrieves the name associated with the CANN backend.
|
1432
|
+
*
|
1433
|
+
* This function returns the name assigned to the CANN backend, which is stored
|
1434
|
+
* in the context of the provided backend structure.
|
1435
|
+
*
|
1436
|
+
* @param backend Pointer to the CANN backend structure.
|
1437
|
+
* @return A pointer to a constant string representing the backend name.
|
1438
|
+
*/
|
1439
|
+
static const char* ggml_backend_cann_name(ggml_backend_t backend) {
|
1440
|
+
ggml_backend_cann_context* cann_ctx =
|
1441
|
+
(ggml_backend_cann_context*)backend->context;
|
1442
|
+
|
1443
|
+
return cann_ctx->name.c_str();
|
1444
|
+
}
|
1445
|
+
|
1446
|
+
/**
|
1447
|
+
* @brief Frees resources associated with the CANN backend.
|
1448
|
+
*
|
1449
|
+
* This function releases resources associated with the CANN backend context
|
1450
|
+
* and resets the device associated with the backend to its initial state.
|
1451
|
+
*
|
1452
|
+
* @param backend Pointer to the CANN backend structure to be freed.
|
1453
|
+
*/
|
1454
|
+
static void ggml_backend_cann_free(ggml_backend_t backend) {
|
1455
|
+
ggml_backend_cann_context* cann_ctx =
|
1456
|
+
(ggml_backend_cann_context*)backend->context;
|
1457
|
+
ACL_CHECK(aclrtSynchronizeDevice());
|
1458
|
+
ACL_CHECK(aclrtResetDevice(cann_ctx->device));
|
1459
|
+
|
1460
|
+
// finalize when last backend freed.
|
1461
|
+
if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
|
1462
|
+
ACL_CHECK(aclFinalize());
|
1463
|
+
}
|
1464
|
+
|
1465
|
+
delete cann_ctx;
|
1466
|
+
delete backend;
|
1467
|
+
}
|
1468
|
+
|
1469
|
+
/**
|
1470
|
+
* @brief Sets tensor data asynchronously in the CANN backend.
|
1471
|
+
*
|
1472
|
+
* This function asynchronously sets tensor data in the CANN backend. Depending
|
1473
|
+
* on the tensor type, it may perform data transformations before copying data
|
1474
|
+
* to the device.
|
1475
|
+
*
|
1476
|
+
* @param backend Pointer to the CANN backend structure.
|
1477
|
+
* @param tensor Pointer to the tensor structure to set data for.
|
1478
|
+
* @param data Pointer to the host data to copy to the tensor.
|
1479
|
+
* @param offset Offset in bytes within the host data.
|
1480
|
+
* @param size Size of the data to copy in bytes.
|
1481
|
+
*/
|
1482
|
+
static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
|
1483
|
+
ggml_tensor *tensor,
|
1484
|
+
const void *data,
|
1485
|
+
size_t offset,
|
1486
|
+
size_t size) {
|
1487
|
+
ggml_backend_cann_context *cann_ctx =
|
1488
|
+
(ggml_backend_cann_context *)backend->context;
|
1489
|
+
|
1490
|
+
if (!need_transform(tensor->type)) {
|
1491
|
+
ACL_CHECK(aclrtMemcpyAsync((char *)tensor->data + offset, size, data,
|
1492
|
+
size, ACL_MEMCPY_HOST_TO_DEVICE,
|
1493
|
+
cann_ctx->stream()));
|
1494
|
+
} else {
|
1495
|
+
void *transform_buffer = malloc(size);
|
1496
|
+
ggml_backend_cann_transform(tensor, data, transform_buffer);
|
1497
|
+
|
1498
|
+
ACL_CHECK(aclrtMemcpyAsync(
|
1499
|
+
(char *)tensor->data + offset, size, transform_buffer, size,
|
1500
|
+
ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
|
1501
|
+
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
1502
|
+
free(transform_buffer);
|
1503
|
+
}
|
1504
|
+
}
|
1505
|
+
|
1506
|
+
static void ggml_backend_cann_get_tensor_async(
|
1507
|
+
ggml_backend_t backend, const ggml_tensor *tensor, void *data,
|
1508
|
+
size_t offset, size_t size) {
|
1509
|
+
ggml_backend_cann_context *cann_ctx =
|
1510
|
+
(ggml_backend_cann_context *)backend->context;
|
1511
|
+
ggml_backend_buffer_t buf =
|
1512
|
+
tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
1513
|
+
|
1514
|
+
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
|
1515
|
+
"unsupported buffer type");
|
1516
|
+
|
1517
|
+
if (!need_transform(tensor->type)) {
|
1518
|
+
ACL_CHECK(aclrtMemcpyAsync(data, size, (char *)tensor->data + offset,
|
1519
|
+
size, ACL_MEMCPY_DEVICE_TO_HOST,
|
1520
|
+
cann_ctx->stream()));
|
1521
|
+
} else {
|
1522
|
+
void *transform_buffer = malloc(size);
|
1523
|
+
ACL_CHECK(aclrtMemcpyAsync(
|
1524
|
+
transform_buffer, size, (char *)tensor->data + offset, size,
|
1525
|
+
ACL_MEMCPY_DEVICE_TO_HOST, cann_ctx->stream()));
|
1526
|
+
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
1527
|
+
ggml_backend_cann_transform_back(tensor, transform_buffer, data);
|
1528
|
+
free(transform_buffer);
|
1529
|
+
}
|
1530
|
+
}
|
1531
|
+
|
1532
|
+
/**
|
1533
|
+
* @brief Asynchronously copies tensor data between CANN backends.
|
1534
|
+
*
|
1535
|
+
* This function copies tensor data asynchronously between two CANN backends. It
|
1536
|
+
* checks if both tensors reside in CANN buffers and whether the devices support
|
1537
|
+
* peer-to-peer access for direct copying. If not, it returns false.
|
1538
|
+
*
|
1539
|
+
* @param backend_src Pointer to the source CANN backend structure.
|
1540
|
+
* @param backend_dst Pointer to the destination CANN backend structure.
|
1541
|
+
* @param src Pointer to the source tensor to copy data from.
|
1542
|
+
* @param dst Pointer to the destination tensor to copy data to.
|
1543
|
+
* @return true if the copy operation succeeds, false otherwise.
|
1544
|
+
*/
|
1545
|
+
static bool ggml_backend_cann_cpy_tensor_async(
|
1546
|
+
ggml_backend_t backend_src, ggml_backend_t backend_dst,
|
1547
|
+
const ggml_tensor* src, ggml_tensor* dst) {
|
1548
|
+
GGML_ASSERT(ggml_backend_is_cann(backend_src) ||
|
1549
|
+
ggml_backend_is_cann(backend_dst));
|
1550
|
+
|
1551
|
+
if (!ggml_backend_buffer_is_cann(src->buffer) ||
|
1552
|
+
!ggml_backend_buffer_is_cann(dst->buffer)) {
|
1553
|
+
return false;
|
1554
|
+
}
|
1555
|
+
|
1556
|
+
ggml_backend_buffer_t buf_src =
|
1557
|
+
src->view_src ? src->view_src->buffer : src->buffer;
|
1558
|
+
ggml_backend_buffer_t buf_dst =
|
1559
|
+
dst->view_src ? dst->view_src->buffer : dst->buffer;
|
1560
|
+
|
1561
|
+
ggml_backend_cann_context* cann_ctx_src =
|
1562
|
+
(ggml_backend_cann_context*)backend_src->context;
|
1563
|
+
ggml_backend_cann_context* cann_ctx_dst =
|
1564
|
+
(ggml_backend_cann_context*)backend_dst->context;
|
1565
|
+
|
1566
|
+
size_t copy_size = ggml_nbytes(dst);
|
1567
|
+
if (backend_src != backend_dst) {
|
1568
|
+
ggml_backend_cann_buffer_context* buf_ctx_src =
|
1569
|
+
(ggml_backend_cann_buffer_context*)buf_src->context;
|
1570
|
+
ggml_backend_cann_buffer_context* buf_ctx_dst =
|
1571
|
+
(ggml_backend_cann_buffer_context*)buf_dst->context;
|
1572
|
+
|
1573
|
+
GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device);
|
1574
|
+
GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device);
|
1575
|
+
|
1576
|
+
int32_t canAccessPeer = 0;
|
1577
|
+
ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device,
|
1578
|
+
cann_ctx_dst->device));
|
1579
|
+
if (!canAccessPeer) {
|
1580
|
+
return false;
|
1581
|
+
}
|
1582
|
+
|
1583
|
+
// need open both directions for memcpyasync between devices.
|
1584
|
+
ggml_cann_set_device(cann_ctx_dst->device);
|
1585
|
+
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));
|
1586
|
+
ggml_cann_set_device(cann_ctx_src->device);
|
1587
|
+
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
|
1588
|
+
|
1589
|
+
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
|
1590
|
+
ACL_MEMCPY_DEVICE_TO_DEVICE,
|
1591
|
+
cann_ctx_src->stream()));
|
1592
|
+
|
1593
|
+
//TODO: workaround for Event didn`t work here.
|
1594
|
+
aclrtSynchronizeStream(cann_ctx_src->stream());
|
1595
|
+
} else {
|
1596
|
+
// src and dst are on the same backend
|
1597
|
+
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
|
1598
|
+
ACL_MEMCPY_DEVICE_TO_DEVICE,
|
1599
|
+
cann_ctx_dst->stream()));
|
1600
|
+
}
|
1601
|
+
|
1602
|
+
return true;
|
1603
|
+
}
|
1604
|
+
|
1605
|
+
/**
|
1606
|
+
* @brief Synchronizes a CANN backend.
|
1607
|
+
*
|
1608
|
+
* This function synchronizes the specified CANN backend by waiting for all
|
1609
|
+
* operations in its associated stream to complete.
|
1610
|
+
*
|
1611
|
+
* @param backend Pointer to the CANN backend structure to synchronize.
|
1612
|
+
*/
|
1613
|
+
static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
|
1614
|
+
ggml_backend_cann_context* cann_ctx =
|
1615
|
+
(ggml_backend_cann_context*)backend->context;
|
1616
|
+
|
1617
|
+
ggml_cann_set_device(cann_ctx->device);
|
1618
|
+
|
1619
|
+
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
1620
|
+
}
|
1621
|
+
|
1622
|
+
/**
|
1623
|
+
* @brief Computes a computational graph using a CANN backend.
|
1624
|
+
*
|
1625
|
+
* This function computes the operations defined in the computational graph
|
1626
|
+
* using the specified CANN backend.
|
1627
|
+
*
|
1628
|
+
* @param backend Pointer to the CANN backend structure to use for computation.
|
1629
|
+
* @param cgraph Pointer to the computational graph structure containing nodes
|
1630
|
+
* representing operations to be computed.
|
1631
|
+
* @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation
|
1632
|
+
* completes successfully, otherwise an appropriate error status.
|
1633
|
+
*/
|
1634
|
+
static enum ggml_status ggml_backend_cann_graph_compute(
|
1635
|
+
ggml_backend_t backend, ggml_cgraph* cgraph) {
|
1636
|
+
ggml_backend_cann_context* cann_ctx =
|
1637
|
+
(ggml_backend_cann_context*)backend->context;
|
1638
|
+
|
1639
|
+
ggml_cann_set_device(cann_ctx->device);
|
1640
|
+
|
1641
|
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
1642
|
+
ggml_tensor* node = cgraph->nodes[i];
|
1643
|
+
|
1644
|
+
if (ggml_is_empty(node) || node->op == GGML_OP_NONE) {
|
1645
|
+
continue;
|
1646
|
+
}
|
1647
|
+
|
1648
|
+
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
|
1649
|
+
|
1650
|
+
if (!ok) {
|
1651
|
+
GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
|
1652
|
+
node->name, ggml_op_name(node->op));
|
1653
|
+
}
|
1654
|
+
GGML_ASSERT(ok);
|
1655
|
+
}
|
1656
|
+
|
1657
|
+
return GGML_STATUS_SUCCESS;
|
1658
|
+
}
|
1659
|
+
|
1660
|
+
/**
|
1661
|
+
* @brief Checks if the CANN backend supports a specific operation.
|
1662
|
+
*
|
1663
|
+
* This function checks whether the specified operation is supported by the
|
1664
|
+
* CANN backend.
|
1665
|
+
*
|
1666
|
+
* @param backend Pointer to the CANN backend structure to check support for
|
1667
|
+
* the operation.
|
1668
|
+
* @param op Pointer to the tensor representing the operation to check.
|
1669
|
+
* @return bool Returns true if the operation is supported by the backend,
|
1670
|
+
* otherwise false.
|
1671
|
+
*/
|
1672
|
+
static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
1673
|
+
const ggml_tensor* op) {
|
1674
|
+
switch (op->op) {
|
1675
|
+
case GGML_OP_UNARY:
|
1676
|
+
switch (ggml_get_unary_op(op)) {
|
1677
|
+
case GGML_UNARY_OP_GELU:
|
1678
|
+
case GGML_UNARY_OP_SILU:
|
1679
|
+
case GGML_UNARY_OP_RELU:
|
1680
|
+
case GGML_UNARY_OP_HARDSIGMOID:
|
1681
|
+
case GGML_UNARY_OP_HARDSWISH:
|
1682
|
+
case GGML_UNARY_OP_GELU_QUICK:
|
1683
|
+
case GGML_UNARY_OP_TANH:
|
1684
|
+
return true;
|
1685
|
+
default:
|
1686
|
+
return false;
|
1687
|
+
}
|
1688
|
+
case GGML_OP_MUL_MAT: {
|
1689
|
+
switch (op->src[0]->type) {
|
1690
|
+
case GGML_TYPE_Q8_0:
|
1691
|
+
// Current groupsize should not be greater than k-1 in
|
1692
|
+
// aclnnWeightQuantBatchMatmulV2GetWorkspaceSize
|
1693
|
+
if (op->src[0]->ne[0] <= QK8_0) {
|
1694
|
+
return false;
|
1695
|
+
}
|
1696
|
+
case GGML_TYPE_F16:
|
1697
|
+
case GGML_TYPE_F32:
|
1698
|
+
case GGML_TYPE_Q4_0:
|
1699
|
+
return true;
|
1700
|
+
default:
|
1701
|
+
return false;
|
1702
|
+
}
|
1703
|
+
}
|
1704
|
+
case GGML_OP_MUL_MAT_ID:
|
1705
|
+
return false;
|
1706
|
+
// embedding
|
1707
|
+
case GGML_OP_GET_ROWS: {
|
1708
|
+
switch (op->src[0]->type) {
|
1709
|
+
case GGML_TYPE_F32:
|
1710
|
+
case GGML_TYPE_F16:
|
1711
|
+
case GGML_TYPE_Q4_0:
|
1712
|
+
case GGML_TYPE_Q8_0:
|
1713
|
+
return true;
|
1714
|
+
default:
|
1715
|
+
return false;
|
1716
|
+
}
|
1717
|
+
} break;
|
1718
|
+
case GGML_OP_CPY: {
|
1719
|
+
switch (op->type) {
|
1720
|
+
case GGML_TYPE_F32:
|
1721
|
+
case GGML_TYPE_F16:
|
1722
|
+
case GGML_TYPE_Q8_0:
|
1723
|
+
case GGML_TYPE_Q4_0:
|
1724
|
+
return true;
|
1725
|
+
default:
|
1726
|
+
return false;
|
1727
|
+
}
|
1728
|
+
}
|
1729
|
+
case GGML_OP_CONT: {
|
1730
|
+
// TODO: support GGML_TYPE_BF16
|
1731
|
+
switch (op->src[0]->type) {
|
1732
|
+
case GGML_TYPE_F32:
|
1733
|
+
case GGML_TYPE_F16:
|
1734
|
+
return true;
|
1735
|
+
default:
|
1736
|
+
return false;
|
1737
|
+
}
|
1738
|
+
}
|
1739
|
+
case GGML_OP_ROPE: {
|
1740
|
+
// TODO: with ops-test v == 1
|
1741
|
+
float * ext_factor = (float*)((int32_t*)op->op_params + 7);
|
1742
|
+
// TODO: n_dims <= ne0
|
1743
|
+
if (op->src[0]->ne[0] != op->op_params[1]) {
|
1744
|
+
return false;
|
1745
|
+
}
|
1746
|
+
// TODO: ext_factor != 0
|
1747
|
+
if (*ext_factor != 0) {
|
1748
|
+
return false;
|
1749
|
+
}
|
1750
|
+
|
1751
|
+
const int mode = ((const int32_t *) op->op_params)[2];
|
1752
|
+
if (mode & GGML_ROPE_TYPE_MROPE) {
|
1753
|
+
return false;
|
1754
|
+
}
|
1755
|
+
if (mode & GGML_ROPE_TYPE_VISION) {
|
1756
|
+
return false;
|
1757
|
+
}
|
1758
|
+
|
1759
|
+
return true;
|
1760
|
+
}
|
1761
|
+
case GGML_OP_UPSCALE: {
|
1762
|
+
// aclnnUpsampleNearest2dGetWorkspaceSize not support
|
1763
|
+
// selfDimN[2]/outDimN[2] or selfDimC[3]/outDimC[3] not equal
|
1764
|
+
if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) {
|
1765
|
+
return false;
|
1766
|
+
}
|
1767
|
+
return true;
|
1768
|
+
}
|
1769
|
+
case GGML_OP_IM2COL:
|
1770
|
+
case GGML_OP_CONCAT:
|
1771
|
+
case GGML_OP_DUP:
|
1772
|
+
case GGML_OP_REPEAT:
|
1773
|
+
case GGML_OP_NONE:
|
1774
|
+
case GGML_OP_RESHAPE:
|
1775
|
+
case GGML_OP_VIEW:
|
1776
|
+
case GGML_OP_PERMUTE:
|
1777
|
+
case GGML_OP_TRANSPOSE:
|
1778
|
+
case GGML_OP_NORM:
|
1779
|
+
case GGML_OP_ADD:
|
1780
|
+
case GGML_OP_MUL:
|
1781
|
+
case GGML_OP_DIV:
|
1782
|
+
case GGML_OP_RMS_NORM:
|
1783
|
+
case GGML_OP_SCALE:
|
1784
|
+
case GGML_OP_SQR:
|
1785
|
+
case GGML_OP_CLAMP:
|
1786
|
+
case GGML_OP_DIAG_MASK_INF:
|
1787
|
+
case GGML_OP_SOFT_MAX:
|
1788
|
+
case GGML_OP_POOL_2D:
|
1789
|
+
case GGML_OP_SUM_ROWS:
|
1790
|
+
case GGML_OP_ARGSORT:
|
1791
|
+
case GGML_OP_ACC:
|
1792
|
+
case GGML_OP_GROUP_NORM:
|
1793
|
+
case GGML_OP_PAD:
|
1794
|
+
case GGML_OP_ARANGE:
|
1795
|
+
case GGML_OP_TIMESTEP_EMBEDDING:
|
1796
|
+
case GGML_OP_LEAKY_RELU:
|
1797
|
+
return true;
|
1798
|
+
default:
|
1799
|
+
return false;
|
1800
|
+
}
|
1801
|
+
|
1802
|
+
GGML_UNUSED(dev);
|
1803
|
+
}
|
1804
|
+
|
1805
|
+
/**
|
1806
|
+
* @brief Checks if the backend buffer type is associated with the CANN backend.
|
1807
|
+
*
|
1808
|
+
* This function checks whether the provided backend buffer type is associated
|
1809
|
+
* with the CANN backend based on the comparison of its name retrieval function
|
1810
|
+
* pointer.
|
1811
|
+
*
|
1812
|
+
* @param buft Pointer to the backend buffer type to check.
|
1813
|
+
* @return bool Returns true if the buffer type is associated with the CANN
|
1814
|
+
* backend, otherwise false.
|
1815
|
+
*/
|
1816
|
+
static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
|
1817
|
+
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
|
1818
|
+
}
|
1819
|
+
|
1820
|
+
/**
|
1821
|
+
* @brief Determines if a tensor operation should be offloaded to the CANN
|
1822
|
+
* backend.
|
1823
|
+
*
|
1824
|
+
* This function checks if a given tensor operation should be offloaded to the
|
1825
|
+
* CANN backend based on the operation type and the size of the tensor. It
|
1826
|
+
* returns true if the second dimension (ne[1]) of the tensor is greater than or
|
1827
|
+
* equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
|
1828
|
+
*
|
1829
|
+
* @param backend Pointer to the CANN backend.
|
1830
|
+
* @param op Pointer to the tensor operation to check.
|
1831
|
+
* @return bool Returns true if the operation should be offloaded, otherwise
|
1832
|
+
* false.
|
1833
|
+
*/
|
1834
|
+
static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev,
|
1835
|
+
const ggml_tensor* op) {
|
1836
|
+
const int min_batch_size = 32;
|
1837
|
+
GGML_UNUSED(dev);
|
1838
|
+
|
1839
|
+
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
|
1840
|
+
}
|
1841
|
+
|
1842
|
+
/**
|
1843
|
+
* @brief Records an event on the CANN backend stream.
|
1844
|
+
*
|
1845
|
+
* This function records the given event on the ACL runtime stream associated
|
1846
|
+
* with the backend context.
|
1847
|
+
*
|
1848
|
+
* @param event Pointer to the event structure to be recorded.
|
1849
|
+
*/
|
1850
|
+
static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
|
1851
|
+
ggml_backend_cann_context* cann_ctx =
|
1852
|
+
(ggml_backend_cann_context*)backend->context;
|
1853
|
+
ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream()));
|
1854
|
+
}
|
1855
|
+
|
1856
|
+
/**
|
1857
|
+
* @brief Waits for a recorded event to complete on the CANN backend stream.
|
1858
|
+
*
|
1859
|
+
* This function makes the given backend wait for the event to complete on its
|
1860
|
+
* ACL runtime stream.
|
1861
|
+
*
|
1862
|
+
* @param backend Pointer to the backend structure.
|
1863
|
+
* @param event Pointer to the event structure that the backend needs to wait
|
1864
|
+
* for.
|
1865
|
+
*/
|
1866
|
+
static void ggml_backend_cann_event_wait(ggml_backend_t backend,
|
1867
|
+
ggml_backend_event_t event) {
|
1868
|
+
ggml_backend_cann_context* cann_ctx =
|
1869
|
+
(ggml_backend_cann_context*)backend->context;
|
1870
|
+
if (ggml_backend_is_cann(backend)) {
|
1871
|
+
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(),
|
1872
|
+
(aclrtEvent)event->context));
|
1873
|
+
} else {
|
1874
|
+
GGML_ABORT("fatal error");
|
1875
|
+
}
|
1876
|
+
}
|
1877
|
+
|
1878
|
+
/**
|
1879
|
+
* @brief Structure defining the interface for the CANN backend.
|
1880
|
+
*
|
1881
|
+
* This structure contains function pointers for various operations
|
1882
|
+
* supported by the CANN backend, including name retrieval, memory
|
1883
|
+
* management, tensor operations, synchronization, and event handling.
|
1884
|
+
*/
|
1885
|
+
static const ggml_backend_i ggml_backend_cann_interface = {
|
1886
|
+
/* .get_name = */ ggml_backend_cann_name,
|
1887
|
+
/* .free = */ ggml_backend_cann_free,
|
1888
|
+
/* .set_tensor_async = */ ggml_backend_cann_set_tensor_async,
|
1889
|
+
/* .get_tensor_async = */ ggml_backend_cann_get_tensor_async,
|
1890
|
+
/* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async,
|
1891
|
+
/* .synchronize = */ ggml_backend_cann_synchronize,
|
1892
|
+
/* .graph_plan_create = */ NULL,
|
1893
|
+
/* .graph_plan_free = */ NULL,
|
1894
|
+
/* .graph_plan_update = */ NULL,
|
1895
|
+
/* .graph_plan_compute = */ NULL,
|
1896
|
+
/* .graph_compute = */ ggml_backend_cann_graph_compute,
|
1897
|
+
/* .event_record = */ ggml_backend_cann_event_record,
|
1898
|
+
/* .event_wait = */ ggml_backend_cann_event_wait,
|
1899
|
+
};
|
1900
|
+
|
1901
|
+
/**
|
1902
|
+
* @brief Return the hardcoded GUID for the CANN backend.
|
1903
|
+
*
|
1904
|
+
* This function returns a static GUID which uniquely identifies the CANN
|
1905
|
+
* backend.
|
1906
|
+
*
|
1907
|
+
* @return A pointer to the static GUID.
|
1908
|
+
*/
|
1909
|
+
static ggml_guid_t ggml_backend_cann_guid() {
|
1910
|
+
static ggml_guid guid = {0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,
|
1911
|
+
0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64};
|
1912
|
+
return &guid;
|
1913
|
+
}
|
1914
|
+
|
1915
|
+
// backend device
|
1916
|
+
struct ggml_backend_cann_device_context {
|
1917
|
+
int device;
|
1918
|
+
std::string name;
|
1919
|
+
std::string description;
|
1920
|
+
};
|
1921
|
+
|
1922
|
+
static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
|
1923
|
+
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
1924
|
+
return ctx->name.c_str();
|
1925
|
+
}
|
1926
|
+
|
1927
|
+
static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
|
1928
|
+
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
1929
|
+
return ctx->description.c_str();
|
1930
|
+
}
|
1931
|
+
|
1932
|
+
static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
1933
|
+
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
1934
|
+
ggml_backend_cann_get_device_memory(ctx->device, free, total);
|
1935
|
+
}
|
1936
|
+
|
1937
|
+
static enum ggml_backend_dev_type ggml_backend_cann_device_get_type(ggml_backend_dev_t dev) {
|
1938
|
+
GGML_UNUSED(dev);
|
1939
|
+
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
1940
|
+
}
|
1941
|
+
|
1942
|
+
static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
1943
|
+
props->name = ggml_backend_cann_device_get_name(dev);
|
1944
|
+
props->description = ggml_backend_cann_device_get_description(dev);
|
1945
|
+
props->type = ggml_backend_cann_device_get_type(dev);
|
1946
|
+
ggml_backend_cann_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
1947
|
+
|
1948
|
+
bool host_buffer = getenv("GGML_CANN_NO_PINNED") == nullptr;
|
1949
|
+
|
1950
|
+
props->caps = {
|
1951
|
+
/* .async = */ false,
|
1952
|
+
/* .host_buffer = */ host_buffer,
|
1953
|
+
/* .buffer_from_host_ptr = */ false,
|
1954
|
+
/* .events = */ true,
|
1955
|
+
};
|
1956
|
+
}
|
1957
|
+
|
1958
|
+
static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {
|
1959
|
+
GGML_UNUSED(params);
|
1960
|
+
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
1961
|
+
return ggml_backend_cann_init(ctx->device);
|
1962
|
+
}
|
1963
|
+
|
1964
|
+
/**
|
1965
|
+
* @brief Checks if the CANN backend supports a specific backend buffer type.
|
1966
|
+
*
|
1967
|
+
* This function determines whether the CANN backend supports the given backend
|
1968
|
+
* buffer type by comparing the device context of the backend and buffer type.
|
1969
|
+
* It returns true if the devices are same between the backend context and
|
1970
|
+
* buffer type context.
|
1971
|
+
*
|
1972
|
+
* @param backend Pointer to the CANN backend.
|
1973
|
+
* @param buft Pointer to the backend buffer type to check.
|
1974
|
+
* @return bool Returns true if the CANN backend supports the buffer type,
|
1975
|
+
* otherwise false.
|
1976
|
+
*/
|
1977
|
+
static bool ggml_backend_cann_supports_buft(
|
1978
|
+
ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
1979
|
+
if (ggml_backend_buft_is_cann(buft)) {
|
1980
|
+
ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
|
1981
|
+
ggml_backend_cann_buffer_type_context * buft_ctx =
|
1982
|
+
(ggml_backend_cann_buffer_type_context *)buft->context;
|
1983
|
+
return buft_ctx->device == dev_ctx->device;
|
1984
|
+
}
|
1985
|
+
return false;
|
1986
|
+
}
|
1987
|
+
|
1988
|
+
static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {
|
1989
|
+
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
1990
|
+
return ggml_backend_cann_buffer_type(ctx->device);
|
1991
|
+
}
|
1992
|
+
|
1993
|
+
static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(ggml_backend_dev_t dev) {
|
1994
|
+
GGML_UNUSED(dev);
|
1995
|
+
return ggml_backend_cann_host_buffer_type();
|
1996
|
+
}
|
1997
|
+
|
1998
|
+
/**
|
1999
|
+
* @brief Creates a new event for the CANN backend device.
|
2000
|
+
*
|
2001
|
+
* This function initializes a new event for the CANN backend by setting the
|
2002
|
+
* device and creating an ACL runtime event. The created event is then wrapped
|
2003
|
+
* in a ggml_backend_event structure and returned.
|
2004
|
+
*
|
2005
|
+
* @param backend Pointer to the CANN backend.
|
2006
|
+
* @return ggml_backend_event_t Returns a pointer to the new event structure.
|
2007
|
+
*/
|
2008
|
+
static ggml_backend_event_t ggml_backend_cann_device_event_new(
|
2009
|
+
ggml_backend_dev_t dev) {
|
2010
|
+
ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
|
2011
|
+
|
2012
|
+
ggml_cann_set_device(dev_ctx->device);
|
2013
|
+
|
2014
|
+
aclrtEvent event;
|
2015
|
+
ACL_CHECK(aclrtCreateEvent(&event));
|
2016
|
+
|
2017
|
+
return new ggml_backend_event{
|
2018
|
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), dev_ctx->device),
|
2019
|
+
/* .context = */ event,
|
2020
|
+
};
|
2021
|
+
}
|
2022
|
+
|
2023
|
+
/**
|
2024
|
+
* @brief Frees a CANN backend event.
|
2025
|
+
*
|
2026
|
+
* This function destroys the ACL runtime event associated with the given CANN
|
2027
|
+
* backend event and then deletes the event structure itself.
|
2028
|
+
*
|
2029
|
+
* @param event Pointer to the event structure to be freed.
|
2030
|
+
*/
|
2031
|
+
static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
|
2032
|
+
ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
|
2033
|
+
|
2034
|
+
delete event;
|
2035
|
+
GGML_UNUSED(dev);
|
2036
|
+
}
|
2037
|
+
|
2038
|
+
/**
|
2039
|
+
* @brief Synchronizes the given event on the CANN backend.
|
2040
|
+
*
|
2041
|
+
* This function waits for the specified event to complete on the ACL runtime.
|
2042
|
+
*
|
2043
|
+
* @param event Pointer to the event structure to be synchronized.
|
2044
|
+
*/
|
2045
|
+
static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
|
2046
|
+
ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
|
2047
|
+
|
2048
|
+
GGML_UNUSED(dev);
|
2049
|
+
}
|
2050
|
+
|
2051
|
+
static const ggml_backend_device_i ggml_backend_cann_device_interface = {
|
2052
|
+
/* .get_name = */ ggml_backend_cann_device_get_name,
|
2053
|
+
/* .get_description = */ ggml_backend_cann_device_get_description,
|
2054
|
+
/* .get_memory = */ ggml_backend_cann_device_get_memory,
|
2055
|
+
/* .get_type = */ ggml_backend_cann_device_get_type,
|
2056
|
+
/* .get_props = */ ggml_backend_cann_device_get_props,
|
2057
|
+
/* .init_backend = */ ggml_backend_cann_device_init, // called for every card
|
2058
|
+
/* .get_buffer_type = */ ggml_backend_cann_device_get_buffer_type,
|
2059
|
+
/* .get_host_buffer_type = */ ggml_backend_cann_device_get_host_buffer_type,
|
2060
|
+
/* .buffer_from_host_ptr = */ NULL, // not supported for CANN
|
2061
|
+
/* .supports_op = */ ggml_backend_cann_supports_op,
|
2062
|
+
/* .supports_buft = */ ggml_backend_cann_supports_buft,
|
2063
|
+
/* .offload_op = */ ggml_backend_cann_offload_op,
|
2064
|
+
/* .event_new = */ ggml_backend_cann_device_event_new,
|
2065
|
+
/* .event_free = */ ggml_backend_cann_device_event_free,
|
2066
|
+
/* .event_synchronize = */ ggml_backend_cann_device_event_synchronize,
|
2067
|
+
};
|
2068
|
+
|
2069
|
+
|
2070
|
+
// backend reg
|
2071
|
+
struct ggml_backend_cann_reg_context {
|
2072
|
+
std::vector<ggml_backend_dev_t> devices;
|
2073
|
+
};
|
2074
|
+
|
2075
|
+
static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) {
|
2076
|
+
GGML_UNUSED(reg);
|
2077
|
+
return GGML_CANN_NAME;
|
2078
|
+
}
|
2079
|
+
|
2080
|
+
static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) {
|
2081
|
+
ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
|
2082
|
+
return ctx->devices.size();
|
2083
|
+
}
|
2084
|
+
|
2085
|
+
static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
2086
|
+
ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
|
2087
|
+
GGML_ASSERT(index < ctx->devices.size());
|
2088
|
+
return ctx->devices[index];
|
2089
|
+
}
|
2090
|
+
|
2091
|
+
static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
2092
|
+
GGML_UNUSED(reg);
|
2093
|
+
GGML_UNUSED(name);
|
2094
|
+
// reserved for future use
|
2095
|
+
return nullptr;
|
2096
|
+
}
|
2097
|
+
|
2098
|
+
static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
|
2099
|
+
/* .get_name = */ ggml_backend_cann_reg_get_name,
|
2100
|
+
/* .get_device_count = */ ggml_backend_cann_reg_get_device_count,
|
2101
|
+
/* .get_device = */ ggml_backend_cann_reg_get_device,
|
2102
|
+
/* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address,
|
2103
|
+
};
|
2104
|
+
|
2105
|
+
// backend registry, called only once for cann backend
|
2106
|
+
ggml_backend_reg_t ggml_backend_cann_reg() {
|
2107
|
+
static ggml_backend_reg reg;
|
2108
|
+
static bool initialized = false;
|
2109
|
+
|
2110
|
+
{
|
2111
|
+
static std::mutex mutex;
|
2112
|
+
std::lock_guard<std::mutex> lock(mutex);
|
2113
|
+
if (!initialized) {
|
2114
|
+
aclInit(nullptr);
|
2115
|
+
ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
|
2116
|
+
|
2117
|
+
for (int i = 0; i < ggml_cann_info().device_count; i++) {
|
2118
|
+
ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context();
|
2119
|
+
dev_ctx->description = aclrtGetSocName();
|
2120
|
+
dev_ctx->device = i;
|
2121
|
+
dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
|
2122
|
+
ggml_cann_set_device(i);
|
2123
|
+
ggml_backend_dev_t dev = new ggml_backend_device {
|
2124
|
+
/* .iface = */ ggml_backend_cann_device_interface,
|
2125
|
+
/* .reg = */ ®,
|
2126
|
+
/* .context = */ dev_ctx
|
2127
|
+
};
|
2128
|
+
ctx->devices.push_back(dev);
|
2129
|
+
}
|
2130
|
+
|
2131
|
+
reg = ggml_backend_reg {
|
2132
|
+
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
2133
|
+
/* .iface = */ ggml_backend_cann_reg_interface,
|
2134
|
+
/* .context = */ ctx
|
2135
|
+
};
|
2136
|
+
}
|
2137
|
+
|
2138
|
+
initialized = true;
|
2139
|
+
}
|
2140
|
+
|
2141
|
+
return ®
|
2142
|
+
}
|
2143
|
+
|
2144
|
+
ggml_backend_t ggml_backend_cann_init(int32_t device) {
|
2145
|
+
aclInit(nullptr);
|
2146
|
+
if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
|
2147
|
+
GGML_LOG_ERROR("%s: error: invalid device %d\n", __func__, device);
|
2148
|
+
return nullptr;
|
2149
|
+
}
|
2150
|
+
|
2151
|
+
ggml_backend_cann_context* ctx = new ggml_backend_cann_context(device);
|
2152
|
+
if (ctx == nullptr) {
|
2153
|
+
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
|
2154
|
+
return nullptr;
|
2155
|
+
}
|
2156
|
+
ggml_cann_set_device(ctx->device);
|
2157
|
+
ggml_backend_t cann_backend =
|
2158
|
+
new ggml_backend{/* .guid = */ ggml_backend_cann_guid(),
|
2159
|
+
/* .interface = */ ggml_backend_cann_interface,
|
2160
|
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
|
2161
|
+
/* .context = */ ctx};
|
2162
|
+
|
2163
|
+
return cann_backend;
|
2164
|
+
}
|
2165
|
+
|
2166
|
+
bool ggml_backend_is_cann(ggml_backend_t backend) {
|
2167
|
+
return backend != NULL &&
|
2168
|
+
ggml_guid_matches(backend->guid, ggml_backend_cann_guid());
|
2169
|
+
}
|
2170
|
+
|
2171
|
+
int32_t ggml_backend_cann_get_device_count() {
|
2172
|
+
return ggml_cann_info().device_count;
|
2173
|
+
}
|
2174
|
+
|
2175
|
+
void ggml_backend_cann_get_device_description(
|
2176
|
+
int32_t device, char* description, size_t description_size) {
|
2177
|
+
ggml_cann_set_device(device);
|
2178
|
+
const char* soc_name = aclrtGetSocName();
|
2179
|
+
snprintf(description, description_size, "%s", soc_name);
|
2180
|
+
}
|
2181
|
+
|
2182
|
+
void ggml_backend_cann_get_device_memory(int32_t device, size_t* free,
|
2183
|
+
size_t* total) {
|
2184
|
+
ggml_cann_set_device(device);
|
2185
|
+
ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total));
|
2186
|
+
}
|
2187
|
+
|
2188
|
+
GGML_BACKEND_DL_IMPL(ggml_backend_cann_reg)
|