@novastera-oss/llamarn 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakePresets.json +11 -0
  22. package/cpp/llama.cpp/CODEOWNERS +1 -0
  23. package/cpp/llama.cpp/README.md +4 -3
  24. package/cpp/llama.cpp/common/arg.cpp +45 -1
  25. package/cpp/llama.cpp/common/common.cpp +22 -6
  26. package/cpp/llama.cpp/common/common.h +18 -4
  27. package/cpp/llama.cpp/convert_hf_to_gguf.py +500 -32
  28. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +12 -13
  29. package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -1
  30. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  31. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  32. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  34. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -0
  35. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +8 -20
  36. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +58 -3
  38. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +122 -16
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +5 -2
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +3 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +3 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +14 -4
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +64 -17
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -67
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +45 -62
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +28 -43
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +41 -56
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -47
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +31 -43
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +22 -37
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +73 -23
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -689
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +7 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +13 -1
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  75. package/cpp/llama.cpp/ggml/src/ggml-impl.h +16 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +13 -3
  77. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +407 -69
  78. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +380 -83
  79. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +2 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +295 -2
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  85. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +131 -46
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +43 -43
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +287 -22
  95. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +1 -5
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  101. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  102. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  105. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +8 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +71 -16
  109. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  112. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  114. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  115. package/cpp/llama.cpp/ggml/src/ggml.c +4 -6
  116. package/cpp/llama.cpp/gguf-py/gguf/constants.py +98 -0
  117. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  118. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  119. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +75 -52
  120. package/cpp/llama.cpp/include/llama.h +15 -7
  121. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  122. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  123. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  124. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  125. package/cpp/llama.cpp/src/llama-arch.cpp +106 -0
  126. package/cpp/llama.cpp/src/llama-arch.h +5 -0
  127. package/cpp/llama.cpp/src/llama-batch.cpp +76 -70
  128. package/cpp/llama.cpp/src/llama-batch.h +24 -18
  129. package/cpp/llama.cpp/src/llama-chat.cpp +43 -1
  130. package/cpp/llama.cpp/src/llama-chat.h +2 -0
  131. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  132. package/cpp/llama.cpp/src/llama-context.h +26 -16
  133. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  134. package/cpp/llama.cpp/src/llama-graph.cpp +203 -39
  135. package/cpp/llama.cpp/src/llama-graph.h +147 -72
  136. package/cpp/llama.cpp/src/llama-hparams.cpp +40 -0
  137. package/cpp/llama.cpp/src/llama-hparams.h +10 -2
  138. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +11 -5
  139. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +3 -0
  140. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +698 -302
  141. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +89 -31
  142. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +1 -0
  143. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +16 -1
  144. package/cpp/llama.cpp/src/llama-model.cpp +1293 -312
  145. package/cpp/llama.cpp/src/llama-model.h +3 -4
  146. package/cpp/llama.cpp/src/llama-quant.cpp +1 -2
  147. package/cpp/llama.cpp/src/llama-vocab.cpp +363 -8
  148. package/cpp/llama.cpp/src/llama-vocab.h +2 -0
  149. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  150. package/cpp/llama.cpp/src/unicode.h +2 -0
  151. package/ios/include/common.h +18 -4
  152. package/ios/include/llama.h +15 -7
  153. package/ios/libs/llama.xcframework/Info.plist +15 -15
  154. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  155. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
  156. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -7
  157. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  158. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  159. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  160. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
  161. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  162. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  163. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  164. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  165. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3891
  166. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -7
  167. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -7
  168. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  169. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -7
  170. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  171. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  172. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  173. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
  174. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -7
  175. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  176. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  177. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
  178. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
  179. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  180. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  181. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  182. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -5095
  183. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -7
  184. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  185. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  186. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -5066
  187. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3919
  188. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
  189. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  190. package/package.json +4 -4
@@ -23,6 +23,7 @@
23
23
  #ifndef CANN_ACLNN_OPS
24
24
  #define CANN_ACLNN_OPS
25
25
 
26
+ #include <unordered_set>
26
27
  #include <functional>
27
28
  #include <aclnnop/aclnn_abs.h>
28
29
  #include <aclnnop/aclnn_neg.h>
@@ -1020,6 +1021,37 @@ inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffe
1020
1021
  */
