@novastera-oss/llamarn 0.2.1 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. package/README.md +80 -14
  2. package/RNLlamaCpp.podspec +10 -3
  3. package/android/CMakeLists.txt +8 -0
  4. package/android/src/main/cpp/include/llama.h +62 -125
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  13. package/cpp/build-info.cpp +2 -2
  14. package/cpp/llama.cpp/README.md +11 -3
  15. package/cpp/llama.cpp/build-xcframework.sh +1 -0
  16. package/cpp/llama.cpp/common/CMakeLists.txt +8 -2
  17. package/cpp/llama.cpp/common/arg.cpp +153 -113
  18. package/cpp/llama.cpp/common/chat-parser.cpp +379 -0
  19. package/cpp/llama.cpp/common/chat-parser.h +117 -0
  20. package/cpp/llama.cpp/common/chat.cpp +847 -699
  21. package/cpp/llama.cpp/common/chat.h +73 -6
  22. package/cpp/llama.cpp/common/common.cpp +50 -82
  23. package/cpp/llama.cpp/common/common.h +21 -17
  24. package/cpp/llama.cpp/common/json-partial.cpp +255 -0
  25. package/cpp/llama.cpp/common/json-partial.h +37 -0
  26. package/cpp/llama.cpp/common/minja/chat-template.hpp +9 -5
  27. package/cpp/llama.cpp/common/minja/minja.hpp +69 -36
  28. package/cpp/llama.cpp/common/regex-partial.cpp +204 -0
  29. package/cpp/llama.cpp/common/regex-partial.h +56 -0
  30. package/cpp/llama.cpp/common/sampling.cpp +7 -8
  31. package/cpp/llama.cpp/convert_hf_to_gguf.py +453 -118
  32. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +120 -68
  33. package/cpp/llama.cpp/ggml/CMakeLists.txt +2 -1
  34. package/cpp/llama.cpp/ggml/cmake/common.cmake +25 -0
  35. package/cpp/llama.cpp/ggml/include/ggml-opt.h +49 -28
  36. package/cpp/llama.cpp/ggml/include/ggml.h +26 -7
  37. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +16 -10
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +4 -1
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +2 -0
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +604 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +42 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +54 -2
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +50 -51
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -2
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -9
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +779 -19
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +22 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +322 -100
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +117 -1
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +220 -49
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/acc.cu +40 -26
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +11 -1
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +15 -7
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +266 -64
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +49 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +48 -4
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +2 -1
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +5 -1
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +2 -0
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +7 -6
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +1 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +10 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-impl.h +1 -1
  71. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +99 -17
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +200 -2
  74. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +8 -2
  75. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +6 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +972 -178
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +373 -190
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -10
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +101 -5
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +31 -33
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +1 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +29 -2
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +4 -5
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +9 -1
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +84 -72
  96. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  98. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -3
  99. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +324 -129
  100. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +1 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +31 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +95 -68
  103. package/cpp/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +1 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +22 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -4
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +2 -3
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +69 -43
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +2 -14
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -91
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -181
  112. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +17 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  114. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +6 -152
  115. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  117. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +2 -118
  118. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +1 -1
  119. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +12 -1
  120. package/cpp/llama.cpp/ggml/src/ggml.c +107 -36
  121. package/cpp/llama.cpp/ggml/src/gguf.cpp +33 -33
  122. package/cpp/llama.cpp/gguf-py/gguf/constants.py +100 -15
  123. package/cpp/llama.cpp/gguf-py/gguf/gguf_reader.py +1 -1
  124. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +44 -12
  125. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py +21 -10
  126. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +5 -2
  127. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +128 -31
  128. package/cpp/llama.cpp/gguf-py/gguf/utility.py +1 -1
  129. package/cpp/llama.cpp/gguf-py/pyproject.toml +1 -1
  130. package/cpp/llama.cpp/include/llama.h +62 -125
  131. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +1 -1
  132. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +1 -1
  133. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +1 -1
  134. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +1 -1
  135. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +1 -1
  136. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +1 -1
  137. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +1 -1
  138. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +1 -1
  139. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +1 -1
  140. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +1 -1
  141. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +1 -1
  142. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +1 -1
  143. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +1 -1
  144. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +1 -1
  145. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +1 -1
  146. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +1 -1
  147. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +1 -1
  148. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +1 -1
  149. package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
  150. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +1 -1
  151. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +1 -1
  152. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +1 -1
  153. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +1 -1
  154. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +1 -1
  155. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +1 -1
  156. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +1 -1
  157. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +1 -1
  158. package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +62 -0
  159. package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +85 -0
  160. package/cpp/llama.cpp/models/templates/README.md +2 -0
  161. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +5 -1
  162. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +5 -1
  163. package/cpp/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
  164. package/cpp/llama.cpp/requirements/requirements-gguf_editor_gui.txt +1 -1
  165. package/cpp/llama.cpp/src/CMakeLists.txt +2 -0
  166. package/cpp/llama.cpp/src/llama-arch.cpp +6 -0
  167. package/cpp/llama.cpp/src/llama-arch.h +2 -0
  168. package/cpp/llama.cpp/src/llama-batch.cpp +3 -1
  169. package/cpp/llama.cpp/src/llama-context.cpp +340 -123
  170. package/cpp/llama.cpp/src/llama-context.h +30 -0
  171. package/cpp/llama.cpp/src/llama-cparams.cpp +4 -0
  172. package/cpp/llama.cpp/src/llama-cparams.h +2 -0
  173. package/cpp/llama.cpp/src/llama-grammar.cpp +12 -2
  174. package/cpp/llama.cpp/src/llama-graph.cpp +157 -247
  175. package/cpp/llama.cpp/src/llama-graph.h +52 -7
  176. package/cpp/llama.cpp/src/llama-hparams.cpp +17 -1
  177. package/cpp/llama.cpp/src/llama-hparams.h +37 -5
  178. package/cpp/llama.cpp/src/llama-kv-cache.cpp +742 -481
  179. package/cpp/llama.cpp/src/llama-kv-cache.h +196 -99
  180. package/cpp/llama.cpp/src/llama-kv-cells.h +379 -0
  181. package/cpp/llama.cpp/src/llama-memory.h +4 -3
  182. package/cpp/llama.cpp/src/llama-model-loader.cpp +22 -17
  183. package/cpp/llama.cpp/src/llama-model-saver.cpp +281 -0
  184. package/cpp/llama.cpp/src/llama-model-saver.h +37 -0
  185. package/cpp/llama.cpp/src/llama-model.cpp +529 -172
  186. package/cpp/llama.cpp/src/llama-model.h +6 -1
  187. package/cpp/llama.cpp/src/llama-quant.cpp +15 -13
  188. package/cpp/llama.cpp/src/llama-sampling.cpp +2 -2
  189. package/cpp/llama.cpp/src/llama-vocab.cpp +35 -8
  190. package/cpp/llama.cpp/src/llama-vocab.h +6 -0
  191. package/cpp/llama.cpp/src/llama.cpp +14 -0
  192. package/cpp/rn-completion.cpp +4 -2
  193. package/ios/include/chat.h +73 -6
  194. package/ios/include/common/minja/chat-template.hpp +9 -5
  195. package/ios/include/common/minja/minja.hpp +69 -36
  196. package/ios/include/common.h +21 -17
  197. package/ios/include/llama.h +62 -125
  198. package/ios/libs/llama.xcframework/Info.plist +19 -19
  199. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  200. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4617 -4487
  201. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  202. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +26 -7
  203. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +62 -125
  204. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  205. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  206. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  207. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3557 -3435
  208. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  209. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  210. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  211. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  212. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  213. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3559 -3437
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +237 -0
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +26 -7
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +62 -125
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +237 -0
  219. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +26 -7
  220. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +62 -125
  221. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  222. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +237 -0
  223. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +26 -7
  224. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +62 -125
  225. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  226. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  227. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  228. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4616 -4487
  229. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +26 -7
  231. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +62 -125
  232. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4637 -4508
  235. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3556 -3435
  236. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  237. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  238. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  239. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  240. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  241. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4653 -4523
  242. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  243. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +26 -7
  244. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +62 -125
  245. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  246. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  247. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4674 -4544
  248. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3587 -3465
  249. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  250. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  251. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  252. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  253. package/package.json +1 -1
  254. package/cpp/llama.cpp/common/stb_image.h +0 -7988
  255. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +0 -112
  256. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.out +0 -46
  257. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +0 -112
  258. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +0 -46
  259. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +0 -112
  260. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +0 -46
  261. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.inp +0 -112
  262. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.out +0 -46
  263. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +0 -112
  264. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.out +0 -46
  265. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +0 -112
  266. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +0 -46
