@novastera-oss/llamarn 0.2.6 → 0.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (253) hide show
  1. package/android/src/main/cpp/include/llama.h +141 -38
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +58 -24
  11. package/cpp/LlamaCppModel.h +3 -3
  12. package/cpp/PureCppImpl.cpp +1 -1
  13. package/cpp/PureCppImpl.h +2 -2
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  16. package/cpp/llama.cpp/Makefile +2 -2
  17. package/cpp/llama.cpp/README.md +32 -13
  18. package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
  19. package/cpp/llama.cpp/common/arg.cpp +37 -6
  20. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  21. package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
  22. package/cpp/llama.cpp/common/chat-parser.h +2 -0
  23. package/cpp/llama.cpp/common/chat.cpp +12 -9
  24. package/cpp/llama.cpp/common/chat.h +1 -1
  25. package/cpp/llama.cpp/common/common.cpp +53 -40
  26. package/cpp/llama.cpp/common/common.h +6 -2
  27. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  28. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  29. package/cpp/llama.cpp/convert_hf_to_gguf.py +215 -76
  30. package/cpp/llama.cpp/ggml/CMakeLists.txt +48 -2
  31. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  32. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  33. package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
  34. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +64 -13
  35. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  38. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +124 -26
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +4 -3
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +93 -104
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +194 -69
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1158 -0
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1571 -0
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +213 -37
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +59 -37
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +90 -39
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  88. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
  90. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  91. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +260 -49
  93. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +497 -282
  94. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1078 -468
  97. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  105. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +20 -48
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  110. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  111. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +117 -165
  112. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +192 -53
  113. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  114. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  115. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  116. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
  117. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  118. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +8 -105
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +209 -92
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +36 -28
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +487 -247
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  138. package/cpp/llama.cpp/ggml/src/ggml.c +69 -19
  139. package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
  140. package/cpp/llama.cpp/gguf-py/gguf/constants.py +133 -0
  141. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +25 -1
  142. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +78 -3
  143. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
  144. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  145. package/cpp/llama.cpp/include/llama.h +141 -38
  146. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  147. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  148. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  149. package/cpp/llama.cpp/src/llama-arch.cpp +150 -3
  150. package/cpp/llama.cpp/src/llama-arch.h +25 -1
  151. package/cpp/llama.cpp/src/llama-batch.cpp +736 -274
  152. package/cpp/llama.cpp/src/llama-batch.h +110 -57
  153. package/cpp/llama.cpp/src/llama-chat.cpp +30 -8
  154. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  155. package/cpp/llama.cpp/src/llama-context.cpp +360 -266
  156. package/cpp/llama.cpp/src/llama-context.h +27 -23
  157. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  158. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  159. package/cpp/llama.cpp/src/llama-graph.cpp +411 -344
  160. package/cpp/llama.cpp/src/llama-graph.h +126 -58
  161. package/cpp/llama.cpp/src/llama-hparams.cpp +10 -2
  162. package/cpp/llama.cpp/src/llama-hparams.h +16 -2
  163. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +103 -73
  164. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +34 -42
  165. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +345 -221
  166. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +75 -50
  167. package/cpp/llama.cpp/src/llama-kv-cells.h +51 -22
  168. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
  169. package/cpp/llama.cpp/src/llama-memory-hybrid.h +138 -0
  170. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +302 -317
  171. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +60 -68
  172. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  173. package/cpp/llama.cpp/src/llama-memory.h +73 -36
  174. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  175. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  176. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  177. package/cpp/llama.cpp/src/llama-model.cpp +1630 -511
  178. package/cpp/llama.cpp/src/llama-model.h +26 -0
  179. package/cpp/llama.cpp/src/llama-quant.cpp +89 -6
  180. package/cpp/llama.cpp/src/llama-vocab.cpp +58 -26
  181. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  182. package/cpp/llama.cpp/src/llama.cpp +11 -7
  183. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  184. package/cpp/rn-completion.cpp +2 -2
  185. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  186. package/cpp/{rn-utils.hpp → rn-utils.h} +3 -0
  187. package/ios/include/chat.h +1 -1
  188. package/ios/include/common.h +6 -2
  189. package/ios/include/llama.h +141 -38
  190. package/ios/libs/llama.xcframework/Info.plist +15 -15
  191. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  192. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
  193. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  194. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
  195. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +141 -38
  196. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  197. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  198. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  199. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
  200. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  201. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  202. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  203. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  204. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  205. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  206. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3624
  207. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  208. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
  209. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +141 -38
  210. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  211. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
  212. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +141 -38
  213. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +141 -38
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  219. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  220. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
  221. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  222. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
  223. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +141 -38
  224. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  225. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  226. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  227. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
  228. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  229. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  231. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  232. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  233. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4725
  234. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  235. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
  236. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +141 -38
  237. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  238. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  239. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4746
  240. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3652
  241. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  242. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  243. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  244. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  245. package/package.json +1 -2
  246. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  247. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  248. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  249. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
  250. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
  251. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  252. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  253. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -17,11 +17,12 @@ struct ggml_tensor;