1021
1022
  void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst);
1022
1023
 
1024
+ /**
1025
+ * @brief Check whether a tensor is a weight tensor for matrix multiplication.
1026
+ *
1027
+ * @details Checks whether the given tensor serves as weight parameters in matrix multiplication operations,
1028
+ * typically within neural network layers. The function maintains a static set of canonical weight
1029
+ * naming suffixes from Transformer-based architectures. Uses substring matching to identify weight
1030
+ * tensors even with hierarchical naming patterns.
1031
+ *
1032
+ * @param tensor Pointer to the target ggml_tensor object (const-qualified).
1033
+ */
1034
+ static bool is_matmul_weight(const ggml_tensor* tensor) {
1035
+ std::string name = ggml_get_name(tensor);
1036
+ static const std::unordered_set<std::string> weight_suffixes{
1037
+ "output.weight",
1038
+ "attn_q.weight",
1039
+ "attn_k.weight",
1040
+ "attn_v.weight",
1041
+ "attn_output.weight",
1042
+ "ffn_gate.weight",
1043
+ "ffn_up.weight",
1044
+ "ffn_down.weight"
1045
+ };
1046
+
1047
+ for (const auto& suffix : weight_suffixes) {
1048
+ if (name.find(suffix) != std::string::npos) {
1049
+ return true;
1050
+ }
1051
+ }
1052
+ return false;
1053
+ }
1054
+
1023
1055
  /**
1024
1056
  * @brief Applies a element-wise operation to two input tensors using the CANN
1025
1057
  * backend.
@@ -1066,7 +1098,7 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1066
1098
  * @param dst The destination tensor. Its src[0] is treated as the input tensor.
1067
1099
  */
1068
1100
  template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
1069
- void ggml_cann_unary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1101
+ void ggml_cann_op_unary(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1070
1102
  ggml_tensor* src = dst->src[0];
1071
1103
 
1072
1104
  aclTensor* acl_src = ggml_cann_create_tensor(src);
@@ -1077,49 +1109,125 @@ template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
1077
1109
  }
1078
1110
 
1079
1111
  /**
1080
- * @brief Applies a unary operation to a ggml tensor using the CANN backend.
1112
+ * @brief Applies a unary operation to a ggml tensor using the CANN backend.
1081
1113
  *
1082
- * @details This function performs a unary operation on the input tensor using
1083
- * a user-provided lambda or callable object `unary_op`, which accepts the CANN
1084
- * context and two ACL tensors (source and destination). Internally, this function
1085
- * creates ACL representations of the ggml tensors and invokes the unary operation.
1086
- * The result is stored in the destination tensor `dst`. This utility abstracts the
1087
- * common boilerplate of tensor conversion and cleanup when implementing unary ops.
1114
+ * @details This function applies a unary operation to the input tensor using
1115
+ * a user-provided lambda or callable `unary_op`. The lambda receives the
1116
+ * CANN backend context and two ACL tensors: the source and the destination.
1088
1117
  *
1089
- * @param unary_op A callable that performs the unary operation using CANN APIs.
1090
- * @param ctx The CANN context used for operations.
1091
- * @param dst The destination tensor where the result will be stored.
1092
- * The source tensor is retrieved from `dst->src[0]`.
1118
+ * Internally, this function handles the conversion from GGML tensors to ACL tensors,
1119
+ * calls the provided unary op, and manages resource cleanup. The input is assumed
1120
+ * to be `dst->src[0]`, and the result is written to `dst`.
1121
+ *
1122
+ * This utility simplifies writing unary op wrappers by abstracting tensor preparation.
1123
+ *
1124
+ * @param unary_op A callable that performs the unary operation using CANN ACL APIs.
1125
+ * @param ctx The CANN context for operation execution.
1126
+ * @param dst The destination ggml_tensor where the result will be stored.
1127
+ * The input tensor is assumed to be `dst->src[0]`.
1128
+ *
1129
+ * @see GGML_CANN_CALL_OP_UNARY
1130
+ */
1131
+ void ggml_cann_op_unary(
1132
+ std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
1133
+ ggml_backend_cann_context& ctx, ggml_tensor* dst);
1134
+
1135
+ /**
1136
+ * @brief Applies a gated (GLU-style) unary operation using the CANN backend.
1137
+ *
1138
+ * @details This function performs a gated activation such as GEGLU or ReGLU.
1139
+ * It supports two input modes:
1140
+ *
1141
+ * 1. **Dual input mode**: `dst->src[0]` and `dst->src[1]` are both valid tensors.
1142
+ * These are used directly as the value and gate tensors.
1143
+ *
1144
+ * 2. **Packed input mode**: Only `dst->src[0]` is valid, and it is assumed to
1145
+ * contain a concatenation of value and gate along the first dimension. This tensor
1146
+ * will be split into two equal halves to form the value and gate inputs.
1147
+ *
1148
+ * The function applies a user-provided unary operation (e.g., GELU) to the value tensor,
1149
+ * then multiplies the result in-place with the gate tensor:
1150
+ *
1151
+ * @code
1152
+ * dst = unary_op(value) * gate;
1153
+ * @endcode
1154
+ *
1155
+ * The `swapped` parameter (from `dst->op_params[1]`) allows flipping the
1156
+ * order of value/gate in the packed input case.
1157
+ *
1158
+ * @param unary_op A callable that performs the unary operation using CANN ACL APIs.
1159
+ * It receives (ctx, acl_value_tensor, acl_output_tensor).
1160
+ * @param ctx The CANN context used for execution.
1161
+ * @param dst The destination ggml_tensor. Source tensors are in `dst->src[0]` and optionally `src[1]`.
1162
+ *
1163
+ * @see GGML_CANN_CALL_OP_UNARY_GATED
1093
1164
  */
