@novastera-oss/llamarn 0.2.1 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. package/README.md +80 -14
  2. package/RNLlamaCpp.podspec +10 -3
  3. package/android/CMakeLists.txt +8 -0
  4. package/android/src/main/cpp/include/llama.h +62 -125
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  13. package/cpp/build-info.cpp +2 -2
  14. package/cpp/llama.cpp/README.md +11 -3
  15. package/cpp/llama.cpp/build-xcframework.sh +1 -0
  16. package/cpp/llama.cpp/common/CMakeLists.txt +8 -2
  17. package/cpp/llama.cpp/common/arg.cpp +153 -113
  18. package/cpp/llama.cpp/common/chat-parser.cpp +379 -0
  19. package/cpp/llama.cpp/common/chat-parser.h +117 -0
  20. package/cpp/llama.cpp/common/chat.cpp +847 -699
  21. package/cpp/llama.cpp/common/chat.h +73 -6
  22. package/cpp/llama.cpp/common/common.cpp +50 -82
  23. package/cpp/llama.cpp/common/common.h +21 -17
  24. package/cpp/llama.cpp/common/json-partial.cpp +255 -0
  25. package/cpp/llama.cpp/common/json-partial.h +37 -0
  26. package/cpp/llama.cpp/common/minja/chat-template.hpp +9 -5
  27. package/cpp/llama.cpp/common/minja/minja.hpp +69 -36
  28. package/cpp/llama.cpp/common/regex-partial.cpp +204 -0
  29. package/cpp/llama.cpp/common/regex-partial.h +56 -0
  30. package/cpp/llama.cpp/common/sampling.cpp +7 -8
  31. package/cpp/llama.cpp/convert_hf_to_gguf.py +453 -118
  32. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +120 -68
  33. package/cpp/llama.cpp/ggml/CMakeLists.txt +2 -1
  34. package/cpp/llama.cpp/ggml/cmake/common.cmake +25 -0
  35. package/cpp/llama.cpp/ggml/include/ggml-opt.h +49 -28
  36. package/cpp/llama.cpp/ggml/include/ggml.h +26 -7
  37. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +16 -10
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +4 -1
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +2 -0
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +604 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +42 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +54 -2
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +50 -51
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -2
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -9
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +779 -19
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +22 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +322 -100
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +117 -1
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +220 -49
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/acc.cu +40 -26
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +11 -1
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +15 -7
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +266 -64
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +49 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +48 -4
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +2 -1
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +5 -1
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +2 -0
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +7 -6
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +1 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +10 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-impl.h +1 -1
  71. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +99 -17
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +200 -2
  74. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +8 -2
  75. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +6 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +972 -178
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +373 -190
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -10
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +101 -5
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +31 -33
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +1 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +29 -2
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +4 -5
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +9 -1
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +84 -72
  96. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  98. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -3
  99. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +324 -129
  100. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +1 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +31 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +95 -68
  103. package/cpp/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +1 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +22 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -4
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +2 -3
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +69 -43
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +2 -14
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -91
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -181
  112. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +17 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  114. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +6 -152
  115. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  117. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +2 -118
  118. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +1 -1
  119. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +12 -1
  120. package/cpp/llama.cpp/ggml/src/ggml.c +107 -36
  121. package/cpp/llama.cpp/ggml/src/gguf.cpp +33 -33
  122. package/cpp/llama.cpp/gguf-py/gguf/constants.py +100 -15
  123. package/cpp/llama.cpp/gguf-py/gguf/gguf_reader.py +1 -1
  124. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +44 -12
  125. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py +21 -10
  126. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +5 -2
  127. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +128 -31
  128. package/cpp/llama.cpp/gguf-py/gguf/utility.py +1 -1
  129. package/cpp/llama.cpp/gguf-py/pyproject.toml +1 -1
  130. package/cpp/llama.cpp/include/llama.h +62 -125
  131. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +1 -1
  132. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +1 -1
  133. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +1 -1
  134. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +1 -1
  135. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +1 -1
  136. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +1 -1
  137. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +1 -1
  138. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +1 -1
  139. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +1 -1
  140. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +1 -1
  141. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +1 -1
  142. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +1 -1
  143. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +1 -1
  144. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +1 -1
  145. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +1 -1
  146. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +1 -1
  147. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +1 -1
  148. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +1 -1
  149. package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
  150. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +1 -1
  151. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +1 -1
  152. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +1 -1
  153. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +1 -1
  154. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +1 -1
  155. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +1 -1
  156. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +1 -1
  157. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +1 -1
  158. package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +62 -0
  159. package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +85 -0
  160. package/cpp/llama.cpp/models/templates/README.md +2 -0
  161. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +5 -1
  162. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +5 -1
  163. package/cpp/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
  164. package/cpp/llama.cpp/requirements/requirements-gguf_editor_gui.txt +1 -1
  165. package/cpp/llama.cpp/src/CMakeLists.txt +2 -0
  166. package/cpp/llama.cpp/src/llama-arch.cpp +6 -0
  167. package/cpp/llama.cpp/src/llama-arch.h +2 -0
  168. package/cpp/llama.cpp/src/llama-batch.cpp +3 -1
  169. package/cpp/llama.cpp/src/llama-context.cpp +340 -123
  170. package/cpp/llama.cpp/src/llama-context.h +30 -0
  171. package/cpp/llama.cpp/src/llama-cparams.cpp +4 -0
  172. package/cpp/llama.cpp/src/llama-cparams.h +2 -0
  173. package/cpp/llama.cpp/src/llama-grammar.cpp +12 -2
  174. package/cpp/llama.cpp/src/llama-graph.cpp +157 -247
  175. package/cpp/llama.cpp/src/llama-graph.h +52 -7
  176. package/cpp/llama.cpp/src/llama-hparams.cpp +17 -1
  177. package/cpp/llama.cpp/src/llama-hparams.h +37 -5
  178. package/cpp/llama.cpp/src/llama-kv-cache.cpp +742 -481
  179. package/cpp/llama.cpp/src/llama-kv-cache.h +196 -99
  180. package/cpp/llama.cpp/src/llama-kv-cells.h +379 -0
  181. package/cpp/llama.cpp/src/llama-memory.h +4 -3
  182. package/cpp/llama.cpp/src/llama-model-loader.cpp +22 -17
  183. package/cpp/llama.cpp/src/llama-model-saver.cpp +281 -0
  184. package/cpp/llama.cpp/src/llama-model-saver.h +37 -0
  185. package/cpp/llama.cpp/src/llama-model.cpp +529 -172
  186. package/cpp/llama.cpp/src/llama-model.h +6 -1
  187. package/cpp/llama.cpp/src/llama-quant.cpp +15 -13
  188. package/cpp/llama.cpp/src/llama-sampling.cpp +2 -2
  189. package/cpp/llama.cpp/src/llama-vocab.cpp +35 -8
  190. package/cpp/llama.cpp/src/llama-vocab.h +6 -0
  191. package/cpp/llama.cpp/src/llama.cpp +14 -0
  192. package/cpp/rn-completion.cpp +4 -2
  193. package/ios/include/chat.h +73 -6
  194. package/ios/include/common/minja/chat-template.hpp +9 -5
  195. package/ios/include/common/minja/minja.hpp +69 -36
  196. package/ios/include/common.h +21 -17
  197. package/ios/include/llama.h +62 -125
  198. package/ios/libs/llama.xcframework/Info.plist +19 -19
  199. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  200. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4617 -4487
  201. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  202. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +26 -7
  203. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +62 -125
  204. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  205. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  206. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  207. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3557 -3435
  208. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  209. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  210. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  211. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  212. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  213. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3559 -3437
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +237 -0
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +26 -7
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +62 -125
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +237 -0
  219. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +26 -7
  220. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +62 -125
  221. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  222. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +237 -0
  223. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +26 -7
  224. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +62 -125
  225. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  226. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  227. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  228. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4616 -4487
  229. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +26 -7
  231. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +62 -125
  232. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4637 -4508
  235. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3556 -3435
  236. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  237. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  238. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  239. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  240. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  241. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4653 -4523
  242. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  243. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +26 -7
  244. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +62 -125
  245. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  246. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  247. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4674 -4544
  248. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3587 -3465
  249. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  250. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  251. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  252. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  253. package/package.json +1 -1
  254. package/cpp/llama.cpp/common/stb_image.h +0 -7988
  255. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +0 -112
  256. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.out +0 -46
  257. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +0 -112
  258. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +0 -46
  259. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +0 -112
  260. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +0 -46
  261. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.inp +0 -112
  262. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.out +0 -46
  263. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +0 -112
  264. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.out +0 -46
  265. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +0 -112
  266. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +0 -46