17
17
  struct llama_ubatch;
18
18
  struct llama_cparams;
19
19
 
20
- class llama_memory_state_i;
20
+ struct llama_memory_context_i;
21
21
 
22
- class llama_kv_cache_unified_state;
23
- class llama_kv_cache_unified_iswa_state;
24
- class llama_kv_cache_recurrent_state;
22
+ class llama_kv_cache_unified_context;
23
+ class llama_kv_cache_unified_iswa_context;
24
+ class llama_memory_recurrent_context;
25
+ class llama_memory_hybrid_context;
25
26
 
26
27
  // certain models (typically multi-modal) can produce different types of graphs
27
28
  enum llm_graph_type {
@@ -36,6 +37,7 @@ enum llm_ffn_op_type {
36
37
  LLM_FFN_RELU,
37
38
  LLM_FFN_RELU_SQR,
38
39
  LLM_FFN_SWIGLU,
40
+ LLM_FFN_GEGLU,
39
41
  };
40
42
 
41
43
  enum llm_ffn_gate_type {
@@ -93,14 +95,14 @@ public:
93
95
 
94
96
  class llm_graph_input_pos : public llm_graph_input_i {
95
97
  public:
96
- llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
98
+ llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
97
99
  virtual ~llm_graph_input_pos() = default;
98
100
 
99
101
  void set_input(const llama_ubatch * ubatch) override;
100
102
 
101
103
  ggml_tensor * pos = nullptr; // I32 [n_batch]
102
104
 
103
- const int64_t n_pos_per_embd = 1;
105
+ const uint32_t n_pos_per_embd = 1;
104
106
  };
105
107
 
106
108
  // temperature tuning, used by llama4
@@ -134,7 +136,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
134
136
  public:
135
137
  llm_graph_input_pos_bucket_kv(
136
138
  const llama_hparams & hparams,
137
- const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
139
+ const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
138
140
  virtual ~llm_graph_input_pos_bucket_kv() = default;
139
141
 
140
142
  void set_input(const llama_ubatch * ubatch) override;
@@ -142,7 +144,8 @@ public:
142
144
  ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
143
145
 
144
146
  const llama_hparams & hparams;
145
- const llama_kv_cache_unified_state * kv_state;
147
+
148
+ const llama_kv_cache_unified_context * mctx;
146
149
  };
147
150
 
148
151
  class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -187,28 +190,16 @@ public:
187
190
  const llama_cparams & cparams;
188
191
  };
189
192
 
190
- class llm_graph_input_s_copy : public llm_graph_input_i {
193
+ class llm_graph_input_rs : public llm_graph_input_i {
191
194
  public:
192
- llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
193
- virtual ~llm_graph_input_s_copy() = default;
195
+ llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
196
+ virtual ~llm_graph_input_rs() = default;
194
197
 
195
198
  void set_input(const llama_ubatch * ubatch) override;
196
199
 
197
200
  ggml_tensor * s_copy; // I32 [kv_size]
198
201
 
199
- const llama_kv_cache_recurrent_state * kv_state;
200
- };
201
-
202
- class llm_graph_input_s_mask : public llm_graph_input_i {
203
- public:
204
- llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
205
- virtual ~llm_graph_input_s_mask() = default;
206
-
207
- void set_input(const llama_ubatch * ubatch) override;
208
-
209
- ggml_tensor * s_mask; // F32 [1, n_kv]
210
-
211
- const llama_kv_cache_recurrent_state * kv_state;
202
+ const llama_memory_recurrent_context * mctx;
212
203
  };
213
204
 
214
205
  class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -248,10 +239,10 @@ public:
248
239
  llm_graph_input_attn_kv_unified(
249
240
  const llama_hparams & hparams,
250
241
  const llama_cparams & cparams,
251
- const llama_kv_cache_unified_state * kv_state) :
242
+ const llama_kv_cache_unified_context * mctx) :
252
243
  hparams(hparams),
253
244
  cparams(cparams),
254
- kv_state(kv_state) {
245
+ mctx(mctx) {
255
246
  }
256
247
  ~llm_graph_input_attn_kv_unified() = default;
257
248
 
@@ -265,7 +256,7 @@ public:
265
256
  const llama_hparams & hparams;
266
257
  const llama_cparams & cparams;
267
258
 
268
- const llama_kv_cache_unified_state * kv_state;
259
+ const llama_kv_cache_unified_context * mctx;
269
260
  };
270
261
 
271
262
  class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
@@ -273,10 +264,10 @@ public:
273
264
  llm_graph_input_attn_kv_unified_iswa(
274
265
  const llama_hparams & hparams,
275
266
  const llama_cparams & cparams,
276
- const llama_kv_cache_unified_iswa_state * kv_state) :
267
+ const llama_kv_cache_unified_iswa_context * mctx) :
277
268
  hparams(hparams),
