@novastera-oss/llamarn 0.2.1 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. package/README.md +80 -14
  2. package/RNLlamaCpp.podspec +10 -3
  3. package/android/CMakeLists.txt +8 -0
  4. package/android/src/main/cpp/include/llama.h +62 -125
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  13. package/cpp/build-info.cpp +2 -2
  14. package/cpp/llama.cpp/README.md +11 -3
  15. package/cpp/llama.cpp/build-xcframework.sh +1 -0
  16. package/cpp/llama.cpp/common/CMakeLists.txt +8 -2
  17. package/cpp/llama.cpp/common/arg.cpp +153 -113
  18. package/cpp/llama.cpp/common/chat-parser.cpp +379 -0
  19. package/cpp/llama.cpp/common/chat-parser.h +117 -0
  20. package/cpp/llama.cpp/common/chat.cpp +847 -699
  21. package/cpp/llama.cpp/common/chat.h +73 -6
  22. package/cpp/llama.cpp/common/common.cpp +50 -82
  23. package/cpp/llama.cpp/common/common.h +21 -17
  24. package/cpp/llama.cpp/common/json-partial.cpp +255 -0
  25. package/cpp/llama.cpp/common/json-partial.h +37 -0
  26. package/cpp/llama.cpp/common/minja/chat-template.hpp +9 -5
  27. package/cpp/llama.cpp/common/minja/minja.hpp +69 -36
  28. package/cpp/llama.cpp/common/regex-partial.cpp +204 -0
  29. package/cpp/llama.cpp/common/regex-partial.h +56 -0
  30. package/cpp/llama.cpp/common/sampling.cpp +7 -8
  31. package/cpp/llama.cpp/convert_hf_to_gguf.py +453 -118
  32. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +120 -68
  33. package/cpp/llama.cpp/ggml/CMakeLists.txt +2 -1
  34. package/cpp/llama.cpp/ggml/cmake/common.cmake +25 -0
  35. package/cpp/llama.cpp/ggml/include/ggml-opt.h +49 -28
  36. package/cpp/llama.cpp/ggml/include/ggml.h +26 -7
  37. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +16 -10
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +4 -1
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +2 -0
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +604 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +42 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +54 -2
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +50 -51
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -2
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -9
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +779 -19
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +22 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +322 -100
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +117 -1
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +220 -49
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/acc.cu +40 -26
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +11 -1
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +15 -7
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +266 -64
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +49 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +48 -4
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +2 -1
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +5 -1
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +2 -0
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +7 -6
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +1 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +10 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-impl.h +1 -1
  71. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +99 -17
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +200 -2
  74. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +8 -2
  75. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +6 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +972 -178
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +373 -190
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -10
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +101 -5
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +31 -33
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +1 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +29 -2
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +4 -5
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +9 -1
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +84 -72
  96. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  98. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -3
  99. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +324 -129
  100. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +1 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +31 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +95 -68
  103. package/cpp/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +1 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +22 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -4
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +2 -3
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +69 -43
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +2 -14
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -91
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -181
  112. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +17 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  114. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +6 -152
  115. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  117. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +2 -118
  118. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +1 -1
  119. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +12 -1
  120. package/cpp/llama.cpp/ggml/src/ggml.c +107 -36
  121. package/cpp/llama.cpp/ggml/src/gguf.cpp +33 -33
  122. package/cpp/llama.cpp/gguf-py/gguf/constants.py +100 -15
  123. package/cpp/llama.cpp/gguf-py/gguf/gguf_reader.py +1 -1
  124. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +44 -12
  125. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py +21 -10
  126. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +5 -2
  127. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +128 -31
  128. package/cpp/llama.cpp/gguf-py/gguf/utility.py +1 -1
  129. package/cpp/llama.cpp/gguf-py/pyproject.toml +1 -1
  130. package/cpp/llama.cpp/include/llama.h +62 -125
  131. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +1 -1
  132. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +1 -1
  133. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +1 -1
  134. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +1 -1
  135. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +1 -1
  136. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +1 -1
  137. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +1 -1
  138. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +1 -1
  139. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +1 -1
  140. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +1 -1
  141. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +1 -1
  142. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +1 -1
  143. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +1 -1
  144. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +1 -1
  145. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +1 -1
  146. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +1 -1
  147. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +1 -1
  148. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +1 -1
  149. package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
  150. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +1 -1
  151. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +1 -1
  152. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +1 -1
  153. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +1 -1
  154. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +1 -1
  155. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +1 -1
  156. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +1 -1
  157. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +1 -1
  158. package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +62 -0
  159. package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +85 -0
  160. package/cpp/llama.cpp/models/templates/README.md +2 -0
  161. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +5 -1
  162. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +5 -1
  163. package/cpp/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
  164. package/cpp/llama.cpp/requirements/requirements-gguf_editor_gui.txt +1 -1
  165. package/cpp/llama.cpp/src/CMakeLists.txt +2 -0
  166. package/cpp/llama.cpp/src/llama-arch.cpp +6 -0
  167. package/cpp/llama.cpp/src/llama-arch.h +2 -0
  168. package/cpp/llama.cpp/src/llama-batch.cpp +3 -1
  169. package/cpp/llama.cpp/src/llama-context.cpp +340 -123
  170. package/cpp/llama.cpp/src/llama-context.h +30 -0
  171. package/cpp/llama.cpp/src/llama-cparams.cpp +4 -0
  172. package/cpp/llama.cpp/src/llama-cparams.h +2 -0
  173. package/cpp/llama.cpp/src/llama-grammar.cpp +12 -2
  174. package/cpp/llama.cpp/src/llama-graph.cpp +157 -247
  175. package/cpp/llama.cpp/src/llama-graph.h +52 -7
  176. package/cpp/llama.cpp/src/llama-hparams.cpp +17 -1
  177. package/cpp/llama.cpp/src/llama-hparams.h +37 -5
  178. package/cpp/llama.cpp/src/llama-kv-cache.cpp +742 -481
  179. package/cpp/llama.cpp/src/llama-kv-cache.h +196 -99
  180. package/cpp/llama.cpp/src/llama-kv-cells.h +379 -0
  181. package/cpp/llama.cpp/src/llama-memory.h +4 -3
  182. package/cpp/llama.cpp/src/llama-model-loader.cpp +22 -17
  183. package/cpp/llama.cpp/src/llama-model-saver.cpp +281 -0
  184. package/cpp/llama.cpp/src/llama-model-saver.h +37 -0
  185. package/cpp/llama.cpp/src/llama-model.cpp +529 -172
  186. package/cpp/llama.cpp/src/llama-model.h +6 -1
  187. package/cpp/llama.cpp/src/llama-quant.cpp +15 -13
  188. package/cpp/llama.cpp/src/llama-sampling.cpp +2 -2
  189. package/cpp/llama.cpp/src/llama-vocab.cpp +35 -8
  190. package/cpp/llama.cpp/src/llama-vocab.h +6 -0
  191. package/cpp/llama.cpp/src/llama.cpp +14 -0
  192. package/cpp/rn-completion.cpp +4 -2
  193. package/ios/include/chat.h +73 -6
  194. package/ios/include/common/minja/chat-template.hpp +9 -5
  195. package/ios/include/common/minja/minja.hpp +69 -36
  196. package/ios/include/common.h +21 -17
  197. package/ios/include/llama.h +62 -125
  198. package/ios/libs/llama.xcframework/Info.plist +19 -19
  199. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  200. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4617 -4487
  201. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  202. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +26 -7
  203. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +62 -125
  204. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  205. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  206. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  207. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3557 -3435
  208. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  209. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  210. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  211. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  212. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  213. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3559 -3437
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +237 -0
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +26 -7
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +62 -125
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +237 -0
  219. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +26 -7
  220. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +62 -125
  221. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  222. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +237 -0
  223. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +26 -7
  224. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +62 -125
  225. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  226. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  227. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  228. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4616 -4487
  229. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +26 -7
  231. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +62 -125
  232. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4637 -4508
  235. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3556 -3435
  236. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  237. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  238. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  239. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  240. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  241. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4653 -4523
  242. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  243. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +26 -7
  244. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +62 -125
  245. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  246. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  247. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4674 -4544
  248. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3587 -3465
  249. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  250. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  251. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  252. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  253. package/package.json +1 -1
  254. package/cpp/llama.cpp/common/stb_image.h +0 -7988
  255. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +0 -112
  256. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.out +0 -46
  257. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +0 -112
  258. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +0 -46
  259. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +0 -112
  260. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +0 -46
  261. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.inp +0 -112
  262. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.out +0 -46
  263. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +0 -112
  264. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.out +0 -46
  265. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +0 -112
  266. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +0 -46