@@ -25,7 +25,11 @@ llama_context::llama_context(
25
25
 
26
26
  const auto & hparams = model.hparams;
27
27
 
28
- cparams.n_seq_max = std::max(1u, params.n_seq_max);
28
+ cparams.n_seq_max = std::max(1u, params.n_seq_max);
29
+ if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
30
+ throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
31
+ }
32
+
29
33
  cparams.n_threads = params.n_threads;
30
34
  cparams.n_threads_batch = params.n_threads_batch;
31
35
  cparams.yarn_ext_factor = params.yarn_ext_factor;
@@ -93,6 +97,7 @@ llama_context::llama_context(
93
97
  }
94
98
 
95
99
  cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
100
+
96
101
  cparams.op_offload = params.op_offload;
97
102
 
98
103
  const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
@@ -176,8 +181,9 @@ llama_context::llama_context(
176
181
  // init the memory module
177
182
  if (!hparams.vocab_only) {
178
183
  llama_memory_params params_mem = {
179
- /*.type_k =*/ params.type_k,
180
- /*.type_v =*/ params.type_v,
184
+ /*.type_k =*/ params.type_k,
185
+ /*.type_v =*/ params.type_v,
186
+ /*.swa_full =*/ params.swa_full,
181
187
  };
182
188
 
183
189
  memory.reset(model.create_memory(params_mem, cparams));
@@ -359,7 +365,9 @@ llama_context::llama_context(
359
365
  }
360
366
  }
361
367
 
362
- llama_context::~llama_context() = default;
368
+ llama_context::~llama_context() {
369
+ ggml_opt_free(opt_ctx);
370
+ }
363
371
 
364
372
  void llama_context::synchronize() {
365
373
  ggml_backend_sched_synchronize(sched.get());
@@ -685,12 +693,18 @@ int llama_context::encode(llama_batch & inp_batch) {
685
693
 
686
694
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
687
695
 
696
+ // TODO: move the validation to the llama_batch_allocr
688
697
  if (batch.token) {
689
698
  for (int32_t i = 0; i < n_tokens; ++i) {
690
699
  if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
691
700
  LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
692
701
  return -1;
693
702
  }
703
+
704
+ if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
705
+ LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
706
+ throw -1;
707
+ }
694
708
  }
695
709
  }
696
710
 
@@ -844,7 +858,7 @@ int llama_context::encode(llama_batch & inp_batch) {
844
858
 
845
859
  int llama_context::decode(llama_batch & inp_batch) {
846
860
  if (!memory) {
847
- LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
861
+ LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
848
862
  return encode(inp_batch);
849
863
  }
850
864
 
@@ -853,11 +867,17 @@ int llama_context::decode(llama_batch & inp_batch) {
853
867
  return -1;
854
868
  }
855
869
 
870
+ if (!inp_batch.pos) {
871
+ if (inp_batch.seq_id) {
872
+ LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
873
+ return -1;
874
+ }
875
+ }
876
+
856
877
  llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
857
878
 
858
879
  // temporary allocate memory for the input batch if needed
859
- // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
860
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
880
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
861
881
 
862
882
  const llama_batch & batch = batch_allocr.batch;
863
883
 
@@ -873,11 +893,17 @@ int llama_context::decode(llama_batch & inp_batch) {
873
893
 
874
894
  GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
875
895
 
896
+ // TODO: move the validation to the llama_batch_allocr
876
897
  if (batch.token) {
877
898
  for (int64_t i = 0; i < n_tokens_all; ++i) {
878
899
  if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
879
900
  LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
880
- throw std::runtime_error("invalid token");
901
+ return -1;
902
+ }
903
+
904
+ if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
905
+ LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
906
+ return -1;
881
907
  }
882
908
  }
883
909
  }
@@ -945,8 +971,6 @@ int llama_context::decode(llama_batch & inp_batch) {
945
971
 
946
972
  // find KV slot
947
973
  if (!kv_self->find_slot(ubatch)) {
948
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
949
-
950
974
  return 1;
951
975
  }
952
976
 
@@ -1702,10 +1726,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1702
1726
  }
1703
1727
  }
1704
1728
 
1705
- LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1706
1729
  llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1707
1730
 
1708
- kv_self->state_write(io);
1731
+ if (kv_self != nullptr) {
1732
+ LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1733
+ kv_self->state_write(io);
1734
+ }
1709
1735
 
1710
1736
  return io.n_bytes();
1711
1737
  }