1094
- void ggml_cann_unary_op(
1165
+ void ggml_cann_op_unary_gated(
1095
1166
  std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
1096
1167
  ggml_backend_cann_context& ctx, ggml_tensor* dst);
1097
1168
 
1098
1169
  /**
1099
- * @brief Helper macro to invoke a unary ACL operation using ggml_cann_unary_op.
1170
+ * @brief Helper macro to call a unary ACL operator via ggml_cann_op_unary.
1100
1171
  *
1101
- * This macro defines an inline lambda wrapping a specific ACL operation name,
1102
- * and passes it to the templated ggml_cann_unary_op function. It simplifies
1103
- * calling unary ops by hiding the lambda boilerplate.
1172
+ * This macro wraps the specified ACLNN unary operator name into a lambda expression,
1173
+ * and passes it to `ggml_cann_op_unary`, which handles the common logic for executing
1174
+ * unary ops in the CANN backend.
1104
1175
  *
1105
- * Internally, the lambda will call:
1176
+ * Internally, this macro expands to a lambda like:
1106
1177
  * @code
1107
- * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
1178
+ * [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) {
1179
+ * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
1180
+ * };
1108
1181
  * @endcode
1109
1182
  *
1183
+ * This lambda is then passed to `ggml_cann_op_unary`, which applies the operation.
1184
+ *
1110
1185
  * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
1111
1186
  *
1112
- * @see ggml_cann_unary_op
1187
+ * @see ggml_cann_op_unary
1113
1188
  * @see GGML_CANN_CALL_ACLNN_OP
1114
1189
  */
1115
- #define GGML_CANN_CALL_UNARY_OP(OP_NAME) \
1190
+ #define GGML_CANN_CALL_OP_UNARY(OP_NAME) \
1116
1191
  do { \
1117
1192
  auto lambda = [](ggml_backend_cann_context& ctx, \
1118
1193
  aclTensor* acl_src, \
1119
1194
  aclTensor* acl_dst) { \
1120
1195
  GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
1121
1196
  }; \
1122
- ggml_cann_unary_op(lambda, ctx, dst); \
1197
+ ggml_cann_op_unary(lambda, ctx, dst); \
1123
1198
  } \
1124
1199
  while (0)
1200
+
1201
+ /**
1202
+ * @brief Helper macro to call a gated unary ACL operator via ggml_cann_op_unary_gated.
1203
+ *
1204
+ * This macro wraps the specified ACLNN unary operator name into a lambda expression,
1205
+ * and passes it to `ggml_cann_op_unary_gated`, which handles the common logic for
1206
+ * executing gated unary ops in the CANN backend.
1207
+ *
1208
+ * Internally, this macro expands to a lambda like:
1209
+ * @code
1210
+ * [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) {
1211
+ * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
1212
+ * };
1213
+ * @endcode
1214
+ *
1215
+ * This lambda is then passed to `ggml_cann_op_unary_gated`, which applies the operation.
1216
+ *
1217
+ * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
1218
+ *
1219
+ * @see ggml_cann_op_unary_gated
1220
+ * @see GGML_CANN_CALL_ACLNN_OP
1221
+ */
1222
+ #define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME) \
1223
+ do { \
1224
+ auto lambda = [](ggml_backend_cann_context& ctx, \
1225
+ aclTensor* acl_src, \
1226
+ aclTensor* acl_dst) { \
1227
+ GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
1228
+ }; \
1229
+ ggml_cann_op_unary_gated(lambda, ctx, dst); \
1230
+ } \
1231
+ while (0)
1232
+
1125
1233
  #endif // CANN_ACLNN_OPS
@@ -24,6 +24,7 @@
24
24
 
25
25
  #include <acl/acl.h>
26
26
  #include <stdarg.h>
27
+ #include <aclnnop/aclnn_trans_matmul_weight.h>
27
28
 
28
29
  #include <cmath>
29
30
  #include <cstdio>
@@ -1115,6 +1116,63 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
1115
1116
  return GGML_STATUS_SUCCESS;
1116
1117
  }
