@novastera-oss/llamarn 0.2.1 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. package/README.md +80 -14
  2. package/RNLlamaCpp.podspec +10 -3
  3. package/android/CMakeLists.txt +8 -0
  4. package/android/src/main/cpp/include/llama.h +62 -125
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  13. package/cpp/build-info.cpp +2 -2
  14. package/cpp/llama.cpp/README.md +11 -3
  15. package/cpp/llama.cpp/build-xcframework.sh +1 -0
  16. package/cpp/llama.cpp/common/CMakeLists.txt +8 -2
  17. package/cpp/llama.cpp/common/arg.cpp +153 -113
  18. package/cpp/llama.cpp/common/chat-parser.cpp +379 -0
  19. package/cpp/llama.cpp/common/chat-parser.h +117 -0
  20. package/cpp/llama.cpp/common/chat.cpp +847 -699
  21. package/cpp/llama.cpp/common/chat.h +73 -6
  22. package/cpp/llama.cpp/common/common.cpp +50 -82
  23. package/cpp/llama.cpp/common/common.h +21 -17
  24. package/cpp/llama.cpp/common/json-partial.cpp +255 -0
  25. package/cpp/llama.cpp/common/json-partial.h +37 -0
  26. package/cpp/llama.cpp/common/minja/chat-template.hpp +9 -5
  27. package/cpp/llama.cpp/common/minja/minja.hpp +69 -36
  28. package/cpp/llama.cpp/common/regex-partial.cpp +204 -0
  29. package/cpp/llama.cpp/common/regex-partial.h +56 -0
  30. package/cpp/llama.cpp/common/sampling.cpp +7 -8
  31. package/cpp/llama.cpp/convert_hf_to_gguf.py +453 -118
  32. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +120 -68
  33. package/cpp/llama.cpp/ggml/CMakeLists.txt +2 -1
  34. package/cpp/llama.cpp/ggml/cmake/common.cmake +25 -0
  35. package/cpp/llama.cpp/ggml/include/ggml-opt.h +49 -28
  36. package/cpp/llama.cpp/ggml/include/ggml.h +26 -7
  37. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +16 -10
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +4 -1
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +2 -0
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +604 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +42 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +54 -2
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +50 -51
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -2
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -9
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +779 -19
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +22 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +322 -100
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +117 -1
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +220 -49
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/acc.cu +40 -26
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +11 -1
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +15 -7
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +266 -64
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +49 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +48 -4
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +2 -1
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +5 -1
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +2 -0
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +7 -6
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +1 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +10 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-impl.h +1 -1
  71. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +99 -17
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +200 -2
  74. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +8 -2
  75. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +6 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +972 -178
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +373 -190
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -10
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +101 -5
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +31 -33
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +1 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +29 -2
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +4 -5
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +9 -1
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +84 -72
  96. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  98. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -3
  99. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +324 -129
  100. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +1 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +31 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +95 -68
  103. package/cpp/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +1 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +22 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -4
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +2 -3
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +69 -43
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +2 -14
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -91
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -181
  112. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +17 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  114. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +6 -152
  115. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  117. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +2 -118
  118. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +1 -1
  119. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +12 -1
  120. package/cpp/llama.cpp/ggml/src/ggml.c +107 -36
  121. package/cpp/llama.cpp/ggml/src/gguf.cpp +33 -33
  122. package/cpp/llama.cpp/gguf-py/gguf/constants.py +100 -15
  123. package/cpp/llama.cpp/gguf-py/gguf/gguf_reader.py +1 -1
  124. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +44 -12
  125. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py +21 -10
  126. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +5 -2
  127. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +128 -31
  128. package/cpp/llama.cpp/gguf-py/gguf/utility.py +1 -1
  129. package/cpp/llama.cpp/gguf-py/pyproject.toml +1 -1
  130. package/cpp/llama.cpp/include/llama.h +62 -125
  131. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +1 -1
  132. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +1 -1
  133. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +1 -1
  134. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +1 -1
  135. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +1 -1
  136. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +1 -1
  137. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +1 -1
  138. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +1 -1
  139. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +1 -1
  140. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +1 -1
  141. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +1 -1
  142. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +1 -1
  143. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +1 -1
  144. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +1 -1
  145. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +1 -1
  146. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +1 -1
  147. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +1 -1
  148. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +1 -1
  149. package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
  150. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +1 -1
  151. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +1 -1
  152. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +1 -1
  153. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +1 -1
  154. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +1 -1
  155. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +1 -1
  156. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +1 -1
  157. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +1 -1
  158. package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +62 -0
  159. package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +85 -0
  160. package/cpp/llama.cpp/models/templates/README.md +2 -0
  161. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +5 -1
  162. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +5 -1
  163. package/cpp/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
  164. package/cpp/llama.cpp/requirements/requirements-gguf_editor_gui.txt +1 -1
  165. package/cpp/llama.cpp/src/CMakeLists.txt +2 -0
  166. package/cpp/llama.cpp/src/llama-arch.cpp +6 -0
  167. package/cpp/llama.cpp/src/llama-arch.h +2 -0
  168. package/cpp/llama.cpp/src/llama-batch.cpp +3 -1
  169. package/cpp/llama.cpp/src/llama-context.cpp +340 -123
  170. package/cpp/llama.cpp/src/llama-context.h +30 -0
  171. package/cpp/llama.cpp/src/llama-cparams.cpp +4 -0
  172. package/cpp/llama.cpp/src/llama-cparams.h +2 -0
  173. package/cpp/llama.cpp/src/llama-grammar.cpp +12 -2
  174. package/cpp/llama.cpp/src/llama-graph.cpp +157 -247
  175. package/cpp/llama.cpp/src/llama-graph.h +52 -7
  176. package/cpp/llama.cpp/src/llama-hparams.cpp +17 -1
  177. package/cpp/llama.cpp/src/llama-hparams.h +37 -5
  178. package/cpp/llama.cpp/src/llama-kv-cache.cpp +742 -481
  179. package/cpp/llama.cpp/src/llama-kv-cache.h +196 -99
  180. package/cpp/llama.cpp/src/llama-kv-cells.h +379 -0
  181. package/cpp/llama.cpp/src/llama-memory.h +4 -3
  182. package/cpp/llama.cpp/src/llama-model-loader.cpp +22 -17
  183. package/cpp/llama.cpp/src/llama-model-saver.cpp +281 -0
  184. package/cpp/llama.cpp/src/llama-model-saver.h +37 -0
  185. package/cpp/llama.cpp/src/llama-model.cpp +529 -172
  186. package/cpp/llama.cpp/src/llama-model.h +6 -1
  187. package/cpp/llama.cpp/src/llama-quant.cpp +15 -13
  188. package/cpp/llama.cpp/src/llama-sampling.cpp +2 -2
  189. package/cpp/llama.cpp/src/llama-vocab.cpp +35 -8
  190. package/cpp/llama.cpp/src/llama-vocab.h +6 -0
  191. package/cpp/llama.cpp/src/llama.cpp +14 -0
  192. package/cpp/rn-completion.cpp +4 -2
  193. package/ios/include/chat.h +73 -6
  194. package/ios/include/common/minja/chat-template.hpp +9 -5
  195. package/ios/include/common/minja/minja.hpp +69 -36
  196. package/ios/include/common.h +21 -17
  197. package/ios/include/llama.h +62 -125
  198. package/ios/libs/llama.xcframework/Info.plist +19 -19
  199. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  200. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4617 -4487
  201. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  202. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +26 -7
  203. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +62 -125
  204. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  205. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  206. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  207. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3557 -3435
  208. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  209. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  210. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  211. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  212. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  213. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3559 -3437
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +237 -0
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +26 -7
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +62 -125
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +237 -0
  219. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +26 -7
  220. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +62 -125
  221. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  222. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +237 -0
  223. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +26 -7
  224. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +62 -125
  225. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  226. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  227. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  228. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4616 -4487
  229. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +26 -7
  231. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +62 -125
  232. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4637 -4508
  235. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3556 -3435
  236. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  237. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  238. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  239. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  240. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  241. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4653 -4523
  242. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  243. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +26 -7
  244. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +62 -125
  245. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  246. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  247. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4674 -4544
  248. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3587 -3465
  249. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  250. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  251. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  252. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  253. package/package.json +1 -1
  254. package/cpp/llama.cpp/common/stb_image.h +0 -7988
  255. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +0 -112
  256. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.out +0 -46
  257. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +0 -112
  258. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +0 -46
  259. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +0 -112
  260. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +0 -46
  261. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.inp +0 -112
  262. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.out +0 -46
  263. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +0 -112
  264. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.out +0 -46
  265. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +0 -112
  266. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +0 -46