278
269
  cparams(cparams),
279
- kv_state(kv_state) {
270
+ mctx(mctx) {
280
271
  }
281
272
  ~llm_graph_input_attn_kv_unified_iswa() = default;
282
273
 
@@ -293,7 +284,7 @@ public:
293
284
  const llama_hparams & hparams;
294
285
  const llama_cparams & cparams;
295
286
 
296
- const llama_kv_cache_unified_iswa_state * kv_state;
287
+ const llama_kv_cache_unified_iswa_context * mctx;
297
288
  };
298
289
 
299
290
  class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -311,6 +302,44 @@ public:
311
302
  const llama_cross * cross = nullptr;
312
303
  };
313
304
 
305
+ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
306
+ public:
307
+ llm_graph_input_mem_hybrid(
308
+ const llama_hparams & hparams,
309
+ const llama_cparams & cparams,
310
+ const llama_memory_hybrid_context * mctx) :
311
+ hparams(hparams),
312
+ cparams(cparams),
313
+ mctx(mctx) {
314
+ }
315
+ virtual ~llm_graph_input_mem_hybrid() = default;
316
+
317
+ void set_input(const llama_ubatch * ubatch) override;
318
+
319
+ ggml_tensor * s_copy; // I32 [kv_size]
320
+
321
+ ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
322
+
323
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
324
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
325
+
326
+ const llama_hparams & hparams;
327
+ const llama_cparams & cparams;
328
+
329
+ const llama_memory_hybrid_context * mctx;
330
+ };
331
+
332
+ // TODO: remove this when ggml_scale_add is implemented
333
+ class llm_graph_input_one : public llm_graph_input_i {
334
+ public:
335
+ llm_graph_input_one() {}
336
+ virtual ~llm_graph_input_one() = default;
337
+
338
+ void set_input(const llama_ubatch *) override;
339
+
340
+ ggml_tensor * one = nullptr; // F32
341
+ };
342
+
314
343
  //
315
344
  // llm_graph_result
316
345
  //
@@ -384,12 +413,12 @@ struct llm_graph_params {
384
413
  ggml_backend_sched_t sched;
385
414
  ggml_backend_t backend_cpu;
386
415
 
387
- const llama_adapter_cvec * cvec;
388
- const llama_adapter_loras * loras;
389
- const llama_memory_state_i * mstate;
390
- const llama_cross * cross;
416
+ const llama_adapter_cvec * cvec;
417
+ const llama_adapter_loras * loras;
418
+ const llama_memory_context_i * mctx;
419
+ const llama_cross * cross;
391
420
 
392
- int32_t n_outputs;
421
+ uint32_t n_outputs;
393
422
 
394
423
  const llm_graph_cb & cb;
395
424
  };