@@ -1788,10 +1814,13 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1788
1814
  }
1789
1815
  }
1790
1816
 
1791
- LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1792
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1817
+ if (memory) {
1818
+ LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1793
1819
 
1794
- kv_self->state_read(io);
1820
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1821
+
1822
+ kv_self->state_read(io);
1823
+ }
1795
1824
 
1796
1825
  return io.n_bytes();
1797
1826
  }
@@ -1799,9 +1828,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1799
1828
  size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
1800
1829
  GGML_UNUSED(seq_id);
1801
1830
 
1802
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1831
+ if (memory) {
1832
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1803
1833
 
1804
- kv_self->state_write(io, seq_id);
1834
+ kv_self->state_write(io, seq_id);
1835
+ }
1805
1836
 
1806
1837
  return io.n_bytes();
1807
1838
  }
@@ -1809,9 +1840,11 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
1809
1840
  size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
1810
1841
  GGML_UNUSED(seq_id);
1811
1842
 
1812
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1843
+ if (memory) {
1844
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1813
1845
 
1814
- kv_self->state_read(io, seq_id);
1846
+ kv_self->state_read(io, seq_id);
1847
+ }
1815
1848
 
1816
1849
  return io.n_bytes();
1817
1850
  }
@@ -1839,6 +1872,215 @@ void llama_context::perf_reset() {
1839
1872
  t_p_eval_us = n_p_eval = 0;
1840
1873
  }