@@ -2691,6 +2691,109 @@ static void ggml_compute_forward_gelu(
2691
2691
  }
2692
2692
  }
2693
2693
 
2694
+ // ggml_compute_forward_gelu_erf
2695
+
2696
+ static void ggml_compute_forward_gelu_erf_f32(
2697
+ const ggml_compute_params * params,
2698
+ ggml_tensor * dst) {
2699
+
2700
+ const ggml_tensor * src0 = dst->src[0];
2701
+
2702
+ assert(ggml_is_contiguous_1(src0));
2703
+ assert(ggml_is_contiguous_1(dst));
2704
+ assert(ggml_are_same_shape(src0, dst));
2705
+
2706
+ const int ith = params->ith;
2707
+ const int nth = params->nth;
2708
+
2709
+ const int nc = src0->ne[0];
2710
+ const int nr = ggml_nrows(src0);
2711
+
2712
+ // rows per thread
2713
+ const int dr = (nr + nth - 1)/nth;
2714
+
2715
+ // row range for this thread
2716
+ const int ir0 = dr*ith;
2717
+ const int ir1 = MIN(ir0 + dr, nr);
2718
+
2719
+ for (int i1 = ir0; i1 < ir1; i1++) {
2720
+ ggml_vec_gelu_erf_f32(nc,
2721
+ (float *) ((char *) dst->data + i1*( dst->nb[1])),
2722
+ (float *) ((char *) src0->data + i1*(src0->nb[1])));
2723
+
2724
+ #ifndef NDEBUG
2725
+ for (int k = 0; k < nc; k++) {
2726
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2727
+ GGML_UNUSED(x);
2728
+ assert(!isnan(x));
2729
+ assert(!isinf(x));
2730
+ }
2731
+ #endif
2732
+ }
2733
+ }
2734
+
2735
+ static void ggml_compute_forward_gelu_erf_f16(
2736
+ const ggml_compute_params * params,
2737
+ ggml_tensor * dst) {
2738
+
2739
+ const ggml_tensor * src0 = dst->src[0];
2740
+
2741
+ assert(ggml_is_contiguous_1(src0));
2742
+ assert(ggml_is_contiguous_1(dst));
2743
+ assert(ggml_are_same_shape(src0, dst));
2744
+
2745
+ const int ith = params->ith;
2746
+ const int nth = params->nth;
2747
+
2748
+ const int nc = src0->ne[0];
2749
+ const int nr = ggml_nrows(src0);
2750
+
2751
+ // rows per thread
2752
+ const int dr = (nr + nth - 1)/nth;
2753
+
2754
+ // row range for this thread
2755
+ const int ir0 = dr*ith;
2756
+ const int ir1 = MIN(ir0 + dr, nr);
2757
+
2758
+ for (int i1 = ir0; i1 < ir1; i1++) {
2759
+ ggml_vec_gelu_erf_f16(nc,
2760
+ (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2761
+ (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2762
+
2763
+ #ifndef NDEBUG
2764
+ for (int k = 0; k < nc; k++) {
2765
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2766
+ const float v = GGML_FP16_TO_FP32(x);
2767
+ GGML_UNUSED(v);
2768
+ assert(!isnan(v));
2769
+ assert(!isinf(v));
2770
+ }
2771
+ #endif
2772
+ }
2773
+ }
2774
+
2775
+ static void ggml_compute_forward_gelu_erf(
2776
+ const ggml_compute_params * params,
2777
+ ggml_tensor * dst) {
2778
+
2779
+ const ggml_tensor * src0 = dst->src[0];
2780
+
2781
+ switch (src0->type) {
2782
+ case GGML_TYPE_F32:
2783
+ {
2784
+ ggml_compute_forward_gelu_erf_f32(params, dst);
2785
+ } break;
2786
+ case GGML_TYPE_F16:
2787
+ {
2788
+ ggml_compute_forward_gelu_erf_f16(params, dst);
2789
+ } break;
2790
+ default:
2791
+ {
2792
+ GGML_ABORT("fatal error");
2793
+ }
2794
+ }
2795
+ }
2796
+
2694
2797
  // ggml_compute_forward_gelu_quick
2695
2798
 
2696
2799
  static void ggml_compute_forward_gelu_quick_f32(
@@ -7530,39 +7633,83 @@ static void ggml_compute_forward_ssm_scan_f32(
7530
7633
  const int ir1 = MIN(ir0 + dr, nr);
7531
7634
  const int ir = ir1 - ir0;
7532
7635
 
7533
- for (int i3 = 0; i3 < n_s; ++i3) {
7534
- for (int i2 = 0; i2 < n_t; ++i2) {
7535
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7536
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7537
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7538
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7539
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7540
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7541
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7542
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7543
-
7544
- // use the output as the source for the next token-wise iterations
7545
- if (i2 > 0) { s0 = s; }
7546
-
7547
- // d_inner
7548
- for (int i1 = 0; i1 < ir; ++i1) {
7549
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7550
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7551
- float x_dt = x[i1] * dt_soft_plus;
7552
- float sumf = 0.0f;
7553
- // d_state
7554
- for (int i0 = 0; i0 < nc; ++i0) {
7555
- int i = i0 + i1*nc;
7556
- // state = prev_state * dA + dB * x
7557
- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
7558
- // y = rowwise_dotprod(state, C)
7559
- sumf += state * C[i0];
7560
- s[i] = state;
7636
+ #ifdef __ARM_FEATURE_SVE
7637
+ for (int i3 = 0; i3 < n_s; ++i3) {
7638
+ for (int i2 = 0; i2 < n_t; ++i2) {
7639
+ const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7640
+ const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7641
+ const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7642
+ const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7643
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7644
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7645
+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7646
+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7647
+
7648
+ // use the output as the source for the next token-wise iterations
7649
+ if (i2 > 0) { s0 = s; }
7650
+
7651
+ // d_inner
7652
+ for (int i1 = 0; i1 < ir; ++i1) {
7653
+ float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7654
+ float x_dt = x[i1] * dt_soft_plus;
7655
+ svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
7656
+ svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
7657
+ svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
7658
+
7659
+ for (int64_t k = 0; k < nc; k += svcntw()) {
7660
+ svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
7661
+ svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
7662
+ svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
7663
+ svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
7664
+
7665
+ svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
7666
+ t1 = exp_ps_sve(svptrue_b32(), t1);
7667
+ svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
7668
+
7669
+ vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
7670
+ r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
7671
+
7672
+ GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
7673
+ }
7674
+ y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
7561
7675
  }
7562
- y[i1] = sumf;
7563
7676
  }
7564
7677
  }
7565
- }
7678
+ #else
7679
+ for (int i3 = 0; i3 < n_s; ++i3) {
7680
+ for (int i2 = 0; i2 < n_t; ++i2) {
7681
+ const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7682
+ const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7683
+ const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7684
+ const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7685
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7686
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7687
+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7688
+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7689
+
7690
+ // use the output as the source for the next token-wise iterations
7691
+ if (i2 > 0) { s0 = s; }
7692
+
7693
+ // d_inner
7694
+ for (int i1 = 0; i1 < ir; ++i1) {
7695
+ // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7696
+ float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7697
+ float x_dt = x[i1] * dt_soft_plus;
7698
+ float sumf = 0.0f;
7699
+ // d_state
7700
+ for (int i0 = 0; i0 < nc; ++i0) {
7701
+ int i = i0 + i1*nc;
7702
+ // state = prev_state * dA + dB * x
7703
+ float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
7704
+ // y = rowwise_dotprod(state, C)
7705
+ sumf += state * C[i0];
7706
+ s[i] = state;
7707
+ }
7708
+ y[i1] = sumf;
7709
+ }
7710
+ }
7711
+ }
7712
+ #endif
7566
7713
  }
7567
7714
 
7568
7715
  void ggml_compute_forward_ssm_scan(
@@ -7749,6 +7896,10 @@ void ggml_compute_forward_unary(
7749
7896
  {
7750
7897
  ggml_compute_forward_gelu(params, dst);
7751
7898
  } break;
7899
+ case GGML_UNARY_OP_GELU_ERF:
7900
+ {
7901
+ ggml_compute_forward_gelu_erf(params, dst);
7902
+ } break;
7752
7903
  case GGML_UNARY_OP_GELU_QUICK:
7753
7904
  {
7754
7905
  ggml_compute_forward_gelu_quick(params, dst);
@@ -7963,6 +8114,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
7963
8114
  #define GGML_F32X_MUL GGML_F32x16_MUL
7964
8115
  #define GGML_F32X_FMA GGML_F32x16_FMA
7965
8116
  #define WKV_VECTOR_SIZE 16
8117
+ #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
8118
+ #define GGML_F32X GGML_F32xt
8119
+ #define GGML_F32X_SET1 GGML_F32xt_SET1
8120
+ #define GGML_F32X_LOAD GGML_F32xt_LOAD
8121
+ #define GGML_F32X_STORE GGML_F32xt_STORE
8122
+ #define GGML_F32X_MUL GGML_F32xt_MUL
8123
+ #define GGML_F32X_FMA GGML_F32xt_FMA
8124
+ #define WKV_VECTOR_SIZE 8
7966
8125
  #elif defined(__ARM_NEON) && defined(__aarch64__)
7967
8126
  #define GGML_F32X GGML_F32x4
7968
8127
  #define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -7973,8 +8132,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
7973
8132
  #define WKV_VECTOR_SIZE 4
7974
8133
  #endif
7975
8134
 
8135
+ int wkv_vector_size;
7976
8136
  #ifdef WKV_VECTOR_SIZE
7977
- const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
8137
+ #if defined(__ARM_FEATURE_SVE)
8138
+ wkv_vector_size = svcntw();
8139
+ #else
8140
+ wkv_vector_size = WKV_VECTOR_SIZE;
8141
+ #endif
8142
+ const int64_t vec_count = head_size / wkv_vector_size;
7978
8143
 
7979
8144
  for (int64_t t = 0; t < T; t++) {
7980
8145
  size_t t_offset = t * t_stride;
@@ -8004,7 +8169,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
8004
8169
  GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
8005
8170
 
8006
8171
  for (int64_t j = 0; j < vec_count; j++) {
8007
- size_t base_j = j * WKV_VECTOR_SIZE;
8172
+ size_t base_j = j * wkv_vector_size;
8008
8173
  size_t t_h_j_offset = t_h_offset + base_j;
8009
8174
  size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
8010
8175
 
@@ -8029,7 +8194,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
8029
8194
  }
8030
8195
 
8031
8196
  // Handle remaining elements, this will not be used.
8032
- for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
8197
+ for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {
8033
8198
  size_t t_h_j_offset = t_h_offset + j;
8034
8199
  size_t h_2d_i_j_offset = h_2d_i_offset + j;
8035
8200
  float v_val = v[t_h_j_offset];
@@ -8165,6 +8330,14 @@ static void ggml_compute_forward_gla_f32(
8165
8330
  #define GGML_F32X_MUL GGML_F32x16_MUL
8166
8331
  #define GGML_F32X_FMA GGML_F32x16_FMA
8167
8332
  #define GLA_VECTOR_SIZE 16
8333
+ #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
8334
+ #define GGML_F32X GGML_F32xt
8335
+ #define GGML_F32X_SET1 GGML_F32xt_SET1
8336
+ #define GGML_F32X_LOAD GGML_F32xt_LOAD
8337
+ #define GGML_F32X_STORE GGML_F32xt_STORE
8338
+ #define GGML_F32X_MUL GGML_F32xt_MUL
8339
+ #define GGML_F32X_FMA GGML_F32xt_FMA
8340
+ #define GLA_VECTOR_SIZE 8
8168
8341
  #elif defined(__ARM_NEON) && defined(__aarch64__)
8169
8342
  #define GGML_F32X GGML_F32x4
8170
8343
  #define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -8175,8 +8348,14 @@ static void ggml_compute_forward_gla_f32(
8175
8348
  #define GLA_VECTOR_SIZE 4
8176
8349
  #endif
8177
8350
 
8351
+ int gla_vector_size;
8178
8352
  #ifdef GLA_VECTOR_SIZE
8179
- const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
8353
+ #if defined(__ARM_FEATURE_SVE)
8354
+ gla_vector_size = svcntw();
8355
+ #else
8356
+ gla_vector_size = GLA_VECTOR_SIZE;
8357
+ #endif
8358
+ const int64_t vec_count = head_size / gla_vector_size;
8180
8359
 
8181
8360
  for (int64_t t = 0; t < T; t++) {
8182
8361
  size_t t_offset = t * t_stride;
@@ -8203,7 +8382,7 @@ static void ggml_compute_forward_gla_f32(
8203
8382
  GGML_F32X g_vec = GGML_F32X_SET1(g_val);
8204
8383
 
8205
8384
  for (int64_t j = 0; j < vec_count; j++) {
8206
- size_t base_j = j * GLA_VECTOR_SIZE;
8385
+ size_t base_j = j * gla_vector_size;
8207
8386
  size_t t_h_j_offset = t_h_offset + base_j;
8208
8387
  size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
8209
8388
 
@@ -8227,7 +8406,7 @@ static void ggml_compute_forward_gla_f32(
8227
8406
  }
8228
8407
 
8229
8408
  // Handle remaining elements, this will not be used.
8230
- for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
8409
+ for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {
8231
8410
  size_t t_h_j_offset = t_h_offset + j;
8232
8411
  size_t h_2d_i_j_offset = h_2d_i_offset + j;
8233
8412
  float v_val = v[t_h_j_offset];
@@ -8336,83 +8515,126 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
8336
8515
  int64_t h_stride_2d = head_size * head_size;
8337
8516
 
8338
8517
  #if defined(GGML_SIMD)
8339
- for (int64_t t = 0; t < T; t++) {
8340
- int64_t t_offset = t * t_stride;
8341
- int64_t state_offset = head_size * C * (t / (T / n_seqs));
8342
- float * state_cur = state + state_offset;
8343
- float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
8344
-
8345
- for (int64_t h = h_start; h < h_end; h++) {
8346
- int64_t h_offset = h * h_stride;
8347
- int64_t t_h_offset = t_offset + h_offset;
8348
- int64_t h_2d_offset = h * h_stride_2d;
8349
-
8350
- for (int64_t ii = 0; ii < head_size; ii++) {
8351
- int64_t t_h_i_offset = t_h_offset + ii;
8352
- int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
8353
-
8354
- GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
8518
+ #if defined(__ARM_FEATURE_SVE)
8519
+ // scalar Route to scalar implementation //TODO: Write SVE code
8520
+ for (int64_t t = 0; t < T; t++) {
8521
+ int64_t t_offset = t * t_stride;
8522
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
8523
+ float * state_cur = state + state_offset;
8524
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
8525
+
8526
+ for (int64_t h = h_start; h < h_end; h++) {
8527
+ int64_t h_offset = h * h_stride;
8528
+ int64_t t_h_offset = t_offset + h_offset;
8529
+ int64_t h_2d_offset = h * h_stride_2d;
8530
+
8531
+ for (int64_t i = 0; i < head_size; i++) {
8532
+ int64_t t_h_i_offset = t_h_offset + i;
8533
+ int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
8534
+
8535
+ float v_val = v[t_h_i_offset];
8536
+
8537
+ float sa = 0, result = 0;
8538
+ for (int64_t j = 0; j < head_size; j++) {
8539
+ sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
8540
+ }
8355
8541
 
8356
- float sa = 0;
8357
- {
8358
- GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8359
- GGML_F32_VEC ax[GGML_F32_ARR];
8360
- GGML_F32_VEC ay[GGML_F32_ARR];
8361
- for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
8362
- for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8363
- ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
8364
- ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
8365
- sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
8366
- }
8542
+ for (int64_t j = 0; j < head_size; j++) {
8543
+ int64_t t_h_j_offset = t_h_offset + j;
8544
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8545
+
8546
+ float r_val = r[t_h_j_offset];
8547
+ float w_val = w[t_h_j_offset];
8548
+ float k_val = k[t_h_j_offset];
8549
+ float b_val = b[t_h_j_offset];
8550
+ float kv_val = v_val * k_val;
8551
+ float prev_state_val = state_prev[h_2d_i_j_offset];
8552
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8553
+ result += state_cur[h_2d_i_j_offset] * r_val;
8367
8554
  }
8368
- GGML_F32_VEC_REDUCE(sa, sum);
8555
+ dst_data[t_h_i_offset] = result;
8369
8556
  }
8557
+ }
8558
+ }
8559
+ #else
8560
+ for (int64_t t = 0; t < T; t++) {
8561
+ int64_t t_offset = t * t_stride;
8562
+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
8563
+ float * state_cur = state + state_offset;
8564
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
8565
+
8566
+ for (int64_t h = h_start; h < h_end; h++) {
8567
+ int64_t h_offset = h * h_stride;
8568
+ int64_t t_h_offset = t_offset + h_offset;
8569
+ int64_t h_2d_offset = h * h_stride_2d;
8570
+
8571
+ for (int64_t ii = 0; ii < head_size; ii++) {
8572
+ int64_t t_h_i_offset = t_h_offset + ii;
8573
+ int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
8574
+
8575
+ GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
8576
+
8577
+ float sa = 0;
8578
+ {
8579
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8580
+ GGML_F32_VEC ax[GGML_F32_ARR];
8581
+ GGML_F32_VEC ay[GGML_F32_ARR];
8582
+ for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
8583
+ for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8584
+ ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
8585
+ ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
8586
+ sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
8587
+ }
8588
+ }
8589
+ GGML_F32_VEC_REDUCE(sa, sum);
8590
+ }
8370
8591
 
8371
- GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
8592
+ GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
8372
8593
 
8373
- int64_t j = 0;
8374
- GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8375
- for (; j < head_size; j += GGML_F32_STEP) {
8376
- for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8377
- int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
8378
- int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
8594
+ int64_t j = 0;
8595
+ GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8596
+ for (; j < head_size; j += GGML_F32_STEP) {
8597
+ for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8598
+ int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
8599
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
8379
8600
 
8380
- GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
8381
- GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
8382
- GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
8383
- GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
8601
+ GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
8602
+ GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
8603
+ GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
8604
+ GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
8384
8605
 
8385
- k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
8606
+ k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
8386
8607
 
8387
- GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
8388
- // kv + s * decay + sa * b
8389
- state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
8390
- state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
8391
- GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
8608
+ GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
8609
+ // kv + s * decay + sa * b
8610
+ state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
8611
+ state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
8612
+ GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
8392
8613
 
8393
- result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
8614
+ result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
8615
+ }
8616
+ }
8617
+ GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
8618
+
8619
+ // There shouldn't be left-overs though.
8620
+ for (; j < head_size; j++) {
8621
+ int64_t t_h_j_offset = t_h_offset + j;
8622
+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8623
+
8624
+ float r_val = r[t_h_j_offset];
8625
+ float w_val = w[t_h_j_offset];
8626
+ float k_val = k[t_h_j_offset];
8627
+ float b_val = b[t_h_j_offset];
8628
+ float kv_val = v[t_h_i_offset] * k_val;
8629
+
8630
+ float prev_state_val = state_prev[h_2d_i_j_offset];
8631
+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8632
+ dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
8394
8633
  }
8395
- }
8396
- GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
8397
-
8398
- // There shouldn't be left-overs though.
8399
- for (; j < head_size; j++) {
8400
- int64_t t_h_j_offset = t_h_offset + j;
8401
- int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8402
-
8403
- float r_val = r[t_h_j_offset];
8404
- float w_val = w[t_h_j_offset];
8405
- float k_val = k[t_h_j_offset];
8406
- float b_val = b[t_h_j_offset];
8407
- float kv_val = v[t_h_i_offset] * k_val;
8408
-
8409
- float prev_state_val = state_prev[h_2d_i_j_offset];
8410
- state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8411
- dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
8412
8634
  }
8413
8635
  }
8414
8636
  }
8415
- }
8637
+ #endif
8416
8638
  #else
8417
8639
  for (int64_t t = 0; t < T; t++) {
8418
8640
  int64_t t_offset = t * t_stride;
@@ -17,7 +17,123 @@
17
17
  // number of elements to fit in a single register
18
18
  //
19
19
 
20
- #if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
20
+ #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_FMA)
21
+
22
+ #define GGML_SIMD
23
+
24
+ // F32 SVE
25
+ #define GGML_F32_EPR 8
26
+ #define DEFAULT_PG svptrue_b32()
27
+
28
+ #define GGML_F32xt svfloat32_t
29
+ #define GGML_F32xt_ZERO svdup_n_f32(0.0f)
30
+ #define GGML_F32xt_SET1(x) svdup_n_f32(x)
31
+ #define GGML_F32xt_LOAD_IMPL(pg, a, ...) svld1_f32(pg, a)
32
+ #define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
33
+ #define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
34
+ #define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
35
+ #define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c)
36
+ #define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
37
+ #define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
38
+ #define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
39
+ #define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b)
40
+ #define GGML_F32xt_MUL(...) GGML_F32xt_MUL_IMPL(DEFAULT_PG, __VA_ARGS__)
41
+ #define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a)
42
+ #define GGML_F32xt_REDUCE_ONE(...) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, __VA_ARGS__)
43
+ #define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
44
+ { \
45
+ sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \
46
+ sum3 = svadd_f32_m(DEFAULT_PG, sum3, sum4); \
47
+ sum5 = svadd_f32_m(DEFAULT_PG, sum5, sum6); \
48
+ sum7 = svadd_f32_m(DEFAULT_PG, sum7, sum8); \
49
+ sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum3); \
50
+ sum5 = svadd_f32_m(DEFAULT_PG, sum5, sum7); \
51
+ sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \
52
+ (res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1); \
53
+ }
54
+ #define GGML_F32xt_REDUCE(...) GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, __VA_ARGS__)
55
+
56
+ #define GGML_F32_VEC GGML_F32xt
57
+ #define GGML_F32_VEC_ZERO GGML_F32xt_ZERO
58
+ #define GGML_F32_VEC_SET1 GGML_F32xt_SET1
59
+ #define GGML_F32_VEC_LOAD GGML_F32xt_LOAD
60
+ #define GGML_F32_VEC_STORE GGML_F32xt_STORE
61
+ #define GGML_F32_VEC_FMA GGML_F32xt_FMA
62
+ #define GGML_F32_VEC_ADD GGML_F32xt_ADD
63
+ #define GGML_F32_VEC_MUL GGML_F32xt_MUL
64
+ #define GGML_F32_VEC_REDUCE GGML_F32xt_REDUCE
65
+
66
+ // F16 NEON
67
+
68
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
69
+ #define GGML_F16_STEP 32
70
+ #define GGML_F16_EPR 8
71
+
72
+ #define GGML_F16x8 float16x8_t
73
+ #define GGML_F16x8_ZERO vdupq_n_f16(0.0f)
74
+ #define GGML_F16x8_SET1(x) vdupq_n_f16(x)
75
+ #define GGML_F16x8_LOAD(x) vld1q_f16((const __fp16 *)(x))
76
+ #define GGML_F16x8_STORE vst1q_f16
77
+ #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
78
+ #define GGML_F16x8_ADD vaddq_f16
79
+ #define GGML_F16x8_MUL vmulq_f16
80
+ #define GGML_F16x8_REDUCE(res, x) \
81
+ do { \
82
+ int offset = GGML_F16_ARR >> 1; \
83
+ for (int i = 0; i < offset; ++i) { \
84
+ (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
85
+ } \
86
+ offset >>= 1; \
87
+ for (int i = 0; i < offset; ++i) { \
88
+ (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
89
+ } \
90
+ offset >>= 1; \
91
+ for (int i = 0; i < offset; ++i) { \
92
+ (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
93
+ } \
94
+ const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \
95
+ const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \
96
+ (res) = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
97
+ } while (0)
98
+
99
+ #define GGML_F16_VEC GGML_F16x8
100
+ #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
101
+ #define GGML_F16_VEC_SET1 GGML_F16x8_SET1
102
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
103
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), (r)[i])
104
+ #define GGML_F16_VEC_FMA GGML_F16x8_FMA
105
+ #define GGML_F16_VEC_ADD GGML_F16x8_ADD
106
+ #define GGML_F16_VEC_MUL GGML_F16x8_MUL
107
+ #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE
108
+ #else
109
+ // if FP16 vector arithmetic is not supported, we use FP32 instead
110
+ // and take advantage of the vcvt_ functions to convert to/from FP16
111
+
112
+ #define GGML_F16_STEP 16
113
+ #define GGML_F16_EPR 4
114
+
115
+ #define GGML_F32Cx4 float32x4_t
116
+ #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f)
117
+ #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x)
118
+ #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const __fp16 *)(x)))
119
+ #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y))
120
+ #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
121
+ #define GGML_F32Cx4_ADD vaddq_f32
122
+ #define GGML_F32Cx4_MUL vmulq_f32
123
+ #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
124
+
125
+ #define GGML_F16_VEC GGML_F32Cx4
126
+ #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
127
+ #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
128
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
129
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i])
130
+ #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
131
+ #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
132
+ #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
133
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
134
+ #endif
135
+
136
+ #elif defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
21
137
 
22
138
  #define GGML_SIMD
23
139