@fugood/llama.node 0.3.3 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. package/CMakeLists.txt +5 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/EmbeddingWorker.cpp +15 -5
  19. package/src/EmbeddingWorker.h +2 -1
  20. package/src/LlamaCompletionWorker.cpp +1 -1
  21. package/src/LlamaContext.cpp +81 -18
  22. package/src/LlamaContext.h +2 -0
  23. package/src/llama.cpp/.github/workflows/build.yml +197 -159
  24. package/src/llama.cpp/.github/workflows/docker.yml +5 -8
  25. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  26. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  27. package/src/llama.cpp/CMakeLists.txt +11 -6
  28. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  29. package/src/llama.cpp/cmake/common.cmake +33 -0
  30. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  31. package/src/llama.cpp/common/CMakeLists.txt +6 -2
  32. package/src/llama.cpp/common/arg.cpp +426 -245
  33. package/src/llama.cpp/common/common.cpp +143 -80
  34. package/src/llama.cpp/common/common.h +81 -24
  35. package/src/llama.cpp/common/sampling.cpp +53 -19
  36. package/src/llama.cpp/common/sampling.h +22 -1
  37. package/src/llama.cpp/common/speculative.cpp +274 -0
  38. package/src/llama.cpp/common/speculative.h +28 -0
  39. package/src/llama.cpp/docs/build.md +101 -148
  40. package/src/llama.cpp/examples/CMakeLists.txt +32 -13
  41. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +5 -4
  43. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  46. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  47. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  48. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  49. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  50. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  51. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  52. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  54. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  55. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  57. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  58. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  59. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  60. package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
  61. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/infill/infill.cpp +1 -1
  63. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  64. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
  65. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  66. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  67. package/src/llama.cpp/examples/llava/clip.cpp +262 -66
  68. package/src/llama.cpp/examples/llava/clip.h +8 -2
  69. package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
  70. package/src/llama.cpp/examples/llava/llava.cpp +46 -19
  71. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
  72. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  73. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
  75. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  76. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
  77. package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
  78. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/main/main.cpp +9 -5
  80. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  81. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
  83. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  84. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  87. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  88. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  89. package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
  90. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  91. package/src/llama.cpp/examples/run/run.cpp +911 -0
  92. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
  94. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
  95. package/src/llama.cpp/examples/server/server.cpp +1758 -886
  96. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  97. package/src/llama.cpp/examples/server/utils.hpp +94 -304
  98. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  99. package/src/llama.cpp/examples/simple/simple.cpp +4 -0
  100. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
  101. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
  102. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
  104. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  106. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
  108. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  109. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  110. package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
  111. package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
  112. package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
  113. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  114. package/src/llama.cpp/ggml/include/ggml.h +106 -24
  115. package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
  116. package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
  117. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
  118. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
  119. package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
  120. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
  121. package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
  122. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
  123. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  124. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  125. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
  126. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  127. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  128. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  129. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  130. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  131. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  132. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  133. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  134. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  135. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
  136. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  137. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  138. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
  139. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
  140. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
  141. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  142. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
  143. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
  151. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
  152. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
  153. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  155. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
  156. package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
  157. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
  158. package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
  159. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
  160. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
  161. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
  162. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  163. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  164. package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
  165. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
  167. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
  169. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
  172. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  173. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  174. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
  175. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
  176. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  177. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
  178. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  179. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  180. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
  181. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
  182. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
  183. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  184. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  185. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
  187. package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
  188. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
  189. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
  190. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
  191. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
  192. package/src/llama.cpp/ggml/src/ggml.c +367 -207
  193. package/src/llama.cpp/include/llama-cpp.h +25 -0
  194. package/src/llama.cpp/include/llama.h +26 -19
  195. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  196. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  197. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  198. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  199. package/src/llama.cpp/src/CMakeLists.txt +2 -7
  200. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  201. package/src/llama.cpp/src/llama-grammar.h +2 -5
  202. package/src/llama.cpp/src/llama-sampling.cpp +35 -90
  203. package/src/llama.cpp/src/llama-vocab.cpp +6 -1
  204. package/src/llama.cpp/src/llama.cpp +1748 -640
  205. package/src/llama.cpp/src/unicode.cpp +62 -51
  206. package/src/llama.cpp/src/unicode.h +9 -10
  207. package/src/llama.cpp/tests/CMakeLists.txt +48 -37
  208. package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
  209. package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
  210. package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
  211. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  212. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  213. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  214. package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
  215. package/src/llama.cpp/tests/test-rope.cpp +61 -20
  216. package/src/llama.cpp/tests/test-sampling.cpp +2 -2
  217. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  218. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  219. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  220. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  221. package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
  222. package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
  223. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
  224. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  225. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