1117
1118
 
1119
+ static int CreateAclTensorWeight(const void *hostData, const std::vector<int64_t> &shape, void **deviceAddr,
1120
+ aclDataType dataType, aclTensor **tensor)
1121
+ {
1122
+ uint64_t size = 1;
1123
+ for (auto i : shape) {
1124
+ size *= i;
1125
+ }
1126
+
1127
+ const aclIntArray *mat2Size = aclCreateIntArray(shape.data(), shape.size());
1128
+ ACL_CHECK(aclnnCalculateMatmulWeightSizeV2(mat2Size, dataType, &size));
1129
+
1130
+ size *= sizeof(int16_t);
1131
+
1132
+ ACL_CHECK(aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST));
1133
+ aclrtMemcpy(*deviceAddr, size, hostData, size, ACL_MEMCPY_HOST_TO_DEVICE);
1134
+
1135
+ std::vector<int64_t> strides(shape.size(), 1);
1136
+ for (int64_t i = shape.size() - 2; i >= 0; i--) {
1137
+ strides[i] = shape[i + 1] * strides[i + 1];
1138
+ }
1139
+
1140
+ *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND,
1141
+ shape.data(), shape.size(), *deviceAddr);
1142
+ return 0;
1143
+ }
1144
+
1145
+ static void weight_format_to_nz(ggml_tensor *tensor, const void *data, size_t offset) {
1146
+ aclrtStream stream;
1147
+ ACL_CHECK(aclrtCreateStream(&stream));
1148
+
1149
+ std::vector<int64_t> weightTransposedShape = {tensor->ne[1], tensor->ne[0]};
1150
+ void *weightTransposedDeviceAddr = nullptr;
1151
+ aclTensor *weightTransposed = nullptr;
1152
+ CreateAclTensorWeight(data, weightTransposedShape, &weightTransposedDeviceAddr,
1153
+ ggml_cann_type_mapping(tensor->type), &weightTransposed);
1154
+
1155
+ uint64_t workspaceSize = 0;
1156
+ aclOpExecutor *executor;
1157
+ void *workspaceAddr = nullptr;
1158
+
1159
+ // TransMatmulWeight
1160
+ ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed, &workspaceSize, &executor));
1161
+ std::unique_ptr<void, aclError (*)(void *)> workspaceAddrPtrTrans(nullptr, aclrtFree);
1162
+ if (workspaceSize > 0) {
1163
+ ACL_CHECK(aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST));
1164
+ workspaceAddrPtrTrans.reset(workspaceAddr);
1165
+ }
1166
+ ACL_CHECK(aclnnTransMatmulWeight(workspaceAddr, workspaceSize, executor, stream));
1167
+
1168
+ size_t size = ggml_nelements(tensor) * ggml_element_size(tensor);
1169
+
1170
+ aclrtMemcpy((char *)tensor->data + offset, size,
1171
+ weightTransposedDeviceAddr, size, ACL_MEMCPY_HOST_TO_DEVICE);
1172
+ ACL_CHECK(aclDestroyTensor(weightTransposed));
1173
+ aclrtFree(weightTransposedDeviceAddr);
1174
+ }
1175
+
1118
1176
  // TODO: need handle tensor which has paddings.
