@novastera-oss/llamarn 0.2.1 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. package/README.md +80 -14
  2. package/RNLlamaCpp.podspec +10 -3
  3. package/android/CMakeLists.txt +8 -0
  4. package/android/src/main/cpp/include/llama.h +62 -125
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  13. package/cpp/PureCppImpl.cpp +9 -27
  14. package/cpp/SystemUtils.h +2 -2
  15. package/cpp/build-info.cpp +2 -2
  16. package/cpp/llama.cpp/README.md +11 -3
  17. package/cpp/llama.cpp/build-xcframework.sh +1 -0
  18. package/cpp/llama.cpp/common/CMakeLists.txt +8 -2
  19. package/cpp/llama.cpp/common/arg.cpp +153 -113
  20. package/cpp/llama.cpp/common/chat-parser.cpp +379 -0
  21. package/cpp/llama.cpp/common/chat-parser.h +117 -0
  22. package/cpp/llama.cpp/common/chat.cpp +847 -699
  23. package/cpp/llama.cpp/common/chat.h +73 -6
  24. package/cpp/llama.cpp/common/common.cpp +50 -82
  25. package/cpp/llama.cpp/common/common.h +21 -17
  26. package/cpp/llama.cpp/common/json-partial.cpp +255 -0
  27. package/cpp/llama.cpp/common/json-partial.h +37 -0
  28. package/cpp/llama.cpp/common/minja/chat-template.hpp +9 -5
  29. package/cpp/llama.cpp/common/minja/minja.hpp +69 -36
  30. package/cpp/llama.cpp/common/regex-partial.cpp +204 -0
  31. package/cpp/llama.cpp/common/regex-partial.h +56 -0
  32. package/cpp/llama.cpp/common/sampling.cpp +7 -8
  33. package/cpp/llama.cpp/convert_hf_to_gguf.py +453 -118
  34. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +120 -68
  35. package/cpp/llama.cpp/ggml/CMakeLists.txt +2 -1
  36. package/cpp/llama.cpp/ggml/cmake/common.cmake +25 -0
  37. package/cpp/llama.cpp/ggml/include/ggml-opt.h +49 -28
  38. package/cpp/llama.cpp/ggml/include/ggml.h +26 -7
  39. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +16 -10
  40. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +4 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +2 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +604 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +42 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +54 -2
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +50 -51
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -2
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -9
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +779 -19
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +22 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +322 -100
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +117 -1
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +220 -49
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/acc.cu +40 -26
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -1
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +11 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +15 -7
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +266 -64
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +49 -4
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +48 -4
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +2 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +5 -1
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +2 -0
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +7 -6
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +1 -1
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +10 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-impl.h +1 -1
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +4 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +99 -17
  75. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +200 -2
  76. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +8 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +6 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +972 -178
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +373 -190
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -10
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +101 -5
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +31 -33
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +1 -0
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +29 -2
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +4 -5
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  96. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +9 -1
  97. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +84 -72
  98. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  100. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -3
  101. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +324 -129
  102. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +1 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +31 -2
  104. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +95 -68
  105. package/cpp/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +1 -0
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +22 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -2
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -4
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +2 -3
  110. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +69 -43
  111. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +2 -14
  112. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -91
  113. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -181
  114. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +17 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  116. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +6 -152
  117. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  118. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  119. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +2 -118
  120. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +1 -1
  121. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +12 -1
  122. package/cpp/llama.cpp/ggml/src/ggml.c +107 -36
  123. package/cpp/llama.cpp/ggml/src/gguf.cpp +33 -33
  124. package/cpp/llama.cpp/gguf-py/gguf/constants.py +100 -15
  125. package/cpp/llama.cpp/gguf-py/gguf/gguf_reader.py +1 -1
  126. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +44 -12
  127. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py +21 -10
  128. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +5 -2
  129. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +128 -31
  130. package/cpp/llama.cpp/gguf-py/gguf/utility.py +1 -1
  131. package/cpp/llama.cpp/gguf-py/pyproject.toml +1 -1
  132. package/cpp/llama.cpp/include/llama.h +62 -125
  133. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +1 -1
  134. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +1 -1
  135. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +1 -1
  136. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +1 -1
  137. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +1 -1
  138. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +1 -1
  139. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +1 -1
  140. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +1 -1
  141. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +1 -1
  142. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +1 -1
  143. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +1 -1
  144. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +1 -1
  145. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +1 -1
  146. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +1 -1
  147. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +1 -1
  148. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +1 -1
  149. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +1 -1
  150. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +1 -1
  151. package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
  152. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +1 -1
  153. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +1 -1
  154. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +1 -1
  155. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +1 -1
  156. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +1 -1
  157. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +1 -1
  158. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +1 -1
  159. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +1 -1
  160. package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +62 -0
  161. package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +85 -0
  162. package/cpp/llama.cpp/models/templates/README.md +2 -0
  163. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +5 -1
  164. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +5 -1
  165. package/cpp/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
  166. package/cpp/llama.cpp/requirements/requirements-gguf_editor_gui.txt +1 -1
  167. package/cpp/llama.cpp/src/CMakeLists.txt +2 -0
  168. package/cpp/llama.cpp/src/llama-arch.cpp +6 -0
  169. package/cpp/llama.cpp/src/llama-arch.h +2 -0
  170. package/cpp/llama.cpp/src/llama-batch.cpp +3 -1
  171. package/cpp/llama.cpp/src/llama-context.cpp +340 -123
  172. package/cpp/llama.cpp/src/llama-context.h +30 -0
  173. package/cpp/llama.cpp/src/llama-cparams.cpp +4 -0
  174. package/cpp/llama.cpp/src/llama-cparams.h +2 -0
  175. package/cpp/llama.cpp/src/llama-grammar.cpp +12 -2
  176. package/cpp/llama.cpp/src/llama-graph.cpp +157 -247
  177. package/cpp/llama.cpp/src/llama-graph.h +52 -7
  178. package/cpp/llama.cpp/src/llama-hparams.cpp +17 -1
  179. package/cpp/llama.cpp/src/llama-hparams.h +37 -5
  180. package/cpp/llama.cpp/src/llama-kv-cache.cpp +742 -481
  181. package/cpp/llama.cpp/src/llama-kv-cache.h +196 -99
  182. package/cpp/llama.cpp/src/llama-kv-cells.h +379 -0
  183. package/cpp/llama.cpp/src/llama-memory.h +4 -3
  184. package/cpp/llama.cpp/src/llama-model-loader.cpp +22 -17
  185. package/cpp/llama.cpp/src/llama-model-saver.cpp +281 -0
  186. package/cpp/llama.cpp/src/llama-model-saver.h +37 -0
  187. package/cpp/llama.cpp/src/llama-model.cpp +529 -172
  188. package/cpp/llama.cpp/src/llama-model.h +6 -1
  189. package/cpp/llama.cpp/src/llama-quant.cpp +15 -13
  190. package/cpp/llama.cpp/src/llama-sampling.cpp +2 -2
  191. package/cpp/llama.cpp/src/llama-vocab.cpp +35 -8
  192. package/cpp/llama.cpp/src/llama-vocab.h +6 -0
  193. package/cpp/llama.cpp/src/llama.cpp +14 -0
  194. package/cpp/rn-completion.cpp +60 -5
  195. package/ios/include/chat.h +73 -6
  196. package/ios/include/common/minja/chat-template.hpp +9 -5
  197. package/ios/include/common/minja/minja.hpp +69 -36
  198. package/ios/include/common.h +21 -17
  199. package/ios/include/llama.h +62 -125
  200. package/ios/libs/llama.xcframework/Info.plist +19 -19
  201. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  202. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4617 -4487
  203. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  204. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +26 -7
  205. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +62 -125
  206. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  207. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  208. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  209. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3557 -3435
  210. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  211. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  212. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  213. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3559 -3437
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +237 -0
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +26 -7
  219. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +62 -125
  220. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +237 -0
  221. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +26 -7
  222. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +62 -125
  223. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  224. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +237 -0
  225. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +26 -7
  226. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +62 -125
  227. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  228. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  229. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4616 -4487
  231. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  232. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +26 -7
  233. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +62 -125
  234. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  235. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  236. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4637 -4508
  237. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3556 -3435
  238. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  239. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  240. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  241. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  242. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  243. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4653 -4523
  244. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  245. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +26 -7
  246. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +62 -125
  247. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  248. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  249. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4674 -4544
  250. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3587 -3465
  251. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  252. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  253. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  254. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  255. package/package.json +1 -1
  256. package/cpp/llama.cpp/common/stb_image.h +0 -7988
  257. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +0 -112
  258. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.out +0 -46
  259. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +0 -112
  260. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +0 -46
  261. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +0 -112
  262. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +0 -46
  263. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.inp +0 -112
  264. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.out +0 -46
  265. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +0 -112
  266. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.out +0 -46
  267. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +0 -112
  268. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +0 -46