1841
1874
 
1875
+ //
1876
+ // training
1877
+ //
1878
+
1879
+ static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
1880
+ if (!tensor || tensor->type != GGML_TYPE_F32) {
1881
+ return;
1882
+ }
1883
+ if (!param_filter(tensor, userdata)) {
1884
+ return;
1885
+ }
1886
+ if (strcmp(tensor->name, "token_embd.weight") == 0) {
1887
+ return; // FIXME
1888
+ }
1889
+ if (strcmp(tensor->name, "rope_freqs.weight") == 0) {
1890
+ return; // FIXME
1891
+ }
1892
+ ggml_set_param(tensor);
1893
+ }
1894
+
1895
+ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) {
1896
+ GGML_ASSERT(!opt_ctx);
1897
+ model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx();
1898
+ const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train);
1899
+ const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
1900
+ GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
1901
+ GGML_ASSERT(n_batch % n_ubatch == 0);
1902
+
1903
+ ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
1904
+ opt_params.opt_period = n_batch / n_ubatch;
1905
+ opt_params.get_opt_pars = lopt_params.get_opt_pars;
1906
+ opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
1907
+
1908
+ opt_ctx = ggml_opt_init(opt_params);
1909
+
1910
+ llama_opt_param_filter param_filter = lopt_params.param_filter;
1911
+ void * param_filter_ud = lopt_params.param_filter_ud;
1912
+
1913
+ //llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME
1914
+ llama_set_param(model->type_embd, param_filter, param_filter_ud);
1915
+ llama_set_param(model->pos_embd, param_filter, param_filter_ud);
1916
+ llama_set_param(model->tok_norm, param_filter, param_filter_ud);
1917
+ llama_set_param(model->tok_norm_b, param_filter, param_filter_ud);
1918
+ llama_set_param(model->output_norm, param_filter, param_filter_ud);
1919
+ llama_set_param(model->output_norm_b, param_filter, param_filter_ud);
1920
+ llama_set_param(model->output, param_filter, param_filter_ud);
1921
+ llama_set_param(model->output_b, param_filter, param_filter_ud);
1922
+ llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
1923
+ llama_set_param(model->cls, param_filter, param_filter_ud);
1924
+ llama_set_param(model->cls_b, param_filter, param_filter_ud);
1925
+ llama_set_param(model->cls_out, param_filter, param_filter_ud);
1926
+ llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
1927
+
1928
+ for (struct llama_layer & layer : model->layers) {
1929
+ for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
1930
+ llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
1931
+ }
1932
+ }
1933
+ }
1934
+
1935
+ void llama_context::opt_epoch_iter(
1936
+ ggml_opt_dataset_t dataset,
1937
+ ggml_opt_result_t result,
1938
+ const std::vector<llama_token> & tokens,
1939
+ const std::vector<llama_token> & labels_sparse,
1940
+ llama_batch & batch,
1941
+ ggml_opt_epoch_callback callback,
1942
+ bool train,
1943
+ int64_t idata_in_loop,
1944
+ int64_t ndata_in_loop,
1945
+ int64_t t_loop_start) {
1946
+ GGML_ASSERT(opt_ctx);
1947
+ const uint32_t n_ctx = llama_model_n_ctx_train(&model);
1948
+ const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
1949
+ const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
1950
+
1951
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1952
+
1953
+ kv_self->clear();
1954
+ llama_kv_cache_guard kv_guard(kv_self);
1955
+
1956
+ for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
1957
+ batch.n_tokens = n_batch;
1958
+ for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
1959
+ batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
1960
+ batch.pos [pos_batch] = pos_ctx + pos_batch;
1961
+ batch.n_seq_id[pos_batch] = 1;
1962
+ batch.seq_id [pos_batch][0] = 0;
1963
+ batch.logits [pos_batch] = true;
1964
+ }
1965
+
1966
+ const auto n_tokens_all = batch.n_tokens;
1967
+
1968
+ n_queued_tokens += n_tokens_all;
1969
+
1970
+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
1971
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
1972
+
1973
+ embd_seq.clear();
1974
+
1975
+ int64_t n_outputs_all = n_tokens_all;
1976
+
1977
+ llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
1978
+
1979
+ // reserve output buffer
1980
+ if (output_reserve(n_outputs_all) < n_outputs_all) {
1981
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
1982
+ GGML_ABORT("TODO: handle this error");
1983
+ };
1984
+
1985
+ for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
1986
+ llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
1987
+
1988
+ n_outputs = ubatch.n_tokens;
1989
+
1990
+ // TODO: not sure if this is needed
1991
+ if (!kv_self->find_slot(ubatch)) {
1992
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1993
+
1994
+ GGML_ABORT("TODO: handle this error");
1995
+ }
1996
+
1997
+ auto * gf = graph_init();
1998
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
1999
+
2000
+ struct ggml_context * ctx_compute_opt;
2001
+ {
2002
+ const size_t size_gf = ggml_graph_size(gf);
2003
+ const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true);
2004
+ struct ggml_init_params params = {
2005
+ /*.mem_size =*/ size_meta,
2006
+ /*.mem_buffer =*/ nullptr,
2007
+ /*.no_alloc =*/ true,
2008
+ };
2009
+ ctx_compute_opt = ggml_init(params);
2010
+ }
2011
+ ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
2012
+ ggml_opt_alloc(opt_ctx, train);
2013
+ res->set_inputs(&ubatch);
2014
+ {
2015
+ struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
2016
+ GGML_ASSERT(labels->ne[1] == n_ubatch);
2017
+ ggml_set_zero(labels);
2018
+ const float onef = 1.0f;
2019
+ for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
2020
+ const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
2021
+ GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
2022
+ ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
2023
+ }
2024
+ }
2025
+ ggml_opt_eval(opt_ctx, result);
2026
+ if (callback) {
2027
+ callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
2028
+ }
2029
+ ggml_free(ctx_compute_opt);
2030
+ }
2031
+ }
2032
+
2033
+ kv_guard.commit();
2034
+ }
2035
+
2036
+ void llama_context::opt_epoch(
2037
+ ggml_opt_dataset_t dataset,
2038
+ ggml_opt_result_t result_train,
2039
+ ggml_opt_result_t result_eval,
2040
+ int64_t idata_split,
2041
+ ggml_opt_epoch_callback callback_train,
2042
+ ggml_opt_epoch_callback callback_eval) {
2043
+ const uint32_t n_ctx = this->n_ctx();
2044
+ const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
2045
+ const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
2046
+ const int64_t ndata = ggml_opt_dataset_ndata(dataset);
2047
+
2048
+ GGML_ASSERT(idata_split >= 0);
2049
+ GGML_ASSERT(idata_split <= ndata);
2050
+
2051
+ const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
2052
+
2053
+ struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
2054
+ std::vector<llama_token> tokens(n_ctx);
2055
+ std::vector<llama_token> labels_sparse(n_ctx);
2056
+
2057
+ int64_t idata = 0;
2058
+
2059
+ int64_t t_loop_start = ggml_time_us();
2060
+ int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
2061
+ for (; idata < idata_split; ++idata) {
2062
+ constexpr bool train = true;
2063
+ const int64_t idata_in_loop = idata*ubatch_per_ctx;
2064
+
2065
+ ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
2066
+ opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch,
2067
+ callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
2068
+ }
2069
+
2070
+ t_loop_start = ggml_time_us();
2071
+ ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
2072
+ for (; idata < ndata; ++idata) {
2073
+ constexpr bool train = false;
2074
+ const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
2075
+
2076
+ ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
2077
+ opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch,
2078
+ callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
2079
+ }
2080
+
2081
+ llama_batch_free(batch);
2082
+ }
2083
+
1842
2084
  //