1119
1177
  /**
1120
1178
  * @brief Set tensor data in a CANN buffer.
@@ -1139,9 +1197,16 @@ static void ggml_backend_cann_buffer_set_tensor(
1139
1197
  // For acl, synchronous functions use this default stream.
1140
1198
  // Why aclrtSynchronizeDevice?
1141
1199
 
1200
+ bool weightToNZ = false;
1201
+ #ifdef ASCEND_310P
1202
+ weightToNZ = (getenv("GGML_CANN_WEIGHT_NZ") != nullptr);
1203
+ #endif
1142
1204
  if (!need_transform(tensor->type)) {
1143
1205
  ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
1144
1206
  ACL_MEMCPY_HOST_TO_DEVICE));
1207
+ if (weightToNZ && is_matmul_weight((const ggml_tensor*)tensor)) {
1208
+ weight_format_to_nz(tensor, data, offset);
1209
+ }
1145
1210
  } else {
1146
1211
  void *transform_buffer = malloc(size);
1147
1212
  ggml_backend_cann_transform(tensor, data, transform_buffer);
@@ -1616,16 +1681,18 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1616
1681
  case GGML_OP_UNARY:
1617
1682
  switch (ggml_get_unary_op(dst)) {
1618
1683
  case GGML_UNARY_OP_ABS:
1619
- GGML_CANN_CALL_UNARY_OP(Abs);
1684
+ GGML_CANN_CALL_OP_UNARY(Abs);
1620
1685
  break;
1621
1686
  case GGML_UNARY_OP_NEG:
1622
- GGML_CANN_CALL_UNARY_OP(Neg);
1687
+ GGML_CANN_CALL_OP_UNARY(Neg);
1623
1688
  break;
1624
1689
  case GGML_UNARY_OP_GELU:
1625
- GGML_CANN_CALL_UNARY_OP(Gelu);
1690
+ case GGML_UNARY_OP_GELU_ERF:
1691
+ // aclnnGelu internally uses the erf-based approximation.
1692
+ GGML_CANN_CALL_OP_UNARY(Gelu);
1626
1693
  break;
1627
1694
  case GGML_UNARY_OP_SILU:
1628
- GGML_CANN_CALL_UNARY_OP(Silu);
1695
+ GGML_CANN_CALL_OP_UNARY(Silu);
1629
1696
  break;
1630
1697
  case GGML_UNARY_OP_GELU_QUICK: {
1631
1698
  auto lambda = [](ggml_backend_cann_context& ctx,
@@ -1633,31 +1700,31 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1633
1700
  aclTensor* acl_dst) {
1634
1701
  GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
1635
1702
  };
1636
- ggml_cann_unary_op(lambda, ctx, dst);
1703
+ ggml_cann_op_unary(lambda, ctx, dst);
1637
1704
  } break;
1638
1705
  case GGML_UNARY_OP_TANH:
1639
- GGML_CANN_CALL_UNARY_OP(Tanh);
1706
+ GGML_CANN_CALL_OP_UNARY(Tanh);
1640
1707
  break;
1641
1708
  case GGML_UNARY_OP_RELU:
1642
- GGML_CANN_CALL_UNARY_OP(Relu);
1709
+ GGML_CANN_CALL_OP_UNARY(Relu);
1643
1710
  break;
1644
1711
  case GGML_UNARY_OP_SIGMOID:
1645
- GGML_CANN_CALL_UNARY_OP(Sigmoid);
1712
+ GGML_CANN_CALL_OP_UNARY(Sigmoid);
1646
1713
  break;
1647
1714
  case GGML_UNARY_OP_HARDSIGMOID:
1648
- GGML_CANN_CALL_UNARY_OP(Hardsigmoid);
1715
+ GGML_CANN_CALL_OP_UNARY(Hardsigmoid);
1649
1716
  break;
1650
1717
  case GGML_UNARY_OP_HARDSWISH:
1651
- GGML_CANN_CALL_UNARY_OP(Hardswish);
1718
+ GGML_CANN_CALL_OP_UNARY(Hardswish);
1652
1719
  break;
1653
1720
  case GGML_UNARY_OP_EXP:
1654
- GGML_CANN_CALL_UNARY_OP(Exp);
1721
+ GGML_CANN_CALL_OP_UNARY(Exp);
1655
1722
  break;
1656
1723
  case GGML_UNARY_OP_ELU:
1657
1724
  ggml_cann_elu(ctx, dst);
1658
1725
  break;
1659
1726
  case GGML_UNARY_OP_SGN:
1660
- GGML_CANN_CALL_UNARY_OP(Sign);
1727
+ GGML_CANN_CALL_OP_UNARY(Sign);
1661
1728
  break;
1662
1729
  case GGML_UNARY_OP_STEP:
1663
1730
  ggml_cann_step(ctx, dst);
@@ -1666,6 +1733,31 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1666
1733
  return false;
1667
1734
  }
1668
1735
  break;
1736
+ case GGML_OP_GLU:
1737
+ switch (ggml_get_glu_op(dst)) {
1738
+ case GGML_GLU_OP_REGLU:
1739
+ GGML_CANN_CALL_OP_UNARY_GATED(Relu);
1740
+ break;
1741
+ case GGML_GLU_OP_GEGLU:
1742
+ case GGML_GLU_OP_GEGLU_ERF:
1743
+ // aclnnGelu internally uses the erf-based approximation.
1744
+ GGML_CANN_CALL_OP_UNARY_GATED(Gelu);
1745
+ break;
1746
+ case GGML_GLU_OP_SWIGLU:
1747
+ GGML_CANN_CALL_OP_UNARY_GATED(Silu);
1748
+ break;
1749
+ case GGML_GLU_OP_GEGLU_QUICK: {
1750
+ auto lambda = [](ggml_backend_cann_context& ctx,
1751
+ aclTensor* acl_src,
1752
+ aclTensor* acl_dst) {
1753
+ GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
1754
+ };
1755
+ ggml_cann_op_unary_gated(lambda, ctx, dst);
1756
+ } break;
1757
+ default:
1758
+ return false;
1759
+ }
1760
+ break;
1669
1761
  case GGML_OP_NORM:
1670
1762
  ggml_cann_norm(ctx, dst);
1671
1763
  break;
@@ -1708,7 +1800,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1708
1800
  ggml_cann_binary_op<aclnn_mul>(ctx, dst);
1709
1801
  break;
1710
1802
  case GGML_OP_SQRT:
1711
- GGML_CANN_CALL_UNARY_OP(Sqrt);
1803
+ GGML_CANN_CALL_OP_UNARY(Sqrt);
1712
1804
  break;
1713
1805
  case GGML_OP_CLAMP:
1714
1806
  ggml_cann_clamp(ctx, dst);
@@ -1753,16 +1845,16 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1753
1845
  ggml_cann_argmax(ctx, dst);
1754
1846
  break;
1755
1847
  case GGML_OP_COS:
1756
- ggml_cann_unary_op<aclnn_cos>(ctx, dst);
1848
+ ggml_cann_op_unary<aclnn_cos>(ctx, dst);
1757
1849
  break;
1758
1850
  case GGML_OP_SIN:
1759
- ggml_cann_unary_op<aclnn_sin>(ctx, dst);
1851
+ ggml_cann_op_unary<aclnn_sin>(ctx, dst);
1760
1852
  break;
1761
1853
  case GGML_OP_CONV_TRANSPOSE_1D:
1762
1854
  ggml_cann_conv_transpose_1d(ctx, dst);
1763
1855
  break;
1764
1856
  case GGML_OP_LOG:
1765
- GGML_CANN_CALL_UNARY_OP(Log);
1857
+ GGML_CANN_CALL_OP_UNARY(Log);
1766
1858
  break;
1767
1859
  case GGML_OP_MEAN:
1768
1860
  ggml_cann_mean(ctx, dst);
@@ -2036,10 +2128,23 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2036
2128
  case GGML_UNARY_OP_ELU:
2037
2129
  case GGML_UNARY_OP_SGN:
2038
2130
  case GGML_UNARY_OP_STEP:
2131
+ case GGML_UNARY_OP_GELU_ERF:
2039
2132
  return true;
2040
2133
  default:
2041
2134
  return false;
2042
2135
  }
2136
+ case GGML_OP_GLU:
2137
+ switch (ggml_get_glu_op(op)) {
2138
+ case GGML_GLU_OP_REGLU:
2139
+ case GGML_GLU_OP_GEGLU:
2140
+ case GGML_GLU_OP_SWIGLU:
2141
+ case GGML_GLU_OP_GEGLU_ERF:
2142
+ case GGML_GLU_OP_GEGLU_QUICK:
2143
+ return true;
2144
+ default:
2145
+ return false;
2146
+ }
2147
+ break;
2043
2148
  case GGML_OP_MUL_MAT: {
2044
2149
  switch (op->src[0]->type) {
2045
2150
  case GGML_TYPE_F16:
@@ -2090,6 +2195,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2090
2195
  {
2091
2196
  // TODO: add support
2092
2197
  // ref: https://github.com/ggml-org/llama.cpp/pull/14274
2198
+ #pragma message("TODO: implement F32, F16, BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
2093
2199
  return false;
2094
2200
  } break;
2095
2201
  case GGML_OP_CPY: {
@@ -70,10 +70,12 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
70
70
  if (GGML_OPENMP)
71
71
  find_package(OpenMP)
72
72
  if (OpenMP_FOUND)
73
+ set(GGML_OPENMP_ENABLED "ON" CACHE INTERNAL "")
73
74
  target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP)
74
75
 
75
76
  target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
76
77
  else()
78
+ set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "")
77
79
  message(WARNING "OpenMP not found")
78
80
  endif()
79
81
  endif()
@@ -456,6 +458,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
456
458
  list(APPEND ARCH_FLAGS -march=z16)
457
459
  elseif (${S390X_M} MATCHES "9175|9176")
458
460
  # NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version.
461
+ # binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15.
459
462
  message(STATUS "z17 target")
460
463
  list(APPEND ARCH_FLAGS -march=z17)
461
464
  else()
@@ -494,9 +497,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
494
497
 
495
498
  # Fetch KleidiAI sources:
496
499
  include(FetchContent)
497
- set(KLEIDIAI_COMMIT_TAG "v1.9.0")
500
+ set(KLEIDIAI_COMMIT_TAG "v1.11.0")
498
501
  set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
499
- set(KLEIDIAI_ARCHIVE_MD5 "2a8e1bb55d201557553545536489a017")
502
+ set(KLEIDIAI_ARCHIVE_MD5 "3fe9e5ab964c375c53839296eb71eaa2")
500
503
 
501
504
  if (POLICY CMP0135)
502
505
  cmake_policy(SET CMP0135 NEW)
@@ -544,7 +544,7 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
544
544
  __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) );
545
545
  max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
546
546
  __m128 tmp = max4;
547
- max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
547
+ max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x1 ));
548
548
  const float max_scalar = ((v4f32)max4)[0];
549
549
 
550
550
  // Quantize these floats
@@ -22,9 +22,94 @@
22
22
 
23
23
  #include "kai_common.h"
24
24
 
25
+ #include "simd-mappings.h"
26
+
25
27
  #include "kernels.h"
26
28
 
27
29
  #define NELEMS(x) sizeof(x) / sizeof(*x)
30
+
31
+ static const size_t INT4_PER_BYTE = 2;
32
+ static const size_t INT4_BITS = 4;
33
+ static const int Q4_0_ZERO_POINT = 8;
34
+ const size_t INT4_PER_UINT16 = 4;
35
+
36
+ static void dequantize_row_qsi4c32pscalef16(
37
+ const void *packed_data,
38
+ int32_t row_idx,
39
+ int64_t nc,
40
+ float *out,
41
+ size_t nr_pack,
42
+ size_t packed_row_stride,
43
+ size_t kr,
44
+ size_t bl,
45
+ size_t num_bytes_multiplier
46
+ ) {
47
+ size_t group_idx = row_idx / nr_pack;
48
+ size_t row_in_group = row_idx % nr_pack;
49
+ const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
50
+ size_t num_blocks = nc / bl;
51
+ const uint8_t *block_ptr = packed_group;
52
+
53
+ for (size_t b = 0; b < num_blocks; ++b) {
54
+ uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier));
55
+ float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
56
+
57
+ const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier;
58
+ size_t num_segments = bl / kr;
59
+ size_t num_bytes_per_segment = kr / INT4_PER_BYTE;
60
+
61
+ for (size_t s = 0; s < num_segments; ++s) {
62
+ const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment;
63
+ const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment;
64
+ for (size_t k = 0; k < num_bytes_per_segment; ++k) {
65
+ uint8_t byte = qbytes[k] ^ 0x88;
66
+ int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT;
67
+ int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT;
68
+ out[b * bl + s * num_bytes_per_segment + k] = x0 * scale;
69
+ out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale;
70
+ }
71
+ }
72
+ block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment;
73
+ }
74
+ }
75
+
76
+ static void dequantize_row_qsi4c32ps1s0scalef16(
77
+ const void *packed_data,
78
+ int32_t row_idx,
79
+ int64_t k,
80
+ float *out,
81
+ size_t nr,
82
+ size_t packed_row_stride,
83
+ size_t kr,
84
+ size_t bl,
85
+ size_t num_bytes_multiplier
86
+ ) {
87
+ const size_t num_blocks = k / bl;
88
+ const size_t bl4 = bl / INT4_PER_UINT16;
89
+
90
+ size_t group_idx = row_idx / nr;
91
+ size_t row_in_group = row_idx % nr;
92
+
93
+ const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
94
+ const uint16_t *qdata = (const uint16_t *)packed_group;
95
+ const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier));
96
+
97
+ for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) {
98
+ uint16_t scale_f16 = scales[row_in_group + block_idx * nr];
99
+ float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
100
+
101
+ for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) {
102
+ uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group];
103
+
104
+ for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) {
105
+ int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT;
106
+ out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale;
107
+ }
108
+ }
109
+ }
110
+ GGML_UNUSED(kr);
111
+ }
112
+
28
113
  static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
29
114
  #if defined(__ARM_FEATURE_SME)
30
115
  {
@@ -63,8 +148,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
63
148
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
64
149
  },
65
150
  /* .rhs_info = */ {
66
- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
67
- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
151
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
152
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
153
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
154
+ /* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
68
155
  },
69
156
  /* .required_cpu = */ CPU_FEATURE_SME,
70
157
  /* .lhs_type = */ GGML_TYPE_F32,
@@ -107,8 +194,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
107
194
  /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
108
195
  },