@@ -7,6 +7,7 @@
7
7
  #include "llama-adapter.h"
8
8
 
9
9
  #include "ggml-cpp.h"
10
+ #include "ggml-opt.h"
10
11
 
11
12
  #include <map>
12
13
  #include <vector>
@@ -133,6 +134,32 @@ struct llama_context {
133
134
  llama_perf_context_data perf_get_data() const;
134
135
  void perf_reset();
135
136
 
137
+ //
138
+ // training
139
+ //
140
+
141
+ void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
142
+
143
+ void opt_epoch(
144
+ ggml_opt_dataset_t dataset,
145
+ ggml_opt_result_t result_train,
146
+ ggml_opt_result_t result_eval,
147
+ int64_t idata_split,
148
+ ggml_opt_epoch_callback callback_train,
149
+ ggml_opt_epoch_callback callback_eval);
150
+
151
+ void opt_epoch_iter(
152
+ ggml_opt_dataset_t dataset,
153
+ ggml_opt_result_t result,
154
+ const std::vector<llama_token> & tokens,
155
+ const std::vector<llama_token> & labels_sparse,
156
+ llama_batch & batch,
157
+ ggml_opt_epoch_callback callback,
158
+ bool train,
159
+ int64_t idata_in_loop,
160
+ int64_t ndata_in_loop,
161
+ int64_t t_loop_start);
162
+
136
163
  private:
137
164
  //
138
165
  // output
@@ -212,6 +239,9 @@ private:
212
239
 
213
240
  ggml_context_ptr ctx_compute;
214
241
 
242
+ // training
243
+ ggml_opt_context_t opt_ctx = nullptr;
244
+
215
245
  ggml_threadpool_t threadpool = nullptr;
216
246
  ggml_threadpool_t threadpool_batch = nullptr;
217
247
 
@@ -1 +1,5 @@
1
1
  #include "llama-cparams.h"
2
+
3
+ size_t llama_max_parallel_sequences(void) {
4
+ return LLAMA_MAX_PARALLEL_SEQUENCES;
5
+ }
@@ -4,6 +4,8 @@
4
4
 
5
5
  #include <cstdint>
6
6
 
7
+ #define LLAMA_MAX_PARALLEL_SEQUENCES 64
8
+
7
9
  struct llama_cparams {
8
10
  uint32_t n_ctx; // context size used during inference
9
11
  uint32_t n_batch;
@@ -1177,8 +1177,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
1177
1177
  for (const auto & trigger_pattern : grammar.trigger_patterns) {
1178
1178
  if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
1179
1179
  grammar.awaiting_trigger = false;
1180
- // get from the first match to the end of the string
1181
- auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
1180
+ // get from the first matched capturing group to the end of the string
1181
+ size_t start = std::string::npos;
1182
+ for (auto i = 1u; i < match.size(); i++) {
1183
+ if (match.length(i) > 0) {
1184
+ start = match.position(i);
1185
+ break;
1186
+ }
1187
+ }
1188
+ if (start == std::string::npos) {
1189
+ start = match.position(0);
1190
+ }
1191
+ auto constrained_str = grammar.trigger_buffer.substr(start);
1182
1192
  // std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
1183
1193
  grammar.trigger_buffer.clear();
1184
1194
  llama_grammar_accept_str(grammar, constrained_str);
@@ -9,33 +9,6 @@
9
9
  #include <cmath>
10
10
  #include <cstring>
11
11
 
12
- static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
13
- // TODO move to hparams if a T5 variant appears that uses a different value
14
- const int64_t max_distance = 128;
15
-
16
- if (bidirectional) {
17
- n_buckets >>= 1;
18
- }
19
-
20
- const int64_t max_exact = n_buckets >> 1;
21
-
22
- int32_t relative_position = x - y;
23
- int32_t relative_bucket = 0;
24
-
25
- if (bidirectional) {
26
- relative_bucket += (relative_position > 0) * n_buckets;
27
- relative_position = abs(relative_position);
28
- } else {
29
- relative_position = -std::min<int32_t>(relative_position, 0);
30
- }
31
-
32
- int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
33
- relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
34
- relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
35
-
36
- return relative_bucket;
37
- }
38
-
39
12
  void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
40
13
  if (ubatch->token) {
41
14
  const int64_t n_tokens = ubatch->n_tokens;
@@ -110,22 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
110
83
 
111
84
  void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
112
85
  if (pos_bucket) {
113
- const int64_t n_tokens = ubatch->n_tokens;
114
-
115
- GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
116
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
117
-
118
- int32_t * data = (int32_t *) pos_bucket->data;
119
-
120
- const int64_t n_kv = kv_self->n;
121
-
122
- for (int h = 0; h < 1; ++h) {
123
- for (int j = 0; j < n_tokens; ++j) {
124
- for (int i = 0; i < n_kv; ++i) {
125
- data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
126
- }
127
- }
128
- }
86
+ kv_self->set_input_pos_bucket(pos_bucket, ubatch);
129
87
  }
130
88
  }
131
89
 
@@ -403,99 +361,18 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
403
361
  }
404
362
 
405
363
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
406
- if (self_kq_mask || self_kq_mask_swa) {
407
- const int64_t n_kv = kv_self->n;
408
- const int64_t n_tokens = ubatch->n_tokens;
409
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
410
- const int64_t n_seqs = ubatch->n_seqs;
411
-
412
- float * data = nullptr;
413
- float * data_swa = nullptr;
414
-
415
- if (self_kq_mask) {
416
- GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
417
- data = (float *) self_kq_mask->data;
418
- }
419
-
420
- if (self_kq_mask_swa) {
421
- GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
422
- data_swa = (float *) self_kq_mask_swa->data;
423
- }
424
-
425
- // Use only the previous KV cells of the correct sequence for each token of the ubatch.
426
- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
427
- // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
428
- // Causal mask:
429
- // xxx-------
430
- // xxxx------
431
- // xxxxx-----
432
- // Non-causal mask:
433
- // xxxxx-----
434
- // xxxxx-----
435
- // xxxxx-----
436
- // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
437
- for (int h = 0; h < 1; ++h) {
438
- for (int s = 0; s < n_seqs; ++s) {
439
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
440
-
441
- for (int j = 0; j < n_seq_tokens; ++j) {
442
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
443
- for (int i = 0; i < n_kv; ++i) {
444
- float f;
445
- // mask the token if:
446
- if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
447
- || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
448
- ) {
449
- f = -INFINITY;
450
- } else {
451
- if (hparams.use_alibi) {
452
- f = -std::abs(kv_self->cells[i].pos - pos);
453
- } else {
454
- f = 0.0f;
455
- }
456
- }
457
-
458
- if (data) {
459
- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
460
- }
461
-
462
- // may need to cut off old tokens for sliding window
463
- // TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
464
- if (data_swa) {
465
- if (hparams.n_attn_chunk) {
466
- llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
467
- if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
468
- f = -INFINITY;
469
- }
470
- } else {
471
- if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
472
- f = -INFINITY;
473
- }
474
- }
475
- data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
476
- }
477
- }
478
- }
479
- }
364
+ if (self_kq_mask) {
365
+ kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
366
+ }
367
+ }
480
368
 