@@ -65,6 +65,8 @@
65
65
  #include <aclnnop/aclnn_eq_tensor.h>
66
66
  #include <aclnnop/aclnn_gt_scalar.h>
67
67
  #include <aclnnop/aclnn_pow.h>
68
+ #include <aclnnop/aclnn_grouped_matmul_v2.h>
69
+ #include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
68
70
  #include <float.h>
69
71
 
70
72
  #include <cmath>
@@ -73,11 +75,13 @@
73
75
  #include <vector>
74
76
 
75
77
  #include "ggml-impl.h"
78
+ #include "ggml.h"
76
79
 
77
80
  #define GGML_COMMON_DECL_C
78
81
 
79
82
  #include "../ggml-common.h"
80
83
 
84
+
81
85
  void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0,
82
86
  aclTensor ** acl_src1, aclTensor ** acl_dst) {
83
87
  GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0));
@@ -2587,3 +2591,603 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2587
2591
 
2588
2592
  ggml_cann_release_resources(ctx, acl_src, acl_dst, alpha);
2589
2593
  }
2594
+
2595
+ /**
2596
+ * @brief Performs expert-specific matrix multiplication (MoE) with
2597
+ * floating-point precision using the CANN backend.
2598
+ *
2599
+ * This function executes a matrix multiplication operation tailored for
2600
+ * Mixture of Experts (MoE) models, where the input tensor is multiplied
2601
+ * with expert-specific weight matrices. It uses the CANN backend for
2602
+ * efficient computation and stores the result in the destination tensor `dst`.
2603
+ * The operation may leverage identity-based optimizations or routing masks
2604
+ * as part of sparse expert selection.
2605
+ *
2606
+ * @param ctx The context for executing CANN backend operations.
2607
+ * @param dst The destination tensor where the MoE multiplication result
2608
+ * will be stored.
2609
+ *
2610
+ * @note This function assumes floating-point data types and is designed for
2611
+ * MoE architectures, possibly involving sparse expert routing.
2612
+ */
2613
+ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2614
+ //dst [M, K, N, 1]
2615
+ ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
2616
+ ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
2617
+ ggml_tensor * ids = dst->src[2]; //ids [K, N]
2618
+
2619
+ GGML_TENSOR_BINARY_OP_LOCALS
2620
+
2621
+ // copy index from npu to cpu
2622
+ int64_t n_as = ne02; // A
2623
+ int64_t n_ids = ids->ne[0]; // K
2624
+
2625
+ std::vector<char> ids_host(ggml_nbytes(ids));
2626
+ ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids),
2627
+ ACL_MEMCPY_DEVICE_TO_HOST);
2628
+ ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
2629
+
2630
+ char * src0_original = (char *) src0->data;
2631
+ char * src1_original = (char *) src1->data;
2632
+ char * dst_original = (char *) dst->data;
2633
+ size_t ori_src0_nb[4] = {nb00, nb01, nb02, nb03};
2634
+
2635
+ // src0 is F16, src1 is F32, dst is F32
2636
+ ggml_cann_pool_alloc src0_cast_allocator;
2637
+ if (src0->type == GGML_TYPE_F16) {
2638
+ src0_cast_allocator.alloc(ctx.pool(), sizeof(float) * ggml_nelements(src0));
2639
+ void* src0_cast_buf = src0_cast_allocator.get();
2640
+
2641
+ size_t cast_nb[GGML_MAX_DIMS];
2642
+ cast_nb[0] = sizeof(float_t);
2643
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
2644
+ cast_nb[i] = cast_nb[i - 1] * src0->ne[i - 1];
2645
+ }
2646
+
2647
+ aclTensor* acl_src0_f16 = ggml_cann_create_tensor(src0);
2648
+ aclTensor* acl_cast = ggml_cann_create_tensor(src0_cast_buf,
2649
+ ACL_FLOAT, sizeof(float), src0->ne, cast_nb, 4);
2650
+ GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src0_f16, ACL_FLOAT, acl_cast);
2651
+ ggml_cann_release_resources(ctx, acl_cast, acl_src0_f16);
2652
+
2653
+ src0_original = (char *) src0_cast_buf;
2654
+ memcpy(ori_src0_nb, cast_nb, sizeof(ori_src0_nb));
2655
+ }
2656
+
2657
+ std::vector<aclTensor*> src0_tensor_vec;
2658
+ std::vector<aclTensor*> src1_tensor_vec;
2659
+ std::vector<aclTensor*> dst_tensor_vec;
2660
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2661
+ for (int64_t id = 0; id < n_ids; id++) {
2662
+ // src0_row [M, D] -> weight && permute
2663
+ int64_t src0_ne[2] = {ne01, ne00};
2664
+ size_t src0_nb[2] = {ori_src0_nb[1], ori_src0_nb[0]};
2665
+ // src1_row [D, 1] -> input
2666
+ int64_t src1_ne[2] = {ne10, 1};
2667
+ size_t src1_nb[2] = {nb10, nb11};
2668
+ // dst_row [M, 1] -> out
2669
+ int64_t dst_ne[2] = {ne0, 1};
2670
+ size_t dst_nb[2] = {nb0, nb1};
2671
+
2672
+ // expert index
2673
+ int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2674
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
2675
+
2676
+ // If B = 1 (broadcast), always use 0; otherwise, use id.
2677
+ int64_t i11 = (ne11 == 1 ? 0 : id);
2678
+ int64_t i12 = iid1;
2679
+
2680
+ int64_t i1 = id;
2681
+ int64_t i2 = i12;
2682
+
2683
+ void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2];
2684
+ void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
2685
+ void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
2686
+
2687
+ aclTensor* acl_src0 = ggml_cann_create_tensor(src0_tmp_ptr,
2688
+ ACL_FLOAT, sizeof(float),
2689
+ src0_ne, src0_nb, 2);
2690
+ aclTensor* acl_src1 = ggml_cann_create_tensor(src1_tmp_ptr,
2691
+ ACL_FLOAT, sizeof(float),
2692
+ src1_ne, src1_nb, 2);
2693
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst_tmp_ptr,
2694
+ ACL_FLOAT, sizeof(float),
2695
+ dst_ne, dst_nb, 2);
2696
+
2697
+ src0_tensor_vec.push_back(acl_src0);
2698
+ src1_tensor_vec.push_back(acl_src1);
2699
+ dst_tensor_vec.push_back(acl_dst);
2700
+ }
2701
+ }
2702
+
2703
+ size_t GROUP_SIZE = 128;
2704
+ // GroupedMatmulV2 required tensor_list.size < 128
2705
+ for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
2706
+ // split and call GroupedMatmulV2
2707
+ size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
2708
+ std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
2709
+ std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
2710
+ std::vector<aclTensor*> dst_tensor_vec_split(dst_tensor_vec.begin() + i, dst_tensor_vec.begin() + end);
2711
+
2712
+ aclTensorList* src0_tensor_list = aclCreateTensorList(src0_tensor_vec_split.data(), src0_tensor_vec_split.size());
2713
+ aclTensorList* src1_tensor_list = aclCreateTensorList(src1_tensor_vec_split.data(), src1_tensor_vec_split.size());
2714
+ aclTensorList* dst_tensor_list = aclCreateTensorList(dst_tensor_vec_split.data(), dst_tensor_vec_split.size());
2715
+
2716
+ GGML_CANN_CALL_ACLNN_OP(ctx, GroupedMatmulV2, src1_tensor_list, src0_tensor_list,
2717
+ nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, -1, dst_tensor_list);
2718
+
2719
+ ggml_cann_release_resources(ctx, src0_tensor_list, src1_tensor_list, dst_tensor_list);
2720
+ }
2721
+ return;
2722
+ }
2723
+
2724
+ /**
2725
+ * @brief Performs expert-specific matrix multiplication (MoE) with
2726
+ * quantized precision using the CANN backend.
2727
+ *
2728
+ * This function executes a matrix multiplication operation tailored for
2729
+ * Mixture of Experts (MoE) models, where the input tensor is multiplied
2730
+ * with expert-specific quantized weight matrices. It leverages the CANN
2731
+ * backend to perform efficient low-precision computations and stores the
2732
+ * quantized result in the destination tensor `dst`.
2733
+ *
2734
+ * Quantization techniques reduce memory footprint and improve performance
2735
+ * by using lower-bit representations (e.g., int8) instead of floating-point.
2736
+ * This function is designed to work with such formats and may incorporate
2737
+ * optimizations like identity-based fast paths or routing masks for sparse
2738
+ * expert selection.
2739
+ *
2740
+ * @param ctx The context for executing CANN backend operations.
2741
+ * @param dst The destination tensor where the quantized MoE multiplication result
2742
+ * will be stored.
2743
+ *
2744
+ * @note This function assumes quantized data types and is designed for
2745
+ * MoE architectures with potential sparse expert routing.
2746
+ */
2747
+ static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2748
+ // TODO: Use aclnnGroupedMatMul
2749
+ //dst [M, K, N, 1]
2750
+ ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
2751
+ ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
2752
+ ggml_tensor * ids = dst->src[2]; //ids [K, N]
2753
+
2754
+ GGML_TENSOR_BINARY_OP_LOCALS
2755
+
2756
+ // copy index from npu to cpu
2757
+ int64_t n_as = ne02; // A
2758
+ int64_t n_ids = ids->ne[0]; // K
2759
+
2760
+ std::vector<char> ids_host(ggml_nbytes(ids));
2761
+ ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids),
2762
+ ACL_MEMCPY_DEVICE_TO_HOST);
2763
+ ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
2764
+
2765
+ char * src0_original = (char *) src0->data;
2766
+ char * src1_original = (char *) src1->data;
2767
+ char * dst_original = (char *) dst->data;
2768
+
2769
+ ggml_tensor src0_row = *src0;
2770
+ ggml_tensor src1_row = *src1;
2771
+ ggml_tensor dst_row = *dst;
2772
+
2773
+ const enum ggml_type type = dst->src[0]->type;
2774
+ float weight_elem_size;
2775
+ if (type == GGML_TYPE_Q4_0) {
2776
+ weight_elem_size = float(sizeof(uint8_t)) / 2;
2777
+ } else if (type == GGML_TYPE_Q8_0) {
2778
+ weight_elem_size = float(sizeof(uint8_t));
2779
+ } else {
2780
+ GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 ");
2781
+ }
2782
+
2783
+ // src0_row [D, M, 1, 1] weight without permute
2784
+ src0_row.ne[2] = 1;
2785
+ src0_row.ne[3] = 1;
2786
+ src0_row.nb[0] = weight_elem_size;
2787
+ src0_row.nb[1] = weight_elem_size * ne00;
2788
+ src0_row.nb[2] = weight_elem_size * ne00;
2789
+ src0_row.nb[3] = weight_elem_size * ne00;
2790
+ size_t weight_stride = ne00 * ne01 * weight_elem_size;
2791
+ size_t weight_size = weight_stride * ne02 * ne03;
2792
+
2793
+ // scale [D, M, 1, 1] -> scale && permute
2794
+ size_t scale_elem_size = sizeof(uint16_t);
2795
+ size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
2796
+
2797
+ // src1_row [D, 1, 1, 1] -> input
2798
+ src1_row.ne[1] = 1;
2799
+ src1_row.ne[2] = 1;
2800
+ src1_row.ne[3] = 1;
2801
+ src1_row.nb[2] = nb11;
2802
+ src1_row.nb[3] = nb11;
2803
+
2804
+ // dst_row [M, 1, 1, 1] -> out
2805
+ dst_row.ne[1] = 1;
2806
+ dst_row.ne[2] = 1;
2807
+ dst_row.ne[3] = 1;
2808
+ dst_row.nb[2] = nb1;
2809
+ dst_row.nb[3] = nb1;
2810
+
2811
+ //create weight for one row
2812
+ ggml_cann_pool_alloc weight_allocator(ctx.pool());
2813
+ void* weight_buffer = weight_allocator.alloc(nb02);
2814
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2815
+ for (int64_t id = 0; id < n_ids; id++) {
2816
+ // expert index
2817
+ int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2818
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
2819
+
2820
+ // If B = 1 (broadcast), always use 0; otherwise, use id.
2821
+ int64_t i11 = (ne11 == 1 ? 0 : id);
2822
+ int64_t i12 = iid1;
2823
+
2824
+ int64_t i1 = id;
2825
+ int64_t i2 = i12;
2826
+
2827
+ void* src0_tmp_ptr = src0_original + i02*weight_stride;
2828
+ void* scale_tmp_ptr = src0_original + weight_size + i02*scale_stride;
2829
+ void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
2830
+ void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
2831
+
2832
+ // mem cpy
2833
+ ggml_cann_async_memcpy(ctx, weight_buffer, src0_tmp_ptr, weight_stride,
2834
+ ACL_MEMCPY_DEVICE_TO_DEVICE);
2835
+ void* scale_buffer = (char*)weight_buffer + weight_stride;
2836
+ ggml_cann_async_memcpy(ctx, scale_buffer, scale_tmp_ptr, scale_stride,
2837
+ ACL_MEMCPY_DEVICE_TO_DEVICE);
2838
+
2839
+ src0_row.data = weight_buffer;
2840
+ src1_row.data = src1_tmp_ptr;
2841
+ dst_row.data = dst_tmp_ptr;
2842
+ dst_row.src[0] = &src0_row;
2843
+ dst_row.src[1] = &src1_row;
2844
+
2845
+ ggml_cann_mul_mat(ctx, &dst_row);
2846
+ }
2847
+ }
2848
+ return;
2849
+ }
2850
+
2851
+ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2852
+ const enum ggml_type type = dst->src[0]->type;
2853
+ switch (type) {
2854
+ case GGML_TYPE_F32:
2855
+ case GGML_TYPE_F16:
2856
+ ggml_cann_mul_mat_id_fp(ctx, dst);
2857
+ break;
2858
+ case GGML_TYPE_Q4_0:
2859
+ case GGML_TYPE_Q8_0:
2860
+ ggml_cann_mul_mat_id_quant(ctx, dst);
2861
+ break;
2862
+ default:
2863
+ GGML_ABORT("Unsupported type for mul_mat_id");
2864
+ break;
2865
+ }
2866
+ }
2867
+
2868
+ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2869
+
2870
+ ggml_tensor* src0 = dst->src[0]; // q, fp32
2871
+ ggml_tensor* src1 = dst->src[1]; // k, fp16
2872
+ ggml_tensor* src2 = dst->src[2]; // v, fp16
2873
+ ggml_tensor* src3 = dst->src[3]; // mask, fp16
2874
+
2875
+ float maxBias = 0.0f;
2876
+ float scaleValue = 1.0f;
2877
+ float logitSoftcap = 0.0f;
2878
+ memcpy(&scaleValue, (float*)dst->op_params + 0, sizeof(float));
2879
+ memcpy(&maxBias, (float*)dst->op_params + 1, sizeof(float));
2880
+ memcpy(&logitSoftcap, (float*)dst->op_params + 2, sizeof(float));
2881
+
2882
+ if(logitSoftcap == 0.0f){
2883
+ size_t faElemSize = sizeof(uint16_t);
2884
+ auto faDataType = ACL_FLOAT16; //ACL_BF16;
2885
+
2886
+ aclTensor* acl_src0_f16_tensor = nullptr;
2887
+ aclTensor* acl_src1_f16_tensor = nullptr;
2888
+ aclTensor* acl_src2_f16_tensor = nullptr;
2889
+ aclTensor* acl_dst_f16_tensor = nullptr;
2890
+
2891
+ // Step 1: cast the src0 (Query) to fp16 if needed
2892
+ ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());
2893
+ void* src0_f16_buffer = nullptr;
2894
+
2895
+ if(ggml_cann_type_mapping(src0->type) != faDataType){
2896
+ aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0);
2897
+ src0_f16_buffer = src0_f16_allocator.alloc(
2898
+ ggml_nelements(src0) * faElemSize);
2899
+
2900
+ int64_t* src0_f16_ne = src0->ne;
2901
+ size_t src0_f16_nb[GGML_MAX_DIMS];
2902
+ src0_f16_nb[0] = sizeof(uint16_t);
2903
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
2904
+ src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1];
2905
+ }
2906
+
2907
+ acl_src0_f16_tensor = ggml_cann_create_tensor(
2908
+ src0_f16_buffer, faDataType, faElemSize,
2909
+ src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS
2910
+ );
2911
+ aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
2912
+ ggml_cann_release_resources(ctx, acl_src0_f32_tensor);
2913
+ }else{
2914
+ acl_src0_f16_tensor = ggml_cann_create_tensor(src0);
2915
+ }
2916
+
2917
+ // Step 2: create the acl tensors for src1 (Key), src2 (Value),
2918
+ // and the direct output from FusedInferAttention
2919
+
2920
+ acl_src1_f16_tensor = ggml_cann_create_tensor(src1);
2921
+ acl_src2_f16_tensor = ggml_cann_create_tensor(src2);
2922
+
2923
+ ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
2924
+ void* out_f16_buffer = out_f16_allocator.alloc(
2925
+ ggml_nelements(dst) * faElemSize);
2926
+
2927
+ int64_t* out_f16_ne = src0->ne;
2928
+ size_t out_f16_nb[GGML_MAX_DIMS];
2929
+ out_f16_nb[0] = faElemSize;
2930
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
2931
+ out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
2932
+ }
2933
+
2934
+ acl_dst_f16_tensor = ggml_cann_create_tensor(
2935
+ out_f16_buffer, faDataType, faElemSize,
2936
+ out_f16_ne, out_f16_nb, GGML_MAX_DIMS
2937
+ );
2938
+
2939
+ // Step 3: create the PSEShift tensor if needed
2940
+ // this tensor is considered as mask (f16) in the llama.cpp
2941
+
2942
+ aclTensor* bcast_pse_tensor = nullptr;
2943
+ int64_t bcast_pse_ne[GGML_MAX_DIMS];
2944
+ size_t bcast_pse_nb[GGML_MAX_DIMS];
2945
+ ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());
2946
+ void* bcast_pse_buffer = nullptr;
2947
+
2948
+ if(src3 != nullptr){
2949
+ bcast_pse_buffer = bcast_pse_allocator.alloc(
2950
+ ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));
2951
+
2952
+ if(src0->ne[1] > 1){
2953
+ // Case 1: broadcast pse for prefill stage with multiple head
2954
+ aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
2955
+ bcast_pse_ne[0] = src3->ne[0];
2956
+ bcast_pse_ne[1] = src3->ne[1];
2957
+ bcast_pse_ne[2] = src0->ne[2];
2958
+ bcast_pse_ne[3] = src3->ne[3];
2959
+
2960
+ bcast_pse_nb[0] = sizeof(uint16_t);
2961
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
2962
+ bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
2963
+ }
2964
+
2965
+ bcast_pse_tensor = ggml_cann_create_tensor(
2966
+ bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
2967
+ bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
2968
+
2969
+ int64_t repeats[] = {1, src0->ne[2], 1, 1};
2970
+ aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
2971
+
2972
+ ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
2973
+ }else{
2974
+ // Case 2: trunc the first row and broadcast pse for decode stage with multiple head
2975
+ int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]};
2976
+ size_t* trunc_pse_nb = src3->nb;
2977
+
2978
+ aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
2979
+ src3->data, ACL_FLOAT16, sizeof(uint16_t),
2980
+ trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
2981
+
2982
+ bcast_pse_ne[0] = src3->ne[0];
2983
+ bcast_pse_ne[1] = src0->ne[1];
2984
+ bcast_pse_ne[2] = src0->ne[2];
2985
+ bcast_pse_ne[3] = src3->ne[3];
2986
+
2987
+ bcast_pse_nb[0] = sizeof(uint16_t);
2988
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
2989
+ bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
2990
+ }
2991
+
2992
+ bcast_pse_tensor = ggml_cann_create_tensor(
2993
+ bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
2994
+ bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
2995
+
2996
+ int64_t repeats[] = {1, src0->ne[2], 1, 1};
2997
+ aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
2998
+
2999
+ ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
3000
+ }
3001
+
3002
+ // Compute the slope if needed. Derived from ggml_cann_softmax().
3003
+ if(maxBias != 0.0f){
3004
+ // alibi
3005
+ const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3];
3006
+ const int64_t n_head = src0->ne[2];
3007
+ const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
3008
+ float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor);
3009
+ float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor);
3010
+ // init arange
3011
+ ggml_cann_pool_alloc arange_allocator(ctx.pool(),
3012
+ ne2_ne3 * faElemSize);
3013
+ void* tmp_arange_buffer = arange_allocator.get();
3014
+
3015
+ // arange1: [1, ..., n_heads_log2_floor+1)
3016
+ float start = 1;
3017
+ float stop = n_heads_log2_floor + 1;
3018
+ float step = 1;
3019
+ int64_t n_elements_arange = n_heads_log2_floor;
3020
+
3021
+ int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
3022
+ size_t tmp_arange1_nb[] = {faElemSize};
3023
+ aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
3024
+ tmp_arange_buffer, faDataType, faElemSize,
3025
+ tmp_arange1_ne, tmp_arange1_nb,
3026
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3027
+
3028
+ aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
3029
+
3030
+ aclTensor* tmp_arange2_tensor = nullptr;
3031
+ if (n_heads_log2_floor < ne2_ne3) {
3032
+ // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
3033
+ start = 1;
3034
+ stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
3035
+ step = 2;
3036
+ n_elements_arange = ne2_ne3 - n_heads_log2_floor;
3037
+ int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
3038
+ size_t tmp_arange2_nb[] = {faElemSize};
3039
+
3040
+ aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
3041
+ (char*)tmp_arange_buffer +
3042
+ n_heads_log2_floor * faElemSize,
3043
+ faDataType, faElemSize,
3044
+ tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3045
+ aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
3046
+ n_elements_arange);
3047
+ }
3048
+
3049
+ // init mk_base
3050
+ ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
3051
+ ne2_ne3 * faElemSize);
3052
+ void* tmp_mk_base_buffer = mk_base_allocator.get();
3053
+ int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
3054
+ size_t tmp_mk_base1_nb[] = {faElemSize};
3055
+ aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
3056
+ tmp_mk_base_buffer, faDataType, faElemSize,
3057
+ tmp_mk_base1_ne, tmp_mk_base1_nb,
3058
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3059
+
3060
+ aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
3061
+
3062
+ aclTensor* tmp_mk_base2_tensor = nullptr;
3063
+ if (n_heads_log2_floor < ne2_ne3) {
3064
+ int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
3065
+ size_t tmp_mk_base2_nb[] = {faElemSize};
3066
+ aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
3067
+ (char*)tmp_mk_base_buffer +
3068
+ n_heads_log2_floor * faElemSize,
3069
+ faDataType, faElemSize,
3070
+ tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3071
+ aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
3072
+ }
3073
+
3074
+ // init mk
3075
+ int64_t tmp_mk_base_ne[] = {ne2_ne3};
3076
+ size_t tmp_mk_base_nb[] = {faElemSize};
3077
+ aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
3078
+ tmp_mk_base_buffer, faDataType, faElemSize,
3079
+ tmp_mk_base_ne, tmp_mk_base_nb,
3080
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3081
+ aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
3082
+ tmp_arange_buffer, faDataType, faElemSize,
3083
+ tmp_mk_base_ne, tmp_mk_base_nb,
3084
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3085
+ aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
3086
+
3087
+ // reshape mk
3088
+ int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]};
3089
+ size_t tmp_mk_nb[GGML_MAX_DIMS];
3090
+ tmp_mk_nb[0] = faElemSize;
3091
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
3092
+ tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
3093
+ }
3094
+ aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
3095
+ tmp_mk_base_buffer, faDataType, faElemSize,
3096
+ tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
3097
+ ACL_FORMAT_ND);
3098
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor);
3099
+
3100
+ ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
3101
+ tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
3102
+ tmp_arange_tensor, tmp_mk_tensor);
3103
+ }
3104
+ }
3105
+
3106
+ // Step 4: set the inputs for FusedInferAttention.
3107
+ int kvTensorNum = 1;
3108
+ aclTensor* acl_q_tensor = acl_src0_f16_tensor;
3109
+ aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor};
3110
+ aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor};
3111
+ auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
3112
+ auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
3113
+
3114
+ int64_t numHeads = src0->ne[2]; // N
3115
+ int64_t numKeyValueHeads = src1->ne[2];
3116
+ // double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)
3117
+ int64_t preTokens = 65535;
3118
+ int64_t nextTokens = 65535;
3119
+ char layout[5] = {'B', 'N', 'S', 'D', 0};
3120
+ int64_t sparseMode = 0;
3121
+ int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2;
3122
+ int64_t blockSize = 0;
3123
+ int64_t antiquantMode = 0;
3124
+ bool softmaxLseFlag = false;
3125
+ int64_t keyAntiquantMode = 0;
3126
+ int64_t valueAntiquantMode = 0;
3127
+
3128
+ // Step 5: launch the FusedInferAttentionScoreV2 kernel.
3129
+ // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
3130
+
3131
+ GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
3132
+ acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
3133
+ bcast_pse_tensor, nullptr, // pse, mask
3134
+ nullptr, nullptr, // actSeqLen, actSeqLenkv
3135
+ nullptr, nullptr, // deqScale1, quantScale1
3136
+ nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2
3137
+ nullptr, nullptr, // antiquantScale, antiquantOffset
3138
+ nullptr, // blockTable
3139
+ nullptr, nullptr, // qPadSize, kvPadSize
3140
+ nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset
3141
+ nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset
3142
+ nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen
3143
+ numHeads, scaleValue, // heads, scaleValue
3144
+ preTokens, nextTokens, // preTokens, nextTokens
3145
+ layout, // inputLayout
3146
+ numKeyValueHeads, // numKVHeads
3147
+ sparseMode, innerPrecise, // sparseMode, innerPrecise
3148
+ blockSize, antiquantMode, // blockSize, antiquantMode
3149
+ softmaxLseFlag, // softmaxLseFlag
3150
+ keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
3151
+ acl_dst_f16_tensor, // attentionOut
3152
+ nullptr // softmaxLse
3153
+ );
3154
+
3155
+ // Step 6: post-processing, permute and cast to f32
3156
+
3157
+ int64_t new_dim[] = {0, 2, 1, 3};
3158
+ aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
3159
+
3160
+ if(ggml_cann_type_mapping(dst->type) != faDataType){
3161
+ ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool());
3162
+ perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
3163
+ void* perm_out_f16_buffer = perm_out_f16_allocator.get();
3164
+
3165
+ int64_t* perm_out_f16_ne = dst->ne;
3166
+ size_t perm_out_f16_nb[GGML_MAX_DIMS];
3167
+ perm_out_f16_nb[0] = faElemSize;
3168
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
3169
+ perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1];
3170
+ }
3171
+ aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor(
3172
+ perm_out_f16_buffer, faDataType, faElemSize,
3173
+ perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
3174
+ aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
3175
+ aclnn_cast(ctx,
3176
+ acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
3177
+ ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor);
3178
+ }else{
3179
+ // only need to permute
3180
+ aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
3181
+ }
3182
+ ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
3183
+ acl_src1_f16_tensor,
3184
+ acl_src2_f16_tensor,
3185
+ acl_dst_f16_tensor,
3186
+ acl_dst_tensor);
3187
+ if(src3 != nullptr){
3188
+ ggml_cann_release_resources(ctx, bcast_pse_tensor);
3189
+ }
3190
+ }else{
3191
+ GGML_ABORT("Function is not implemented.");
3192
+ }
3193
+ }
@@ -714,6 +714,21 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
714
714
  */