109
196
  /* .rhs_info = */ {
110
- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
111
- /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
197
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
198
+ /* .packed_stride = */ NULL,
199
+ /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
200
+ /* .to_float = */ NULL,
112
201
  },
113
202
  /* .required_cpu = */ CPU_FEATURE_SME,
114
203
  /* .lhs_type = */ GGML_TYPE_F32,
@@ -154,8 +243,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
154
243
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
155
244
  },
156
245
  /* .rhs_info = */ {
157
- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
158
- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
246
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
247
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
248
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
249
+ /* .to_float = */ dequantize_row_qsi4c32pscalef16,
159
250
  },
160
251
  /* .required_cpu = */ CPU_FEATURE_DOTPROD,
161
252
  /* .lhs_type = */ GGML_TYPE_F32,
@@ -200,8 +291,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
200
291
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
201
292
  },
202
293
  /* .rhs_info = */ {
203
- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
204
- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
294
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
295
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
296
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
297
+ /* .to_float = */ dequantize_row_qsi4c32pscalef16,
205
298
  },
206
299
  /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
207
300
  /* .lhs_type = */ GGML_TYPE_F32,
@@ -247,8 +340,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
247
340
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
248
341
  },
249
342
  /* .rhs_info = */ {
250
- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
251
- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
343
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
344
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
345
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
346
+ /* .to_float = */ dequantize_row_qsi4c32pscalef16,
252
347
  },
253
348
  /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
254
349
  /* .lhs_type = */ GGML_TYPE_F32,
@@ -293,8 +388,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
293
388
  /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
294
389
  },
295
390
  /* .rhs_info = */ {
296
- /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
297
- /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
391
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
392
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
393
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
394
+ /* .to_float = */ dequantize_row_qsi4c32pscalef16,
298
395
  },
299
396
  /* .required_cpu = */ CPU_FEATURE_DOTPROD,
300
397
  /* .lhs_type = */ GGML_TYPE_F32,