481
- // mask padded tokens
482
- if (data) {
483
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
484
- for (int j = 0; j < n_kv; ++j) {
485
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
486
- }
487
- }
488
- }
369
+ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
370
+ if (self_kq_mask) {
371
+ kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
372
+ }
489
373
 
490
- // mask padded tokens
491
- if (data_swa) {
492
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
493
- for (int j = 0; j < n_kv; ++j) {
494
- data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
495
- }
496
- }
497
- }
498
- }
374
+ if (self_kq_mask_swa) {
375
+ kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
499
376
  }
500
377
  }
501
378
 
@@ -545,7 +422,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
545
422
  n_layer (hparams.n_layer),
546
423
  n_rot (hparams.n_rot),
547
424
  n_ctx (cparams.n_ctx),
548
- n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
549
425
  n_head (hparams.n_head()),
550
426
  n_head_kv (hparams.n_head_kv()),
551
427
  n_embd_head_k (hparams.n_embd_head_k),
@@ -579,7 +455,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
579
455
  }
580
456
 
581
457
  int64_t llm_graph_context::n_pos_per_embd() const {
582
- return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
458
+ return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
583
459
  }
584
460
 
585
461
  void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
@@ -971,6 +847,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
971
847
  inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
972
848
  //cb(inp->tokens, "inp_tokens", -1);