@@ -2,9 +2,9 @@
2
2
  #include "fattn-common.cuh"
3
3
 
4
4
  template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
5
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
5
+ #ifndef GGML_USE_HIP
6
6
  __launch_bounds__(D, 1)
7
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
7
+ #endif // GGML_USE_HIP
8
8
  static __global__ void flash_attn_vec_ext_f16(
9
9
  const char * __restrict__ Q,
10
10
  const char * __restrict__ K,
@@ -48,6 +48,12 @@ static __global__ void flash_attn_vec_ext_f16(
48
48
  NO_DEVICE_CODE;
49
49
  return;
50
50
  }
51
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
52
+ if (ncols > 1) {
53
+ NO_DEVICE_CODE;
54
+ return;
55
+ }
56
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
51
57
 
52
58
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
53
59
 
@@ -91,6 +97,13 @@ static __global__ void flash_attn_vec_ext_f16(
91
97
  kqsum_shared[j][threadIdx.x] = 0.0f;
92
98
  }
93
99
  }
100
+
101
+ __shared__ half maskh_shared[ncols*D];
102
+ #pragma unroll
103
+ for (int j = 0; j < ncols; ++j) {
104
+ maskh_shared[j*D + tid] = 0.0f;
105
+ }
106
+
94
107
  __syncthreads();