@@ -211,17 +211,20 @@ struct ggml_cann_pool_alloc {
211
211
  struct ggml_backend_cann_context {
212
212
  int32_t device; /**< Device ID. */
213
213
  std::string name; /**< Name of the device. */
214
+ std::string description; /**< Description of the device. */
214
215
  aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
215
216
 
216
- aclrtStream streams[GGML_CANN_MAX_STREAMS] = {
217
- {nullptr}}; /**< Array of streams for the device. */
217
+ aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
218
218
 
219
219
  /**
220
220
  * @brief Constructor for initializing the context with a given device.
221
221
  * @param device Device ID.
222
222
  */
223
223
  explicit ggml_backend_cann_context(int device)
224
- : device(device), name("CANN" + std::to_string(device)) {}
224
+ : device(device), name("CANN" + std::to_string(device)) {
225
+ ggml_cann_set_device(device);
226
+ description = aclrtGetSocName();
227
+ }
225
228
 
226
229
  /**
227
230
  * @brief Destructor for cleaning up resources.
@@ -122,6 +122,10 @@ static ggml_cann_device_info ggml_cann_init() {
122
122
  ACL_CHECK(aclrtMemGetAllocationGranularity(
123
123
  &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
124
124
  &info.devices[id].vmm_granularity));
125
+
126
+ size_t free, total;
127
+ ggml_backend_cann_get_device_memory(id, &free, &total);
128
+ info.devices[id].total_vram = free;
125
129
  }
126
130
 
127
131
  // TODO: add more device info later.
@@ -208,6 +212,11 @@ struct ggml_cann_pool_leg : public ggml_cann_pool {
208
212
  * @return A pointer to the allocated buffer.
209
213
  */
210
214
  void* alloc(size_t size, size_t* actual_size) override {
215
+ const size_t alignment = 128;
216
+ size = GGML_PAD(size, alignment);
217
+ if (size == 0) {
218
+ size = alignment;
219
+ }
211
220
  #ifdef DEBUG_CANN_MALLOC
212
221
  int nnz = 0;
213
222
  size_t max_size = 0;
@@ -246,13 +255,11 @@ struct ggml_cann_pool_leg : public ggml_cann_pool {
246
255
  return ptr;
247
256
  }
248
257
  void* ptr;
249
- size_t look_ahead_size = (size_t)(1.05 * size);
250
- look_ahead_size = 256 * ((look_ahead_size + 255) / 256);
251
258
  ggml_cann_set_device(device);
252
259
  ACL_CHECK(
253
- aclrtMalloc(&ptr, look_ahead_size, ACL_MEM_MALLOC_HUGE_FIRST));
254
- *actual_size = look_ahead_size;
255
- pool_size += look_ahead_size;
260
+ aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
261
+ *actual_size = size;
262
+ pool_size += size;
256
263
  #ifdef DEBUG_CANN_MALLOC
257
264
  GGML_LOG_INFO(
258
265
  "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, "
@@ -296,7 +303,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
296
303
  /**
297
304
  * @brief The maximum size of the virtual memory pool (32 GB).
298
305
  */
299
- static const size_t CANN_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
306
+ size_t max_size;
300
307
 
301
308
  /**
302
309
  * @brief The device ID associated with this buffer pool.
@@ -341,7 +348,11 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
341
348
  */
342
349
  explicit ggml_cann_pool_vmm(int device)
343
350
  : device(device),
344
- granularity(ggml_cann_info().devices[device].vmm_granularity) {}
351
+ granularity(ggml_cann_info().devices[device].vmm_granularity) {
352
+ auto dev = ggml_cann_info().devices[device];
353
+ granularity = dev.vmm_granularity;
354
+ max_size = dev.total_vram;
355
+ }
345
356
 
346
357
  /**
347
358
  * @brief Destructor to free all buffers in the virtual memory pool.
@@ -370,17 +381,19 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
370
381
  // round up the allocation size to the alignment to ensure that all
371
382
  // allocations are aligned for all data types
372
383
  const size_t alignment = 128;
373
- size = alignment * ((size + alignment - 1) / alignment);
384
+ size = GGML_PAD(size, alignment);
385
+ if (size == 0) {
386
+ size = alignment;
387
+ }
374
388
 
375
389
  size_t avail = pool_size - pool_used;
376
390
 
377
391
  if (size > avail) {
378
392
  // round up to the next multiple of the granularity
379
393
  size_t reserve_size = size - avail;
380
- reserve_size =
381
- granularity * ((reserve_size + granularity - 1) / granularity);
394
+ reserve_size = GGML_PAD(reserve_size, granularity);
382
395
 
383
- GGML_ASSERT(pool_size + reserve_size <= CANN_POOL_VMM_MAX_SIZE);
396
+ GGML_ASSERT(pool_size + reserve_size <= max_size);
384
397
 
385
398
  // allocate more physical memory
386
399
  aclrtPhysicalMemProp prop = {};
@@ -396,7 +409,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
396
409
  // reserve virtual address space (if not already reserved)
397
410
  if (pool_addr == 0) {
398
411
  ACL_CHECK(aclrtReserveMemAddress(
399
- &pool_addr, CANN_POOL_VMM_MAX_SIZE, 0, NULL, 1));
412
+ &pool_addr, max_size, 0, NULL, 1));
400
413
  }
401
414
 
402
415
  // map at the end of the pool
@@ -409,10 +422,11 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
409
422
  // add to the pool
410
423
  pool_size += reserve_size;
411
424
 
412
- // GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (
413
- // reserved %llu MB)\n",
414
- // device, (unsigned long long) (pool_size/1024/1024),
415
- // (unsigned long long) (reserve_size/1024/1024));
425
+ #ifdef DEBUG_CANN_MALLOC
426
+ GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
427
+ device, (unsigned long long) (pool_size/1024/1024),
428
+ (unsigned long long) (reserve_size/1024/1024));
429
+ #endif
416
430
  }
417
431
 
418
432
  GGML_ASSERT(pool_addr != 0);
@@ -457,7 +471,6 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
457
471
  */
458
472
  std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
459
473
  int device) {
460
- // return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_leg(device));
461
474
  return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
462
475
  }
463
476
 
@@ -1130,10 +1143,10 @@ ggml_backend_cann_buffer_type(int32_t device) {
1130
1143
  static bool ggml_backend_cann_buffer_type_initialized = false;
1131
1144
 
1132
1145
  if (!ggml_backend_cann_buffer_type_initialized) {
1133
- for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) {
1146
+ for (int32_t i = 0; i < ggml_cann_info().device_count; i++) {
1134
1147
  ggml_backend_cann_buffer_types[i] = {
1135
1148
  /* .iface = */ ggml_backend_cann_buffer_type_interface,
1136
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
1149
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), i),
1137
1150
  /* .context = */
1138
1151
  new ggml_backend_cann_buffer_type_context{
1139
1152
  i, "CANN" + std::to_string(i)},
@@ -1199,10 +1212,15 @@ static void * ggml_cann_host_malloc(size_t size) {
1199
1212
  return nullptr;
1200
1213
  }
1201
1214
 
1215
+ const size_t alignment = 128;
1216
+ size = GGML_PAD(size, alignment);
1217
+ if (size == 0) {
1218
+ size = alignment;
1219
+ }
1220
+
1202
1221
  void * hostPtr = nullptr;
1203
1222
  aclError err = aclrtMallocHost((void **) &hostPtr, size);
1204
1223
  if (err != ACL_SUCCESS) {
1205
-
1206
1224
  GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
1207
1225
  size / 1024.0 / 1024.0, aclGetRecentErrMsg());
1208
1226
  return nullptr;
@@ -1669,12 +1687,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1669
1687
  }
1670
1688
  case GGML_OP_MUL_MAT: {
1671
1689
  switch (op->src[0]->type) {
1672
- case GGML_TYPE_F16:
1673
- case GGML_TYPE_F32:
1674
1690
  case GGML_TYPE_Q8_0:
1675
- // TODO: fix me
1676
1691
  // Current groupsize should not be greater than k-1 in
1677
- // aclnnWeightQuantBatchMatmulV2GetWorkspaceSize().
1692
+ // aclnnWeightQuantBatchMatmulV2GetWorkspaceSize
1693
+ if (op->src[0]->ne[0] <= QK8_0) {
1694
+ return false;
1695
+ }
1696
+ case GGML_TYPE_F16:
1697
+ case GGML_TYPE_F32:
1678
1698
  case GGML_TYPE_Q4_0:
1679
1699
  return true;
1680
1700
  default:
@@ -1706,9 +1726,50 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1706
1726
  return false;
1707
1727
  }
1708
1728
  }
1729
+ case GGML_OP_CONT: {
1730
+ // TODO: support GGML_TYPE_BF16
1731
+ switch (op->src[0]->type) {
1732
+ case GGML_TYPE_F32:
1733
+ case GGML_TYPE_F16:
1734
+ return true;
1735
+ default:
1736
+ return false;
1737
+ }
1738
+ }
1739
+ case GGML_OP_ROPE: {
1740
+ // TODO: with ops-test v == 1
1741
+ float * ext_factor = (float*)((int32_t*)op->op_params + 7);
1742
+ // TODO: n_dims <= ne0
1743
+ if (op->src[0]->ne[0] != op->op_params[1]) {
1744
+ return false;
1745
+ }
1746
+ // TODO: ext_factor != 0
1747
+ if (*ext_factor != 0) {
1748
+ return false;
1749
+ }
1750
+
1751
+ const int mode = ((const int32_t *) op->op_params)[2];
1752
+ if (mode & GGML_ROPE_TYPE_MROPE) {
1753
+ return false;
1754
+ }
1755
+ if (mode & GGML_ROPE_TYPE_VISION) {
1756
+ return false;
1757
+ }
1758
+
1759
+ return true;
1760
+ }
1761
+ case GGML_OP_UPSCALE: {
1762
+ // aclnnUpsampleNearest2dGetWorkspaceSize not support
1763
+ // selfDimN[2]/outDimN[2] or selfDimC[3]/outDimC[3] not equal
1764
+ if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) {
1765
+ return false;
1766
+ }
1767
+ return true;
1768
+ }
1769
+ case GGML_OP_IM2COL:
1770
+ case GGML_OP_CONCAT:
1709
1771
  case GGML_OP_DUP:
1710
1772
  case GGML_OP_REPEAT:
1711
- case GGML_OP_CONCAT:
1712
1773
  case GGML_OP_NONE:
1713
1774
  case GGML_OP_RESHAPE:
1714
1775
  case GGML_OP_VIEW:
@@ -1722,17 +1783,13 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1722
1783
  case GGML_OP_SCALE:
1723
1784
  case GGML_OP_SQR:
1724
1785
  case GGML_OP_CLAMP:
1725
- case GGML_OP_CONT:
1726
1786
  case GGML_OP_DIAG_MASK_INF:
1727
1787
  case GGML_OP_SOFT_MAX:
1728
- case GGML_OP_ROPE:
1729
- case GGML_OP_IM2COL:
1730
1788
  case GGML_OP_POOL_2D:
1731
1789
  case GGML_OP_SUM_ROWS:
1732
1790
  case GGML_OP_ARGSORT:
1733
1791
  case GGML_OP_ACC:
1734
1792
  case GGML_OP_GROUP_NORM:
1735
- case GGML_OP_UPSCALE:
1736
1793
  case GGML_OP_PAD:
1737
1794
  case GGML_OP_ARANGE:
1738
1795
  case GGML_OP_TIMESTEP_EMBEDDING:
@@ -2041,7 +2098,7 @@ static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, con
2041
2098
  static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
2042
2099
  /* .get_name = */ ggml_backend_cann_reg_get_name,
2043
2100
  /* .get_device_count = */ ggml_backend_cann_reg_get_device_count,
2044
- /* .get_device_get = */ ggml_backend_cann_reg_get_device,
2101
+ /* .get_device = */ ggml_backend_cann_reg_get_device,
2045
2102
  /* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address,
2046
2103
  };
2047
2104
 
@@ -2064,16 +2121,17 @@ ggml_backend_reg_t ggml_backend_cann_reg() {
2064
2121
  dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
2065
2122
  ggml_cann_set_device(i);
2066
2123
  ggml_backend_dev_t dev = new ggml_backend_device {
2067
- /* .interface = */ ggml_backend_cann_device_interface,
2068
- /* .reg = */ &reg,
2069
- /* .context = */ dev_ctx
2124
+ /* .iface = */ ggml_backend_cann_device_interface,
2125
+ /* .reg = */ &reg,
2126
+ /* .context = */ dev_ctx
2070
2127
  };
2071
2128
  ctx->devices.push_back(dev);
2072
2129
  }
2073
2130
 
2074
2131
  reg = ggml_backend_reg {
2075
- /* .interface = */ ggml_backend_cann_reg_interface,
2076
- /* .context = */ ctx
2132
+ /* .api_version = */ GGML_BACKEND_API_VERSION,
2133
+ /* .iface = */ ggml_backend_cann_reg_interface,
2134
+ /* .context = */ ctx
2077
2135
  };
2078
2136
  }