715
715
  void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
716
716
 
717
+ /**
718
+ * @brief Performs the Flash Attention extended operator using the CANN backend.
719
+ *
720
+ * @details This function implements the memory-efficient Flash Attention algorithm
721
+ * for computing scaled dot-product attention with hardware acceleration.
722
+ * The result is stored in the destination tensor `dst`.
723
+ *
724
+ * This operation is accelerated using the CANN backend to improve runtime performance.
725
+ *
726
+ * @param ctx The CANN context used for operations.
727
+ * @param dst The destination tensor where the result will be stored.
728
+ * dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`.
729
+ */
730
+ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst);
731
+
717
732
  /*
718
733
  * @brief A generic wrapper for ACL resources with custom deleter support.
719
734
  */
@@ -978,6 +993,33 @@ inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffe
978
993
  }
979
994
  }
980
995
 
996
+ /**
997
+ * @brief Performs sparse expert-based matrix multiplication using the CANN backend.
998
+ *
999
+ * @details This function implements a MoE-style batched matrix multiplication, where each input token
1000
+ * is routed to one or more experts, and each expert corresponds to a specific [D, M] weight matrix
1001
+ * in the source tensor `src0`. The routing indices are provided via the `ids` tensor.
1002
+ *
1003
+ * For each token (from `src1`), the function selects the corresponding expert(s) as specified by `ids`,
1004
+ * performs the matrix multiplication with the selected expert's weight submatrix (from `src0`),
1005
+ * and stores the results in `dst`. This operation is optimized and executed on the CANN backend.
1006
+ *
1007
+ * Dimensions:
1008
+ * - src0: [D, M, A, 1], where A is the number of experts
1009
+ * - src1: [D, B, N, 1], where N is batch size and B is the slot count per sample
1010
+ * - ids : [K, N], where K is the number of experts each token is routed to
1011
+ * - dst : [M, K, N, 1], output tensor storing the result of expert × token multiplication
1012
+ *
1013
+ * The function handles two main modes:
1014
+ * - If `ne12 == 1`, a simpler per-token loop is used.
1015
+ * - TODO: If `ne12 > 1`, grouped multiplication and memory copying is used for efficiency.
1016
+ *
1017
+ * @param ctx The CANN context used for operations.
1018
+ * @param dst The destination tensor where the expert-weighted token outputs are stored.
1019
+ * Expected to be of shape [M, K, N, 1].
1020
+ */
1021
+ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst);
1022
+
981
1023
  /**
982
1024
  * @brief Applies a element-wise operation to two input tensors using the CANN
983
1025
  * backend.