@@ -423,8 +452,8 @@ struct llm_graph_context {
423
452
  const float norm_eps;
424
453
  const float norm_rms_eps;
425
454
 
426
- const int32_t n_tokens;
427
- const int32_t n_outputs;
455
+ const int64_t n_tokens;
456
+ const int64_t n_outputs;
428
457
  const int32_t n_ctx_orig; // yarn
429
458
 
430
459
  const enum llama_pooling_type pooling_type;
@@ -436,18 +465,17 @@ struct llm_graph_context {
436
465
 
437
466
  ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
438
467
 
439
- const llama_adapter_cvec * cvec;
440
- const llama_adapter_loras * loras;
441
- const llama_memory_state_i * mstate;
442
- const llama_cross * cross;
468
+ const llama_adapter_cvec * cvec;
469
+ const llama_adapter_loras * loras;
470
+ const llama_memory_context_i * mctx;
471
+ const llama_cross * cross;
443
472
 
444
473
  const llm_graph_cb & cb_func;
445
474
 
446
475
  std::unique_ptr<llm_graph_result> res;
447
476
 
448
477
  llm_graph_context(const llm_graph_params & params);
449
-
450
- int64_t n_pos_per_embd() const;
478
+ virtual ~llm_graph_context() = default;
451
479
 
452
480
  void cb(ggml_tensor * cur, const char * name, int il) const;
453
481
 
@@ -519,14 +547,14 @@ struct llm_graph_context {
519
547
  ggml_tensor * build_inp_out_ids() const;
520
548
  ggml_tensor * build_inp_mean() const;
521
549
  ggml_tensor * build_inp_cls() const;
522
- ggml_tensor * build_inp_s_copy() const;
523
- ggml_tensor * build_inp_s_mask() const;
524
550
 
525
551
  ggml_tensor * build_inp_cross_embd() const;
526
552
  ggml_tensor * build_inp_pos_bucket_enc() const;
527
553
  ggml_tensor * build_inp_pos_bucket_dec() const;
528
554
  ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
529
555
 
556
+ llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
557
+
530
558
  //
531
559
  // attention
532
560
  //
@@ -573,14 +601,15 @@ struct llm_graph_context {
573
601
 
574
602
  llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
575
603
 
604
+ // note: if k_cur or v_cur are not provided, they will not be stored in the memory
576
605
  ggml_tensor * build_attn(
577
606
  llm_graph_input_attn_kv_unified_iswa * inp,
578
607
  ggml_cgraph * gf,
579
608
  ggml_tensor * wo,
580
609
  ggml_tensor * wo_b,
581
610
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
582
- ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
583
- ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
611
+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
612
+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
584
613
  ggml_tensor * kq_b,
585
614
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
586
615
  float kq_scale,
@@ -601,23 +630,62 @@ struct llm_graph_context {
601
630
  float kq_scale,
602
631
  int il) const;
603
632
 
633
+ ggml_tensor * build_attn(
634
+ llm_graph_input_mem_hybrid * inp,
635
+ ggml_cgraph * gf,
636
+ ggml_tensor * wo,
637
+ ggml_tensor * wo_b,
638
+ ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
639
+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
640
+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
641
+ ggml_tensor * kq_b,
642
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
643
+ float kq_scale,
644
+ int il) const;
604
645
  //
605
646
  // recurrent
606
647
  //
607
648
 
608
- ggml_tensor * build_copy_mask_state(
609
- ggml_cgraph * gf,
610
- ggml_tensor * s,
611
- ggml_tensor * state_copy,
612
- ggml_tensor * state_mask,
613
- int32_t n_state,
614
- int32_t n_seqs) const;
649
+ // TODO: avoid notion of "kv"
650
+ // TODO: move this implementation to llama_memory_recurrent.
651
+ // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
652
+ // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
653
+ // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
654
+ // `llama_memory_recurrent`
655
+ ggml_tensor * build_rs(
656
+ ggml_cgraph * gf,
657
+ ggml_tensor * s,
658
+ ggml_tensor * state_copy,
659
+ int32_t state_size,
660
+ int32_t n_seqs,
661
+ uint32_t n_kv,
662
+ uint32_t kv_head,
663
+ uint32_t kv_size,
664
+ int32_t rs_zero,
665
+ bool avoid_copies = false) const;
666
+
667
+ llm_graph_input_rs * build_rs_inp() const;
668
+
669
+ ggml_tensor * build_rs(
670
+ llm_graph_input_rs * inp,
671
+ ggml_cgraph * gf,
672
+ ggml_tensor * s,
673
+ int32_t state_size,
674
+ int32_t n_seqs,
675
+ bool avoid_copies = false) const;
676
+
677
+ ggml_tensor * build_rs(
678
+ llm_graph_input_mem_hybrid * inp,
679
+ ggml_cgraph * gf,
680
+ ggml_tensor * s,
681
+ int32_t state_size,
682
+ int32_t n_seqs,
683
+ bool avoid_copies = false) const;
615
684
 
616
685
  ggml_tensor * build_rwkv_token_shift_load(
617
- ggml_cgraph * gf,
618
- ggml_tensor * state_copy,
619
- ggml_tensor * state_mask,
620
- const llama_ubatch & ubatch,
686
+ llm_graph_input_rs * inp,
687
+ ggml_cgraph * gf,
688
+ const llama_ubatch & ubatch,
621
689
  int il) const;
622
690
 
623
691
  ggml_tensor * build_rwkv_token_shift_store(
@@ -65,7 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
65
65
  return n_embd_head_v * n_head_kv;
66
66
  }
67
67
 
68
- uint32_t llama_hparams::n_embd_k_s() const {
68
+ uint32_t llama_hparams::n_embd_r() const {
69
69
  if (wkv_head_size != 0) {
70
70
  // for RWKV models
71
71
  return token_shift_count * n_embd;
@@ -76,7 +76,7 @@ uint32_t llama_hparams::n_embd_k_s() const {
76
76
  return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
77
77
  }
78
78
 
79
- uint32_t llama_hparams::n_embd_v_s() const {
79
+ uint32_t llama_hparams::n_embd_s() const {
80
80
  if (wkv_head_size != 0) {
81
81
  // corresponds to RWKV's wkv_states size
82
82
  return n_embd * wkv_head_size;
@@ -86,6 +86,14 @@ uint32_t llama_hparams::n_embd_v_s() const {
86
86
  return ssm_d_state * ssm_d_inner;
87
87
  }
88
88
 
89
+ bool llama_hparams::is_recurrent(uint32_t il) const {
90
+ return recurrent_layer_arr[il];
91
+ }
92
+
93
+ uint32_t llama_hparams::n_pos_per_embd() const {
94
+ return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
95
+ }
96
+
89
97
  bool llama_hparams::is_swa(uint32_t il) const {
90
98
  if (il < n_layer) {
91
99
  return swa_layers[il];
@@ -115,6 +115,9 @@ struct llama_hparams {
115
115
  uint32_t ssm_d_state = 0;
116
116
  uint32_t ssm_dt_rank = 0;
117
117
 
118
+ // for hybrid state space models
119
+ std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
120
+
118
121
  bool ssm_dt_b_c_rms = false;
119
122
 
120
123
  float f_clamp_kqv = 0.0f;
@@ -140,6 +143,12 @@ struct llama_hparams {
140
143
  uint32_t n_attn_temp_floor_scale = 8192;
141
144
  float f_attn_temp_scale = 0.1;
142
145
 
146
+ // gemma3n altup
147
+ uint32_t n_altup = 4; // altup_num_inputs
148
+ uint32_t i_altup_act = 0; // altup_active_idx
149
+ uint32_t laurel_rank = 64;
150
+ uint32_t n_embd_altup = 256;
151
+
143
152
  // needed by encoder-decoder models (e.g. T5, FLAN-T5)
144
153
  // ref: https://github.com/ggerganov/llama.cpp/pull/8141
145
154
  llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
@@ -181,10 +190,15 @@ struct llama_hparams {
181
190
 
182
191
  // dimension of the rolling state embeddings
183
192
  // corresponds to Mamba's conv_states size or RWKV's token_shift states size
184
- uint32_t n_embd_k_s() const;
193
+ uint32_t n_embd_r() const;
185
194
 
186
195
  // dimension of the recurrent state embeddings
187
- uint32_t n_embd_v_s() const;
196
+ uint32_t n_embd_s() const;
197
+
198
+ // whether or not the given layer is recurrent (for hybrid models)
199
+ bool is_recurrent(uint32_t il) const;
200
+
201
+ uint32_t n_pos_per_embd() const;
188
202
 
189
203
  bool is_swa(uint32_t il) const;
190
204
  };
@@ -52,9 +52,9 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
52
52
  hparams.n_swa, hparams.swa_type);
53
53
  }
54
54
 
55
- void llama_kv_cache_unified_iswa::clear() {
56
- kv_base->clear();
57
- kv_swa ->clear();
55
+ void llama_kv_cache_unified_iswa::clear(bool data) {
56
+ kv_base->clear(data);
57
+ kv_swa ->clear(data);
58
58
  }
59
59
 
60
60
  bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
@@ -95,54 +95,83 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
95
95
  return kv_swa->seq_pos_max(seq_id);
96
96
  }
97
97
 
98
- llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
99
- GGML_UNUSED(embd_pooled);
98
+ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
99
+ GGML_UNUSED(embd_all);
100
100
 
101
- // TODO: if we fail with split_simple, we should attempt different splitting strategies
102
- // but to do that properly, we first have to refactor the batches to be more flexible
101
+ // first try simple split
102
+ do {
103
+ balloc.split_reset();
103
104
 
104
- auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
105
+ std::vector<llama_ubatch> ubatches;
106
+ while (true) {
107
+ auto ubatch = balloc.split_simple(n_ubatch);
105
108
 
106
- std::vector<llama_ubatch> ubatches;
109
+ if (ubatch.n_tokens == 0) {
110
+ break;
111
+ }
107
112
 
108
- while (sbatch.n_tokens > 0) {
109
- auto ubatch = sbatch.split_simple(n_ubatch);
113
+ ubatches.push_back(std::move(ubatch)); // NOLINT
114
+ }
110
115
 
111
- ubatches.push_back(ubatch);
112
- }
116
+ auto heads_base = kv_base->prepare(ubatches);
117
+ if (heads_base.empty()) {
118
+ break;
119
+ }
113
120
 
114
- auto heads_base = kv_base->prepare(ubatches);
115
- if (heads_base.empty()) {
116
- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
117
- }
121
+ auto heads_swa = kv_swa->prepare(ubatches);
122
+ if (heads_swa.empty()) {
123
+ break;
124
+ }
118
125
 
119
- auto heads_swa = kv_swa->prepare(ubatches);
120
- if (heads_swa.empty()) {
121
- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
122
- }
126
+ assert(heads_base.size() == heads_swa.size());
123
127
 
124
- assert(heads_base.size() == heads_swa.size());
128
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(
129
+ this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
130
+ } while (false);
125
131
 
126
- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
127
- this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
128
- }
132
+ // if it fails, try equal split
133
+ do {
134
+ balloc.split_reset();
129
135
 
130
- llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
131
- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
132
- }
136
+ std::vector<llama_ubatch> ubatches;
137
+ while (true) {
138
+ auto ubatch = balloc.split_equal(n_ubatch);
133
139
 
134
- bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
135
- bool res = false;
140
+ if (ubatch.n_tokens == 0) {
141
+ break;
142
+ }
136
143
 
137
- res = res | kv_base->update(lctx);
138
- res = res | kv_swa ->update(lctx);
144
+ ubatches.push_back(std::move(ubatch)); // NOLINT
145
+ }
139
146
 
140
- return res;
147
+ auto heads_base = kv_base->prepare(ubatches);
148
+ if (heads_base.empty()) {
149
+ break;
150
+ }
151
+
152
+ auto heads_swa = kv_swa->prepare(ubatches);
153
+ if (heads_swa.empty()) {
154
+ break;
155
+ }
156
+
157
+ assert(heads_base.size() == heads_swa.size());
158
+
159
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(
160
+ this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
161
+ } while (false);
162
+
163
+ // TODO: if we fail again, we should attempt different splitting strategies
164
+ // but to do that properly, we first have to refactor the batches to be more flexible
165
+
166
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
141
167
  }
142
168
 
143
- void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
144
- kv_base->defrag_sched(thold);
145
- kv_swa ->defrag_sched(thold);
169
+ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
170
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
171
+ }
172
+
173
+ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
174
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
146
175
  }
147
176
 
148
177
  bool llama_kv_cache_unified_iswa::get_can_shift() const {
@@ -168,40 +197,46 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
168
197
  }
169
198
 
170
199
  //
171
- // llama_kv_cache_unified_iswa_state
200
+ // llama_kv_cache_unified_iswa_context
172
201
  //
173
202
 
174
- llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
203
+ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
204
+
205
+ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
206
+ llama_kv_cache_unified_iswa * kv) :
207
+ ctx_base(kv->get_base()->init_full()),
208
+ ctx_swa (kv->get_swa ()->init_full()),
209
+ status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
210
+ }
175
211
 
176
- llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
177
- llama_memory_status status,
178
- llama_kv_cache_unified_iswa * kv) : status(status) {
179
- state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
180
- state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
212
+ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
213
+ llama_kv_cache_unified_iswa * kv,
214
+ llama_context * lctx,
215
+ bool optimize) :
216
+ ctx_base(kv->get_base()->init_update(lctx, optimize)),
217
+ ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
218
+ status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
181
219
  }
182
220
 
183
- llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
184
- llama_memory_status status,
221
+ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
185
222
  llama_kv_cache_unified_iswa * kv,
186
- llama_sbatch sbatch,
187
223
  std::vector<uint32_t> heads_base,
188
224
  std::vector<uint32_t> heads_swa,
189
- std::vector<llama_ubatch> ubatches)
190
- : status(status),
191
- sbatch(std::move(sbatch)),
192
- ubatches(std::move(ubatches)) {
193
- // note: here we copy the ubatches. not sure if this is ideal
194
- state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
195
- state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
196
- }
225
+ std::vector<llama_ubatch> ubatches) :
226
+ ubatches(std::move(ubatches)),
227
+ // note: here we copy the ubatches. not sure if this is ideal
228
+ ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
229
+ ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
230
+ status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
231
+ }
197
232
 
198
- llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
233
+ llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
199
234
 
200
- bool llama_kv_cache_unified_iswa_state::next() {
235
+ bool llama_kv_cache_unified_iswa_context::next() {
201
236
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
202
237
 
203
- state_base->next();
204
- state_swa ->next();
238
+ ctx_base->next();
239
+ ctx_swa ->next();
205
240
 
206
241
  if (++i_next >= ubatches.size()) {
207
242
  return false;
@@ -210,40 +245,35 @@ bool llama_kv_cache_unified_iswa_state::next() {
210
245
  return true;
211
246
  }
212
247
 
213
- bool llama_kv_cache_unified_iswa_state::apply() {
248
+ bool llama_kv_cache_unified_iswa_context::apply() {
214
249
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
215
250
 
216
251
  bool res = true;
217
252
 
218
- res = res & state_base->apply();
219
- res = res & state_swa ->apply();
253
+ res = res & ctx_base->apply();
254
+ res = res & ctx_swa ->apply();
220
255
 
221
256
  return res;
222
257
  }
223
258
 
224
- std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
225
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
226
-
227
- return sbatch.out_ids;
228
- }
229
-
230
- llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
259
+ llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
231
260
  return status;
232
261
  }
233
262
 
234
- const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
263
+ const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
235
264
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
265
+
236
266
  return ubatches[i_next];
237
267
  }
238
268
 
239
- const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
269
+ const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
240
270
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
241
271
 
242
- return state_base.get();
272
+ return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
243
273
  }
244
274
 
245
- const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
275
+ const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
246
276
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
247
277
 
248
- return state_swa.get();
278
+ return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
249
279
  }