95
108
 
96
109
  // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
@@ -175,6 +188,36 @@ static __global__ void flash_attn_vec_ext_f16(
175
188
  for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
176
189
  // Calculate KQ tile and keep track of new maximum KQ values:
177
190
 
191
+ if (mask) {
192
+ #pragma unroll
193
+ for (int j = 0; j < ncols; ++j) {
194
+ maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid];
195
+ }
196
+
197
+ __syncthreads();
198
+
199
+ // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
200
+ // In such cases, skip the KV slice.
201
+ // On AMD __all_sync would not work correctly because it assumes a warp size of 64.
202
+ #ifndef GGML_USE_HIP
203
+ bool skip = true;
204
+ #pragma unroll
205
+ for (int j = 0; j < ncols; ++j) {
206
+ #pragma unroll
207
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
208
+ const int i = i0 + threadIdx.x;
209
+
210
+ const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]);
211
+ skip = skip && isinf(tmp.x) && isinf(tmp.y);
212
+ }
213
+ }
214
+ if (__all_sync(0xFFFFFFFF, skip)) {
215
+ __syncthreads();
216
+ continue;
217
+ }
218
+ #endif // GGML_USE_HIP
219
+ }
220
+
178
221
  // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
179
222
  // see https://github.com/ggerganov/llama.cpp/pull/7061 .
180
223
  // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
@@ -202,7 +245,7 @@ static __global__ void flash_attn_vec_ext_f16(
202
245
  sum = logit_softcap*tanhf(sum);
203
246
  }
204
247
 
205
- sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
248
+ sum += maskh_shared[j*D + i_KQ];
206
249
 