973
849
  ggml_set_input(inp->tokens);
850
+ res->t_tokens = inp->tokens;
974
851
 
975
852
  cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
976
853
 
@@ -1152,7 +1029,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1152
1029
 
1153
1030
  auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
1154
1031
 
1155
- const auto n_kv = kv_self->n;
1032
+ const auto n_kv = kv_self->get_n();
1156
1033
 
1157
1034
  auto & cur = inp->pos_bucket;
1158
1035
 
@@ -1187,16 +1064,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1187
1064
  ggml_tensor * kq_b,
1188
1065
  ggml_tensor * kq_mask,
1189
1066
  ggml_tensor * v_mla,
1190
- bool v_trans,
1191
1067
  float kq_scale) const {
1192
- //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1193
- //const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1068
+ const bool v_trans = v->nb[1] > v->nb[2];
1194
1069
 
1195
- //const int64_t n_head = hparams.n_head(il);
1196
- //const int64_t n_head_kv = hparams.n_head_kv(il);
1197
-
1198
- //const auto & n_embd_head_k = hparams.n_embd_head_k;
1199
- //const auto & n_embd_head_v = hparams.n_embd_head_v;
1070
+ q = ggml_permute(ctx0, q, 0, 2, 1, 3);
1071
+ k = ggml_permute(ctx0, k, 0, 2, 1, 3);
1072
+ v = ggml_permute(ctx0, v, 0, 2, 1, 3);
1200
1073
 
1201
1074
  const auto n_tokens = q->ne[1];
1202
1075
  const auto n_head = q->ne[2];
@@ -1335,17 +1208,11 @@ ggml_tensor * llm_graph_context::build_attn(
1335
1208
 
1336
1209
  const auto & kq_mask = inp->get_kq_mask();
1337
1210
 
1338
- ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1339
- //cb(q, "q", il);
1340
-
1341
- ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
1342
- //cb(k, "k", il);
1343
-
1344
- ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1345
- //cb(k, "v", il);
1346
-
1347
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
1211
+ ggml_tensor * q = q_cur;
1212
+ ggml_tensor * k = k_cur;
1213
+ ggml_tensor * v = v_cur;
1348
1214
 
1215
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1349
1216
  cb(cur, "kqv_out", il);
1350
1217
 
1351
1218
  if (wo) {
@@ -1368,22 +1235,16 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
1368
1235
 
1369
1236
  auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
1370
1237
 
1371
- const auto n_kv = kv_self->n;
1372
-
1373
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1374
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1375
- ggml_set_input(inp->self_kq_mask);
1376
-
1377
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1238
+ {
1239
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1378
1240
 
1379
- if (hparams.n_swa_pattern > 1) {
1380
- GGML_ASSERT(hparams.n_swa > 0);
1241
+ const auto n_kv = kv_self->get_n();
1381
1242
 
1382
- inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1383
- //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1384
- ggml_set_input(inp->self_kq_mask_swa);
1243
+ inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1244
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1245
+ ggml_set_input(inp->self_kq_mask);
1385
1246
 
1386
- inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1247
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1387
1248
  }
1388
1249
 
1389
1250
  return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
@@ -1408,81 +1269,104 @@ ggml_tensor * llm_graph_context::build_attn(
1408
1269
  ggml_build_forward_expand(gf, v_cur);
1409
1270
 
1410
1271
  const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1411
- const auto & n_ctx = cparams.n_ctx;
1412
1272
 
1413
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1414
- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1273
+ // store to KV cache
1274
+ {
1275
+ ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1276
+ ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
1277
+ }
1278
+
1279
+ const auto & kq_mask = inp->get_kq_mask();
1415
1280
 
1416
- const auto n_tokens = q_cur->ne[2];
1281
+ ggml_tensor * q = q_cur;
1282
+ ggml_tensor * k = kv_self->get_k(ctx0, il);
1283
+ ggml_tensor * v = kv_self->get_v(ctx0, il);
1417
1284
 
1418
- const bool v_trans = !cparams.flash_attn;
1285
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1286
+ cb(cur, "kqv_out", il);
1419
1287
 
1420
- // store to KV cache
1421
- {
1422
- const auto kv_head = kv_self->head;
1288
+ if (wo) {
1289
+ cur = build_lora_mm(wo, cur);
1290
+ if (arch == LLM_ARCH_GLM4) {
1291
+ // GLM4 seems to have numerical issues with half-precision accumulators
1292
+ ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1293
+ }
1294
+ }
1423
1295
 
1424
- GGML_ASSERT(kv_self->size == n_ctx);
1296
+ if (wo_b) {
1297
+ cur = ggml_add(ctx0, cur, wo_b);
1298
+ }
1425
1299
 
1426
- ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
1427
- //cb(k_cache_view, "k_cache_view", il);
1300
+ return cur;
1301
+ }
1428
1302
 
1429
- // note: storing RoPE-ed version of K in the KV cache
1430
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
1303
+ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1304
+ const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1431
1305
 
1432
- v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
1306
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
1433
1307
 
1434
- ggml_tensor * v_cache_view = nullptr;
1308
+ {
1309
+ const auto n_kv = kv_self->get_kv_base()->get_n();
1435
1310
 
1436
- if (!v_trans) {
1437
- v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
1438
- } else {
1439
- // note: the V cache is transposed when not using flash attention
1440
- v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
1441
- ( n_ctx)*ggml_element_size(kv_self->v_l[il]),
1442
- (kv_head)*ggml_element_size(kv_self->v_l[il]));
1311
+ inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1312
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1313
+ ggml_set_input(inp->self_kq_mask);
1443
1314
 
1444
- v_cur = ggml_transpose(ctx0, v_cur);
1445
- }
1446
- //cb(v_cache_view, "v_cache_view", il);
1315
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1316
+ }
1317
+
1318
+ {
1319
+ GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1320
+
1321
+ const auto n_kv = kv_self->get_kv_swa()->get_n();
1322
+
1323
+ inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1324
+ //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1325
+ ggml_set_input(inp->self_kq_mask_swa);
1447
1326
 
1448
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
1327
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1449
1328
  }
1450
1329
 
1330
+ return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1331
+ }
1332
+
1333
+ ggml_tensor * llm_graph_context::build_attn(
1334
+ llm_graph_input_attn_kv_unified_iswa * inp,
1335
+ ggml_cgraph * gf,
1336
+ ggml_tensor * wo,
1337
+ ggml_tensor * wo_b,
1338
+ ggml_tensor * q_cur,
1339
+ ggml_tensor * k_cur,
1340
+ ggml_tensor * v_cur,
1341
+ ggml_tensor * kq_b,
1342
+ ggml_tensor * v_mla,
1343
+ float kq_scale,
1344
+ int il) const {
1345
+ // these nodes are added to the graph together so that they are not reordered
1346
+ // by doing so, the number of splits in the graph is reduced
1347
+ ggml_build_forward_expand(gf, q_cur);
1348
+ ggml_build_forward_expand(gf, k_cur);
1349
+ ggml_build_forward_expand(gf, v_cur);
1350
+
1451
1351
  const bool is_swa = hparams.is_swa(il);
1452
1352
 
1353
+ const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1354
+
1355
+ const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
1356
+
1357
+ // store to KV cache
1358
+ {
1359
+ ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
1360
+ ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
1361
+ }
1362
+
1453
1363
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1454
1364
 
1455
- const auto n_kv = kv_self->n;
1365
+ ggml_tensor * q = q_cur;
1366
+ ggml_tensor * k = kv->get_k(ctx0, il);
1367
+ ggml_tensor * v = kv->get_v(ctx0, il);
1456
1368
 
1457
- const int64_t n_head_kv = hparams.n_head_kv(il);
1458
-
1459
- const auto & n_embd_head_k = hparams.n_embd_head_k;
1460
- const auto & n_embd_head_v = hparams.n_embd_head_v;
1461
-
1462
- ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1463
- //cb(q, "q", il);
1464
-
1465
- ggml_tensor * k =
1466
- ggml_view_3d(ctx0, kv_self->k_l[il],
1467
- n_embd_head_k, n_kv, n_head_kv,
1468
- ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
1469
- ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
1470
- 0);
1471
- //cb(k, "k", il);
1472
-
1473
- ggml_tensor * v = !v_trans ?
1474
- ggml_view_3d(ctx0, kv_self->v_l[il],
1475
- n_embd_head_v, n_kv, n_head_kv,
1476
- ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
1477
- ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
1478
- 0) :
1479
- ggml_view_3d(ctx0, kv_self->v_l[il],
1480
- n_kv, n_embd_head_v, n_head_kv,
1481
- ggml_element_size(kv_self->v_l[il])*n_ctx,
1482
- ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
1483
- 0);
1484
-
1485
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
1369
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1486
1370
  cb(cur, "kqv_out", il);
1487
1371
 
1488
1372
  if (wo) {
@@ -1533,17 +1417,11 @@ ggml_tensor * llm_graph_context::build_attn(
1533
1417
 
1534
1418
  const auto & kq_mask = inp->get_kq_mask_cross();
1535
1419
 
1536
- ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1537
- //cb(q, "q", il);
1538
-
1539
- ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
1540
- //cb(k, "k", il);
1541
-
1542
- ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1543
- //cb(k, "v", il);
1544
-
1545
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
1420
+ ggml_tensor * q = q_cur;
1421
+ ggml_tensor * k = k_cur;
1422
+ ggml_tensor * v = v_cur;
1546
1423
 
1424
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1547
1425
  cb(cur, "kqv_out", il);
1548
1426
 
1549
1427
  if (wo) {
@@ -1684,20 +1562,25 @@ void llm_graph_context::build_pooling(
1684
1562
  ggml_tensor * inp_cls = build_inp_cls();
1685
1563
  inp = ggml_get_rows(ctx0, inp, inp_cls);
1686
1564
 
1687
- // classification head
1688
- // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1689
- GGML_ASSERT(cls != nullptr);
1690
- GGML_ASSERT(cls_b != nullptr);
1691
-
1692
- cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
1693
- cur = ggml_tanh(ctx0, cur);
1694
-
1695
- // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1696
- // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1697
- if (cls_out) {
1565
+ if (cls != nullptr && cls_b != nullptr) {
1566
+ // classification head
1567
+ // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1568
+ cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
1569
+ cur = ggml_tanh(ctx0, cur);
1570
+
1571
+ // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1572
+ // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1573
+ if (cls_out) {
1574
+ GGML_ASSERT(cls_out_b != nullptr);
1575
+ cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
1576
+ }
1577
+ } else if (cls_out) {
1578
+ // Single layer classification head (direct projection)
1579
+ // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1698
1580
  GGML_ASSERT(cls_out_b != nullptr);
1699
-
1700
- cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
1581
+ cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
1582
+ } else {
1583
+ GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
1701
1584
  }
1702
1585
  } break;
1703
1586
  default:
@@ -1711,3 +1594,30 @@ void llm_graph_context::build_pooling(
1711
1594
 
1712
1595
  ggml_build_forward_expand(gf, cur);
1713
1596
  }
1597
+
1598
+ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
1599
+ // TODO move to hparams if a T5 variant appears that uses a different value
1600
+ const int64_t max_distance = 128;
1601
+
1602
+ if (bidirectional) {
1603
+ n_buckets >>= 1;
1604
+ }
1605
+
1606
+ const int64_t max_exact = n_buckets >> 1;
1607
+
1608
+ int32_t relative_position = x - y;
1609
+ int32_t relative_bucket = 0;
1610
+
1611
+ if (bidirectional) {
1612
+ relative_bucket += (relative_position > 0) * n_buckets;
1613
+ relative_position = abs(relative_position);
1614
+ } else {
1615
+ relative_position = -std::min<int32_t>(relative_position, 0);
1616
+ }
1617
+
1618
+ int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
1619
+ relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
1620
+ relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
1621
+
1622
+ return relative_bucket;
1623
+ }