1843
2085
  // interface implementation
1844
2086
  //
@@ -1873,6 +2115,7 @@ llama_context_params llama_context_default_params() {
1873
2115
  /*.flash_attn =*/ false,
1874
2116
  /*.no_perf =*/ true,
1875
2117
  /*.op_offload =*/ true,
2118
+ /*.swa_full =*/ true,
1876
2119
  };
1877
2120
 
1878
2121
  return result;
@@ -2067,65 +2310,51 @@ int32_t llama_apply_adapter_cvec(
2067
2310
  return res ? 0 : -1;
2068
2311
  }
2069
2312
 
2070
- //
2071
- // kv cache view
2072
- //
2073
-
2074
- llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
2075
- const auto * kv = ctx->get_kv_self();
2076
- if (kv == nullptr) {
2077
- LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
2078
- return {};
2079
- }
2080
-
2081
- return llama_kv_cache_view_init(*kv, n_seq_max);
2082
- }
2083
-
2084
- void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
2085
- const auto * kv = ctx->get_kv_self();
2086
- if (kv == nullptr) {
2087
- LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
2088
- return;
2089
- }
2090
-
2091
- llama_kv_cache_view_update(view, kv);
2092
- }
2093
-
2094
2313
  //
2095
2314
  // kv cache
2096
2315
  //