207
250
  if (ncols == 1) {
208
251
  kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
@@ -335,7 +378,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
335
378
  float logit_softcap;
336
379
  memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
337
380
 
338
- if (Q->ne[1] == 1) {
381
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
382
+
383
+ if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
339
384
  constexpr int cols_per_block = 1;
340
385
  if (logit_softcap == 0.0f) {
341
386
  constexpr bool use_logit_softcap = false;
@@ -2,9 +2,9 @@
2
2
  #include "fattn-common.cuh"
3
3
 
4
4
  template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
5
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
5
+ #ifndef GGML_USE_HIP
6
6
  __launch_bounds__(D, 1)
7
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
7
+ #endif // GGML_USE_HIP
8
8
  static __global__ void flash_attn_vec_ext_f32(
9
9
  const char * __restrict__ Q,
10
10
  const char * __restrict__ K,
@@ -60,6 +60,12 @@ static __global__ void flash_attn_vec_ext_f32(
60
60
  NO_DEVICE_CODE;
61
61
  return;
62
62
  }
63
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
64
+ if (ncols > 1) {
65
+ NO_DEVICE_CODE;
66
+ return;
67
+ }
68
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
63
69
 
64
70
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
65
71
 
@@ -104,6 +110,13 @@ static __global__ void flash_attn_vec_ext_f32(
104
110
  kqsum_shared[j][threadIdx.x] = 0.0f;
105
111
  }
106
112
  }
113
+
114
+ __shared__ float maskf_shared[ncols*D];
115
+ #pragma unroll
116
+ for (int j = 0; j < ncols; ++j) {
117
+ maskf_shared[j*D + tid] = 0.0f;
118
+ }
119
+
107
120
  __syncthreads();
108
121
 
109
122
  // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
@@ -181,6 +194,35 @@ static __global__ void flash_attn_vec_ext_f32(
181
194
  for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
182
195
  // Calculate KQ tile and keep track of new maximum KQ values:
183
196
 
197
+ if (mask) {
198
+ #pragma unroll
199
+ for (int j = 0; j < ncols; ++j) {
200
+ maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]);
201
+ }
202
+
203
+ __syncthreads();
204
+
205
+ // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
206
+ // In such cases, skip the KV slice.
207
+ // On AMD __all_sync would not work correctly because it assumes a warp size of 64.
208
+ #ifndef GGML_USE_HIP
209
+ bool skip = true;
210
+ #pragma unroll
211
+ for (int j = 0; j < ncols; ++j) {
212
+ #pragma unroll
213
+ for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
214
+ const int i = i0 + threadIdx.x;
215
+
216
+ skip = skip && isinf(maskf_shared[j*D + i]);
217
+ }
218
+ }
219
+ if (__all_sync(0xFFFFFFFF, skip)) {
220
+ __syncthreads();
221
+ continue;
222
+ }
223
+ #endif // GGML_USE_HIP
224
+ }
225
+
184
226
  float kqmax_new_arr[ncols];
185
227
  #pragma unroll
186
228
  for (int j = 0; j < ncols; ++j) {
@@ -204,7 +246,7 @@ static __global__ void flash_attn_vec_ext_f32(
204
246
  sum = logit_softcap*tanhf(sum);
205
247
  }
206
248
 
207
- sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
249
+ sum += maskf_shared[j*D + i_KQ];
208
250
 
209
251
  kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
210
252
 
@@ -326,7 +368,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
326
368
  float logit_softcap;
327
369
  memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
328
370
 
329
- if (Q->ne[1] == 1) {
371
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
372
+
373
+ if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
330
374
  constexpr int cols_per_block = 1;
331
375
  if (logit_softcap == 0.0f) {
332
376
  constexpr bool use_logit_softcap = false;
@@ -10,6 +10,7 @@
10
10
 
11
11
  template <int DKQ, int DV, int ncols2>
12
12
  static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
13
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
13
14
  const ggml_tensor * Q = dst->src[0];
14
15
 
15
16
  if constexpr (ncols2 <= 8) {
@@ -24,7 +25,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
24
25
  return;
25
26
  }
26
27
 
27
- if (Q->ne[1] <= 32/ncols2) {
28
+ if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
28
29
  ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
29
30
  return;
30
31
  }
@@ -2192,6 +2192,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2192
2192
  case GGML_UNARY_OP_SILU:
2193
2193
  ggml_cuda_op_silu(ctx, dst);
2194
2194
  break;
2195
+ case GGML_UNARY_OP_GELU_ERF:
2196
+ ggml_cuda_op_gelu_erf(ctx, dst);
2197
+ break;
2195
2198
  case GGML_UNARY_OP_GELU_QUICK:
2196
2199
  ggml_cuda_op_gelu_quick(ctx, dst);
2197
2200
  break;
@@ -2977,6 +2980,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
2977
2980
  case GGML_UNARY_OP_SIGMOID:
2978
2981
  case GGML_UNARY_OP_HARDSIGMOID:
2979
2982
  case GGML_UNARY_OP_HARDSWISH:
2983
+ case GGML_UNARY_OP_GELU_ERF:
2980
2984
  case GGML_UNARY_OP_GELU_QUICK:
2981
2985
  case GGML_UNARY_OP_TANH:
2982
2986
  case GGML_UNARY_OP_EXP:
@@ -3222,7 +3226,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3222
3226
  #endif // FLASH_ATTN_AVAILABLE
3223
3227
  if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
3224
3228
  const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3225
- if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) {
3229
+ if (!new_mma_available(cc)) {
3226
3230
  return false;
3227
3231
  }
3228
3232
  const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
@@ -122,6 +122,7 @@ void ggml_cuda_mul_mat_q(
122
122
  const int64_t s13 = src1->nb[3] / ts_src1;
123
123
  quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
124
124
  ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
125
+ CUDA_CHECK(cudaGetLastError());
125
126
  }
126
127
 
127
128
  const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
@@ -205,6 +206,7 @@ void ggml_cuda_mul_mat_q(
205
206
  const int64_t s13 = src1->nb[2] / ts_src1;
206
207
  quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
207
208
  ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
209
+ CUDA_CHECK(cudaGetLastError());
208
210
  }
209
211
 
210
212
  const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
@@ -56,13 +56,13 @@ static __global__ void quantize_mmq_q8_1(
56
56
  constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
57
57
  constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
58
58
 
59
- const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
59
+ const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4;
60
60
 
61
61
  if (i0 >= ne0) {
62
62
  return;
63
63
  }
64
64
 
65
- const int64_t i1 = blockIdx.y;
65
+ const int64_t i1 = blockIdx.x;
66
66
  const int64_t i2 = blockIdx.z % ne2;
67
67
  const int64_t i3 = blockIdx.z / ne2;
68
68
 
@@ -75,8 +75,8 @@ static __global__ void quantize_mmq_q8_1(
75
75
 
76
76
  block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
77
77
 
78
- const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
79
- const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel
78
+ const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel
79
+ const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel
80
80
  const int64_t iqs = i0 % (4*QK8_1); // quant index in block
81
81
 
82
82
  // Load 4 floats per thread and calculate max. abs. value between them:
@@ -166,8 +166,9 @@ void quantize_mmq_q8_1_cuda(
166
166
  GGML_ASSERT(ne00 % 4 == 0);
167
167
  GGML_ASSERT(ne0 % (4*QK8_1) == 0);
168
168
 
169
- const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
170
- const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
169
+ // ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
170
+ const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
171
+ const dim3 num_blocks(ne1, block_num_y, ne2*ne3);
171
172
  const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
172
173
  switch (mmq_get_q8_1_ds_layout(type_src0)) {
173
174
  case MMQ_Q8_1_DS_LAYOUT_D4:
@@ -31,7 +31,7 @@ void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
31
31
 
32
32
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
33
33
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
34
- GGML_ASSERT(ggml_is_contiguous(src0));
34
+ GGML_ASSERT(ggml_is_contiguously_allocated(src0));
35
35
 
36
36
  const float * src0_d = (const float *) src0->data;
37
37
  float * dst_d = (float *) dst->data;
@@ -23,6 +23,12 @@ static __device__ __forceinline__ float op_gelu(float x) {
23
23
  return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
24
24
  }
25
25
 
26
+ static __device__ __forceinline__ float op_gelu_erf(float x) {
27
+ const float SQRT_2_INV = 0.70710678118654752440084436210484f;
28
+
29
+ return 0.5f*x*(1.0f + erff(x*SQRT_2_INV));
30
+ }
31
+
26
32
  static __device__ __forceinline__ float op_gelu_quick(float x) {
27
33
  const float GELU_QUICK_COEF = -1.702f;
28
34
 
@@ -134,6 +140,10 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
134
140
  ggml_cuda_op_unary<op_gelu>(ctx, dst);
135
141
  }
136
142
 
143
+ void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
144
+ ggml_cuda_op_unary<op_gelu_erf>(ctx, dst);
145
+ }
146
+
137
147
  void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
138
148
  ggml_cuda_op_unary<op_gelu_quick>(ctx, dst);
139
149
  }
@@ -30,6 +30,8 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
30
30
 
31
31
  void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
32
32
 
33
+ void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
34
+
33
35
  void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
34
36
 
35
37
  void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -386,7 +386,7 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size);
386
386
  return r;
387
387
  }
388
388
 
389
- #elif defined(__riscv) && defined(GGML_RV_ZFH)
389
+ #elif defined(__riscv) && defined(__riscv_zfhmin)
390
390
 
391
391
  static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
392
392
  float f;
@@ -207,6 +207,10 @@ typedef struct {
207
207
  float attn_factor;
208
208
  float beta_fast;
209
209
  float beta_slow;
210
+ int32_t sect_0;
211
+ int32_t sect_1;
212
+ int32_t sect_2;
213
+ int32_t sect_3;
210
214
  } ggml_metal_kargs_rope;
211
215
 
212
216
  typedef struct {
@@ -149,6 +149,8 @@ enum ggml_metal_kernel_type {
149
149
  GGML_METAL_KERNEL_TYPE_SIGMOID,
150
150
  GGML_METAL_KERNEL_TYPE_GELU,
151
151
  GGML_METAL_KERNEL_TYPE_GELU_4,
152
+ GGML_METAL_KERNEL_TYPE_GELU_ERF,
153
+ GGML_METAL_KERNEL_TYPE_GELU_ERF_4,
152
154
  GGML_METAL_KERNEL_TYPE_GELU_QUICK,
153
155
  GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
154
156
  GGML_METAL_KERNEL_TYPE_SILU,
@@ -332,6 +334,10 @@ enum ggml_metal_kernel_type {
332
334
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
333
335
  GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
334
336
  GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
337
+ GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
338
+ GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
339
+ GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
340
+ GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
335
341
  GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
336
342
  GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
337
343
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
@@ -411,6 +417,13 @@ enum ggml_metal_kernel_type {
411
417
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
412
418
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
413
419
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
420
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
421
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
422
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
423
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64,
424
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64,
425
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64,
426
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64,
414
427
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
415
428
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
416
429
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
@@ -1092,6 +1105,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1092
1105
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
1093
1106
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
1094
1107
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
1108
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF, gelu_erf, true);
1109
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF_4, gelu_erf_4, true);
1095
1110
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
1096
1111
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
1097
1112
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
@@ -1275,6 +1290,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1275
1290
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
1276
1291
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
1277
1292
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
1293
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
1294
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true);
1295
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true);
1296
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true);
1278
1297
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
1279
1298
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
1280
1299
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
@@ -1354,6 +1373,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1354
1373
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
1355
1374
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
1356
1375
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
1376
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
1377
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
1378
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
1379
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction);
1380
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction);
1381
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction);
1382
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction);
1357
1383
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
1358
1384
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
1359
1385
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
@@ -1591,6 +1617,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1591
1617
  case GGML_UNARY_OP_RELU:
1592
1618
  case GGML_UNARY_OP_SIGMOID:
1593
1619
  case GGML_UNARY_OP_GELU:
1620
+ case GGML_UNARY_OP_GELU_ERF:
1594
1621
  case GGML_UNARY_OP_GELU_QUICK:
1595
1622
  case GGML_UNARY_OP_SILU:
1596
1623
  case GGML_UNARY_OP_ELU:
@@ -1637,16 +1664,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1637
1664
  case GGML_OP_NORM:
1638
1665
  return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
1639
1666
  case GGML_OP_ROPE:
1640
- {
1641
- const int mode = ((const int32_t *) op->op_params)[2];
1642
- if (mode & GGML_ROPE_TYPE_MROPE) {
1643
- return false;
1644
- }
1645
- if (mode & GGML_ROPE_TYPE_VISION) {
1646
- return false;
1647
- }
1648
- return true;
1649
- }
1667
+ return true;
1650
1668
  case GGML_OP_IM2COL:
1651
1669
  return op->src[0]->type == GGML_TYPE_F16;
1652
1670
  case GGML_OP_POOL_1D:
@@ -2238,6 +2256,25 @@ static bool ggml_metal_encode_node(
2238
2256
 
2239
2257
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2240
2258
  } break;
2259
+ case GGML_UNARY_OP_GELU_ERF:
2260
+ {
2261
+ int64_t n = ggml_nelements(dst);
2262
+
2263
+ id<MTLComputePipelineState> pipeline = nil;
2264
+
2265
+ if (n % 4 == 0) {
2266
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline;
2267
+ n /= 4;
2268
+ } else {
2269
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline;
2270
+ }
2271
+
2272
+ [encoder setComputePipelineState:pipeline];
2273
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2274
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2275
+
2276
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2277
+ } break;
2241
2278
  case GGML_UNARY_OP_GELU_QUICK:
2242
2279
  {
2243
2280
  int64_t n = ggml_nelements(dst);
@@ -3826,6 +3863,7 @@ static bool ggml_metal_encode_node(
3826
3863
  } break;
3827
3864
  case GGML_OP_ROPE:
3828
3865
  {
3866
+
3829
3867
  // make sure we have one or more position id(ne10) per token(ne02)
3830
3868
  GGML_ASSERT(ne10 % ne02 == 0);
3831
3869
  GGML_ASSERT(ne10 >= ne02);
@@ -3852,20 +3890,42 @@ static bool ggml_metal_encode_node(
3852
3890
  memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
3853
3891
  memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
3854
3892
 
3855
- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
3893
+ const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
3894
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
3895
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
3896
+
3897
+ // mrope
3898
+ const int sect_0 = ((const int32_t *) dst->op_params)[11];
3899
+ const int sect_1 = ((const int32_t *) dst->op_params)[12];
3900
+ const int sect_2 = ((const int32_t *) dst->op_params)[13];
3901
+ const int sect_3 = ((const int32_t *) dst->op_params)[14];
3856
3902
 
3857
3903
  id<MTLComputePipelineState> pipeline = nil;
3858
3904
 
3859
- if (!is_neox) {
3905
+ if (is_neox) {
3860
3906
  switch (src0->type) {
3861
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
3862
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
3907
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
3908
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
3909
+ default: GGML_ABORT("fatal error");
3910
+ };
3911
+ } else if (is_mrope && !is_vision) {
3912
+ GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
3913
+ switch (src0->type) {
3914
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
3915
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
3916
+ default: GGML_ABORT("fatal error");
3917
+ };
3918
+ } else if (is_vision) {
3919
+ GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
3920
+ switch (src0->type) {
3921
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
3922
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
3863
3923
  default: GGML_ABORT("fatal error");
3864
3924
  };
3865
3925
  } else {
3866
3926
  switch (src0->type) {
3867
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
3868
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
3927
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
3928
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
3869
3929
  default: GGML_ABORT("fatal error");
3870
3930
  };
3871
3931
  }
@@ -3896,6 +3956,10 @@ static bool ggml_metal_encode_node(
3896
3956
  /*.attn_factor =*/ attn_factor,
3897
3957
  /*.beta_fast =*/ beta_fast,
3898
3958
  /*.beta_slow =*/ beta_slow,
3959
+ /* sect_0 =*/ sect_0,
3960
+ /* sect_1 =*/ sect_1,
3961
+ /* sect_2 =*/ sect_2,
3962
+ /* sect_3 =*/ sect_3,
3899
3963
  };
3900
3964
 
3901
3965
  [encoder setComputePipelineState:pipeline];
@@ -4332,7 +4396,7 @@ static bool ggml_metal_encode_node(
4332
4396
  // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
4333
4397
  // for now avoiding mainly to keep the number of templates/kernels a bit lower
4334
4398
  // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
4335
- if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
4399
+ if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
4336
4400
  switch (src1->type) {
4337
4401
  case GGML_TYPE_F16:
4338
4402
  {
@@ -4513,6 +4577,24 @@ static bool ggml_metal_encode_node(
4513
4577
  use_vec_kernel = true;
4514
4578
 
4515
4579
  switch (ne00) {
4580
+ case 64:
4581
+ {
4582
+ switch (src1->type) {
4583
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break;
4584
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break;
4585
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break;
4586
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break;
4587
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break;
4588
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break;
4589
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break;
4590
+ default:
4591
+ {
4592
+ GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
4593
+ GGML_LOG_ERROR("add template specialization for this type\n");
4594
+ GGML_ABORT("add template specialization for this type");
4595
+ }
4596
+ }
4597
+ } break;
4516
4598
  case 96:
4517
4599
  {
4518
4600
  switch (src1->type) {