2079
2137
 
@@ -2126,3 +2184,5 @@ void ggml_backend_cann_get_device_memory(int32_t device, size_t* free,
2126
2184
  ggml_cann_set_device(device);
2127
2185
  ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total));
2128
2186
  }
2187
+
2188
+ GGML_BACKEND_DL_IMPL(ggml_backend_cann_reg)
@@ -1,7 +1,3 @@
1
- if (NOT SOC_TYPE)
2
- set (SOC_TYPE "Ascend910B3")
3
- endif()
4
-
5
1
  file(GLOB SRC_FILES
6
2
  get_row_f32.cpp
7
3
  get_row_f16.cpp
@@ -13,7 +9,6 @@ file(GLOB SRC_FILES
13
9
  dup.cpp
14
10
  )
15
11
 
16
- string(TOLOWER ${SOC_TYPE} SOC_VERSION)
17
12
  set(ASCEND_CANN_PACKAGE_PATH ${CANN_INSTALL_DIR})
18
13
  set(RUN_MODE "npu" CACHE STRING "run mode: npu/sim")
19
14
 
@@ -30,4 +25,6 @@ ascendc_library(ascendc_kernels STATIC
30
25
  ${SRC_FILES}
31
26
  )
32
27
 
28
+ message(STATUS "CANN: compile ascend kernels witch SOC_TYPE:${SOC_TYPE}, SOC_VERSION:${SOC_VERSION}, compile macro:-D${SOC_TYPE_COMPILE_OPTION}.")
29
+ ascendc_compile_definitions(ascendc_kernels PRIVATE "-D${SOC_TYPE_COMPILE_OPTION}")
33
30
  # ascendc_compile_definitions(ascendc_kernels PRIVATE -DASCENDC_DUMP)
@@ -5,6 +5,7 @@
5
5
  using namespace AscendC;
6
6
 
7
7
  #define BUFFER_NUM 2
8
+ const int64_t SUPPORTED_MAX_DIM = 65535; // currently the limit of max block dim supportted by dup kernel is 65535template <typename SRC_T, typename DST_T>
8
9
 
9
10
  template <typename SRC_T, typename DST_T>
10
11
  class DupByRows {
@@ -51,24 +52,36 @@ class DupByRows {
51
52
 
52
53
  __aicore__ inline void copy_in() {
53
54
  LocalTensor<SRC_T> src_local = src_queue.AllocTensor<SRC_T>();
54
-
55
- DataCopyExtParams dataCopyParams;
56
- dataCopyParams.blockCount = 1;
57
- dataCopyParams.blockLen = num_elem * sizeof(SRC_T);
58
- DataCopyPadExtParams<SRC_T> padParams;
59
- DataCopyPad(src_local, src_gm, dataCopyParams, padParams);
60
-
55
+ const size_t elem_per_block = 32 / sizeof(SRC_T);
56
+ size_t tail = num_elem % elem_per_block;
57
+ size_t cpy_elements_len = tail > 0 ? num_elem + 1 : num_elem;
58
+ DataCopy(src_local, src_gm, cpy_elements_len);
61
59
  src_queue.EnQue(src_local);
62
60
  }
63
61
 
64
62
  __aicore__ inline void copy_out() {
65
63
  LocalTensor<DST_T> dst_local = dst_queue.DeQue<DST_T>();
66
-
64
+ #ifdef ASCEND_310P
65
+ const size_t elem_per_block = 32 / sizeof(DST_T);
66
+ size_t tail = num_elem % elem_per_block;
67
+ size_t len = num_elem & ~(elem_per_block - 1);
68
+ if (len > 0) {
69
+ DataCopy(dst_gm, dst_local, len);
70
+ }
71
+ if(tail != 0) {
72
+ for (size_t i = tail; i < elem_per_block; i++) {
73
+ dst_local[len + i].SetValue(0, 0);
74
+ }
75
+ SetAtomicAdd<float>();
76
+ DataCopy(dst_gm[len], dst_local[len], elem_per_block);
77
+ SetAtomicNone();
78
+ }
79
+ #else
67
80
  DataCopyExtParams dataCopyParams;
68
81
  dataCopyParams.blockCount = 1;
69
82
  dataCopyParams.blockLen = num_elem * sizeof(DST_T);
70
83
  DataCopyPad(dst_gm, dst_local, dataCopyParams);
71
-
84
+ #endif
72
85
  dst_queue.FreeTensor(dst_local);
73
86
  }
74
87
 
@@ -14,7 +14,7 @@ class GET_ROW_F16 {
14
14
  int64_t *output_ne_ub, size_t *output_nb_ub) {
15
15
  // TODO, use template for F16/f32
16
16
  int64_t op_block_num = GetBlockNum();
17
- int64_t op_block_idx = GetBlockIdx();
17
+ op_block_idx = GetBlockIdx();
18
18
 
19
19
  for (int i = 0; i < 4; i++) {
20
20
  input_ne[i] = input_ne_ub[i];
@@ -59,32 +59,42 @@ class GET_ROW_F16 {
59
59
  }
60
60
 
61
61
  __aicore__ inline void copy_in(uint32_t offset, size_t len) {
62
+ size_t origin_len = len;
62
63
  LocalTensor<half> input_local = input_queue.AllocTensor<half>();
63
- size_t tail = len % 32;
64
- len = len & ~31;
65
- DataCopy(input_local, input_gm[offset], len);
64
+ const size_t elem_per_block = 32 / sizeof(half);
65
+ size_t tail = len % elem_per_block;
66
+ len = len & ~(elem_per_block - 1);
66
67
  if(tail != 0) {
67
- DataCopyExtParams dataCopyParams;
68
- dataCopyParams.blockCount = 1;
69
- dataCopyParams.blockLen = tail * sizeof(half);
70
- DataCopyPadExtParams<half> padParams;
71
- DataCopyPad(input_local[len], input_gm[offset + len],
72
- dataCopyParams, padParams);
68
+ len += elem_per_block;
73
69
  }
70
+ DataCopy(input_local, input_gm[offset], len);
74
71
  input_queue.EnQue(input_local);
75
72
  }
76
73
 
77
74
  __aicore__ inline void copy_out(uint32_t offset, size_t len) {
78
75
  LocalTensor<float> output_local = output_queue.DeQue<float>();
79
- size_t tail = len % 32;
80
- len = len & ~31;
81
- DataCopy(output_gm[offset], output_local, len);
76
+ const size_t elem_per_block = 32 / sizeof(float);
77
+ size_t tail = len % elem_per_block;
78
+ len = len & ~(elem_per_block - 1);
79
+ if (len > 0) {
80
+ DataCopy(output_gm[offset], output_local, len);
81
+ }
82
+
82
83
  if(tail != 0) {
84
+ #ifdef ASCEND_310P
85
+ for (size_t i = tail; i < elem_per_block; i++) {
86
+ output_local[len + i].SetValue(0, 0);
87
+ }
88
+ SetAtomicAdd<float>();
89
+ DataCopy(output_gm[offset + len], output_local[len], elem_per_block);
90
+ SetAtomicNone();
91
+ #else
83
92
  DataCopyExtParams dataCopyParams;
84
93
  dataCopyParams.blockCount = 1;
85
94
  dataCopyParams.blockLen = tail * sizeof(float);
86
95
  DataCopyPad(output_gm[offset + len], output_local[len],
87
96
  dataCopyParams);
97
+ #endif
88
98
  }
89
99
  output_queue.FreeTensor(output_local);
90
100
  }
@@ -150,6 +160,7 @@ class GET_ROW_F16 {
150
160
  GlobalTensor<float> output_gm;
151
161
  TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
152
162
  TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
163
+ int64_t op_block_idx;
153
164
  };
154
165
 
155
166
  template <typename T>
@@ -13,7 +13,7 @@ class GET_ROW_F32 {
13
13
  int64_t *indices_ne_ub, size_t *indices_nb_ub,
14
14
  int64_t *output_ne_ub, size_t *output_nb_ub) {
15
15
  int64_t op_block_num = GetBlockNum();
16
- int64_t op_block_idx = GetBlockIdx();
16
+ op_block_idx = GetBlockIdx();
17
17
 
18
18
  for (int i = 0; i < 4; i++) {
19
19
  input_ne[i] = input_ne_ub[i];
@@ -55,31 +55,40 @@ class GET_ROW_F32 {
55
55
 
56
56
  __aicore__ inline void copy_in(uint32_t offset, size_t len) {
57
57
  LocalTensor<float> input_local = input_queue.AllocTensor<float>();
58
- size_t tail = len % 32;
59
- len = len & ~31;
60
- DataCopy(input_local, input_gm[offset], len);
58
+ const size_t elem_per_block = 32 / sizeof(float);
59
+ size_t tail = len % elem_per_block;
60
+ len = len & ~(elem_per_block - 1);
61
61
  if(tail != 0) {
62
- DataCopyExtParams dataCopyParams;
63
- dataCopyParams.blockCount = 1;
64
- dataCopyParams.blockLen = tail * sizeof(float);
65
- DataCopyPadExtParams<float> padParams;
66
- DataCopyPad(input_local[len], input_gm[offset + len],
67
- dataCopyParams, padParams);
62
+ len += elem_per_block;
68
63
  }
64
+ DataCopy(input_local, input_gm[offset], len);
69
65
  input_queue.EnQue(input_local);
70
66
  }
71
67
 
72
68
  __aicore__ inline void copy_out(uint32_t offset, size_t len) {
73
69
  LocalTensor<float> output_local = output_queue.DeQue<float>();
74
- size_t tail = len % 32;
75
- len = len & ~31;
76
- DataCopy(output_gm[offset], output_local, len);
70
+ const size_t elem_per_block = 32 / sizeof(float);
71
+ size_t tail = len % elem_per_block;
72
+ len = len & ~(elem_per_block - 1);
73
+ if (len > 0) {
74
+ DataCopy(output_gm[offset], output_local, len);
75
+ }
76
+
77
77
  if(tail != 0) {
78
+ #ifdef ASCEND_310P
79
+ for (size_t i = tail; i < elem_per_block; i++) {
80
+ output_local[len + i].SetValue(0, 0);
81
+ }
82
+ SetAtomicAdd<float>();
83
+ DataCopy(output_gm[offset + len], output_local[len], elem_per_block);
84
+ SetAtomicNone();
85
+ #else
78
86
  DataCopyExtParams dataCopyParams;
79
87
  dataCopyParams.blockCount = 1;
80
88
  dataCopyParams.blockLen = tail * sizeof(float);
81
89
  DataCopyPad(output_gm[offset + len], output_local[len],
82
90
  dataCopyParams);
91
+ #endif
83
92
  }
84
93
  output_queue.FreeTensor(output_local);
85
94
  }
@@ -144,6 +153,7 @@ class GET_ROW_F32 {
144
153
  GlobalTensor<float> output_gm;
145
154
  TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
146
155
  TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
156
+ int64_t op_block_idx;
147
157
  };
148
158
 
149
159
  template <typename T>
@@ -2,6 +2,15 @@
2
2
 
3
3
  // optimize me. Use template to avoid copy code.
4
4
  using namespace AscendC;
5
+ #ifdef ASCEND_310P // 310P not support 4bit get row
6
+ extern "C" __global__ __aicore__ void ascendc_get_row_q4_0(
7
+ GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm,
8
+ GM_ADDR input_ne_gm, GM_ADDR indices_ne_gm, GM_ADDR indices_nb_gm,
9
+ GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) {
10
+ // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed.
11
+ printf("Ascend310P not support 4bit get row.\n");
12
+ }
13
+ #else
5
14
 
6
15
  #define BUFFER_NUM 2
7
16
 
@@ -191,3 +200,5 @@ extern "C" __global__ __aicore__ void ascendc_get_row_q4_0(
191
200
  indices_nb_ub, output_ne_ub, output_nb_ub);
192
201
  op.calculate();
193
202
  }
203
+
204
+ #endif // #ifdef ASCEND_310P
@@ -1,6 +1,14 @@
1
1
  #include "kernel_operator.h"
2
2
 
3
3
  using namespace AscendC;
4
+ #ifdef ASCEND_310P
5
+ extern "C" __global__ __aicore__ void ascendc_quantize_f16_q8_0(
6
+ GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
7
+ GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
8
+ // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed.
9
+ printf("Ascend310P not support f16->8bit quantization.\n");
10
+ }
11
+ #else
4
12
 
5
13
  #define BUFFER_NUM 2
6
14
  #define QK8_0 32
@@ -206,3 +214,5 @@ extern "C" __global__ __aicore__ void ascendc_quantize_f16_q8_0(
206
214
  op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
207
215
  op.calculate();
208
216
  }
217
+
218
+ #endif // #ifdef ASCEND_310P
@@ -1,6 +1,14 @@
1
1
  #include "kernel_operator.h"
2
2
 
3
3
  using namespace AscendC;
4
+ #ifdef ASCEND_310P // 310P not support f32->8bit quantization
5
+ extern "C" __global__ __aicore__ void ascendc_quantize_f32_q8_0(
6
+ GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
7
+ GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
8
+ // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed.
9
+ printf("Ascend310P not support f32->8bit quantization.\n");
10
+ }
11
+ #else
4
12
 
5
13
  #define BUFFER_NUM 2
6
14
  #define QK8_0 32
@@ -204,3 +212,5 @@ extern "C" __global__ __aicore__ void ascendc_quantize_f32_q8_0(
204
212
  op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
205
213
  op.calculate();
206
214
  }
215
+
216
+ #endif // #ifdef ASCEND_310P
@@ -1,6 +1,21 @@
1
1
  #include "kernel_operator.h"
2
2
 
3
3
  using namespace AscendC;
4
+ #ifdef ASCEND_310P // 310P not support float->4bit quantization
5
+ extern "C" __global__ __aicore__ void ascendc_quantize_f32_to_q4_0(
6
+ GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
7
+ GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
8
+ // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed.
9
+ printf("Ascend310P not support f32->4bit quantization.\n");
10
+ }
11
+
12
+ extern "C" __global__ __aicore__ void ascendc_quantize_f16_to_q4_0(
13
+ GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
14
+ GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
15
+ // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed.
16
+ printf("Ascend310P not support f16->4bit quantization.\n");
17
+ }
18
+ #else
4
19
 
5
20
  #define BUFFER_NUM 2
6
21
  #define Group_Size 32
@@ -276,3 +291,5 @@ extern "C" __global__ __aicore__ void ascendc_quantize_f32_to_q4_0(
276
291
  op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
277
292
  op.calculate();
278
293
  }
294
+
295
+ #endif // #ifdef ASCEND_310P