2097
2316
 
2098
2317
  // deprecated
2099
- int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
2100
- return llama_kv_self_n_tokens(ctx);
2101
- }
2102
-
2103
2318
  int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2104
2319
  const auto * kv = ctx->get_kv_self();
2105
2320
  if (!kv) {
2106
2321
  return 0;
2107
2322
  }
2108
2323
 
2109
- return kv->get_n_tokens();
2110
- }
2324
+ int32_t res = 0;
2111
2325
 
2112
- // deprecated
2113
- int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
2114
- return llama_kv_self_used_cells(ctx);
2326
+ for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2327
+ const llama_pos p0 = kv->seq_pos_min(s);
2328
+ const llama_pos p1 = kv->seq_pos_max(s);
2329
+
2330
+ if (p0 >= 0) {
2331
+ res += (p1 - p0) + 1;
2332
+ }
2333
+ }
2334
+
2335
+ return res;
2115
2336
  }
2116
2337
 
2338
+ // deprecated
2339
+ // note: this is the same as above - will be removed anyway, so it's ok
2117
2340
  int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2118
2341
  const auto * kv = ctx->get_kv_self();
2119
2342
  if (!kv) {
2120
2343
  return 0;
2121
2344
  }
2122
2345
 
2123
- return kv->get_used_cells();
2124
- }
2346
+ int32_t res = 0;
2125
2347
 
2126
- // deprecated
2127
- void llama_kv_cache_clear(llama_context * ctx) {
2128
- llama_kv_self_clear(ctx);
2348
+ for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2349
+ const llama_pos p0 = kv->seq_pos_min(s);
2350
+ const llama_pos p1 = kv->seq_pos_max(s);
2351
+
2352
+ if (p0 >= 0) {
2353
+ res += (p1 - p0) + 1;
2354
+ }
2355
+ }
2356
+
2357
+ return res;
2129
2358
  }
2130
2359
 
2131
2360
  void llama_kv_self_clear(llama_context * ctx) {
@@ -2137,15 +2366,6 @@ void llama_kv_self_clear(llama_context * ctx) {
2137
2366
  kv->clear();
2138
2367
  }
2139
2368
 
2140
- // deprecated
2141
- bool llama_kv_cache_seq_rm(
2142
- llama_context * ctx,
2143
- llama_seq_id seq_id,
2144
- llama_pos p0,
2145
- llama_pos p1) {
2146
- return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
2147
- }
2148
-
2149
2369
  bool llama_kv_self_seq_rm(
2150
2370
  llama_context * ctx,
2151
2371
  llama_seq_id seq_id,
@@ -2159,16 +2379,6 @@ bool llama_kv_self_seq_rm(
2159
2379
  return kv->seq_rm(seq_id, p0, p1);
2160
2380
  }
2161
2381
 
2162
- // deprecated
2163
- void llama_kv_cache_seq_cp(
2164
- llama_context * ctx,
2165
- llama_seq_id seq_id_src,
2166
- llama_seq_id seq_id_dst,
2167
- llama_pos p0,
2168
- llama_pos p1) {
2169
- llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
2170
- }
2171
-
2172
2382
  void llama_kv_self_seq_cp(
2173
2383
  llama_context * ctx,
2174
2384
  llama_seq_id seq_id_src,
@@ -2183,13 +2393,6 @@ void llama_kv_self_seq_cp(
2183
2393
  kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2184
2394
  }
2185
2395
 
2186
- // deprecated
2187
- void llama_kv_cache_seq_keep(
2188
- llama_context * ctx,
2189
- llama_seq_id seq_id) {
2190
- llama_kv_self_seq_keep(ctx, seq_id);
2191
- }
2192
-
2193
2396
  void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2194
2397
  auto * kv = ctx->get_kv_self();
2195
2398
  if (!kv) {
@@ -2199,16 +2402,6 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2199
2402
  kv->seq_keep(seq_id);
2200
2403
  }
2201
2404
 
2202
- // deprecated
2203
- void llama_kv_cache_seq_add(
2204
- llama_context * ctx,
2205
- llama_seq_id seq_id,
2206
- llama_pos p0,
2207
- llama_pos p1,
2208
- llama_pos delta) {
2209
- llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
2210
- }
2211
-
2212
2405
  void llama_kv_self_seq_add(
2213
2406
  llama_context * ctx,
2214
2407
  llama_seq_id seq_id,
@@ -2223,16 +2416,6 @@ void llama_kv_self_seq_add(
2223
2416
  kv->seq_add(seq_id, p0, p1, delta);
2224
2417
  }
2225
2418
 
2226
- // deprecated
2227
- void llama_kv_cache_seq_div(
2228
- llama_context * ctx,
2229
- llama_seq_id seq_id,
2230
- llama_pos p0,
2231
- llama_pos p1,
2232
- int d) {
2233
- llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
2234
- }
2235
-
2236
2419
  void llama_kv_self_seq_div(
2237
2420
  llama_context * ctx,
2238
2421
  llama_seq_id seq_id,
@@ -2247,25 +2430,24 @@ void llama_kv_self_seq_div(
2247
2430
  kv->seq_div(seq_id, p0, p1, d);
2248
2431
  }
2249
2432
 
2250
- // deprecated
2251
- llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2252
- return llama_kv_self_seq_pos_max(ctx, seq_id);
2433
+ llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2434
+ const auto * kv = ctx->get_kv_self();
2435
+ if (!kv) {
2436
+ return -1;
2437
+ }
2438
+
2439
+ return kv->seq_pos_min(seq_id);
2253
2440
  }
2254
2441
 
2255
2442
  llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2256
2443
  const auto * kv = ctx->get_kv_self();
2257
2444
  if (!kv) {
2258
- return 0;
2445
+ return -1;
2259
2446
  }
2260
2447
 
2261
2448
  return kv->seq_pos_max(seq_id);
2262
2449
  }
2263
2450
 
2264
- // deprecated
2265
- void llama_kv_cache_defrag(llama_context * ctx) {
2266
- llama_kv_self_defrag(ctx);
2267
- }
2268
-
2269
2451
  void llama_kv_self_defrag(llama_context * ctx) {
2270
2452
  auto * kv = ctx->get_kv_self();
2271
2453
  if (!kv) {
@@ -2276,11 +2458,6 @@ void llama_kv_self_defrag(llama_context * ctx) {
2276
2458
  kv->defrag_sched(-1.0f);
2277
2459
  }
2278
2460
 
2279
- // deprecated
2280
- bool llama_kv_cache_can_shift(const llama_context * ctx) {
2281
- return llama_kv_self_can_shift(ctx);
2282
- }
2283
-
2284
2461
  bool llama_kv_self_can_shift(const llama_context * ctx) {
2285
2462
  const auto * kv = ctx->get_kv_self();
2286
2463
  if (!kv) {
@@ -2290,11 +2467,6 @@ bool llama_kv_self_can_shift(const llama_context * ctx) {
2290
2467
  return kv->get_can_shift();
2291
2468
  }
2292
2469
 
2293
- // deprecated
2294
- void llama_kv_cache_update(llama_context * ctx) {
2295
- llama_kv_self_update(ctx);
2296
- }
2297
-
2298
2470
  // llama state API
2299
2471
 
2300
2472
  // deprecated
@@ -2417,7 +2589,21 @@ int32_t llama_encode(
2417
2589
  int32_t llama_decode(
2418
2590
  llama_context * ctx,
2419
2591
  llama_batch batch) {
2420
- const int ret = ctx->decode(batch);
2592
+ int ret = ctx->decode(batch);
2593
+
2594
+ // defrag and try again
2595
+ // TODO: distinguish return code when we are sure that even after defrag there is no space available
2596
+ if (ret == 1) {
2597
+ llama_kv_self_defrag(ctx);
2598
+ ret = ctx->decode(batch);
2599
+
2600
+ if (ret == 1) {
2601
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2602
+
2603
+ return ret;
2604
+ }
2605
+ }
2606
+
2421
2607
  if (ret != 0) {
2422
2608
  LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2423
2609
  }
@@ -2457,3 +2643,34 @@ void llama_perf_context_print(const llama_context * ctx) {
2457
2643
  void llama_perf_context_reset(llama_context * ctx) {
2458
2644
  ctx->perf_reset();
2459
2645
  }
2646
+
2647
+ //
2648
+ // training
2649
+ //
2650
+
2651
+ bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) {
2652
+ GGML_UNUSED(tensor);
2653
+ GGML_UNUSED(userdata);
2654
+ return true;
2655
+ }
2656
+
2657
+ void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) {
2658
+ ctx->opt_init(model, lopt_params);
2659
+ }
2660
+
2661
+ void llama_opt_epoch(
2662
+ struct llama_context * ctx,
2663
+ ggml_opt_dataset_t dataset,
2664
+ ggml_opt_result_t result_train,
2665
+ ggml_opt_result_t result_eval,
2666
+ int64_t idata_split,
2667
+ ggml_opt_epoch_callback callback_train,
2668
+ ggml_opt_epoch_callback callback_eval) {
2669
+ ctx->opt_epoch(
2670
+ dataset,
2671
+ result_train,
2672
+ result_eval,
2673
+ idata_split,
2674
+ callback_train,
2675
+ callback_eval);
2676
+ }