@novastera-oss/llamarn 0.2.6 → 0.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (253) hide show
  1. package/android/src/main/cpp/include/llama.h +141 -38
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +58 -24
  11. package/cpp/LlamaCppModel.h +3 -3
  12. package/cpp/PureCppImpl.cpp +1 -1
  13. package/cpp/PureCppImpl.h +2 -2
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  16. package/cpp/llama.cpp/Makefile +2 -2
  17. package/cpp/llama.cpp/README.md +32 -13
  18. package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
  19. package/cpp/llama.cpp/common/arg.cpp +37 -6
  20. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  21. package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
  22. package/cpp/llama.cpp/common/chat-parser.h +2 -0
  23. package/cpp/llama.cpp/common/chat.cpp +12 -9
  24. package/cpp/llama.cpp/common/chat.h +1 -1
  25. package/cpp/llama.cpp/common/common.cpp +53 -40
  26. package/cpp/llama.cpp/common/common.h +6 -2
  27. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  28. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  29. package/cpp/llama.cpp/convert_hf_to_gguf.py +215 -76
  30. package/cpp/llama.cpp/ggml/CMakeLists.txt +48 -2
  31. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  32. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  33. package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
  34. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +64 -13
  35. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  38. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +124 -26
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +4 -3
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +93 -104
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +194 -69
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1158 -0
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1571 -0
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +213 -37
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +59 -37
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +90 -39
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  88. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
  90. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  91. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +260 -49
  93. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +497 -282
  94. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1078 -468
  97. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  105. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +20 -48
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  110. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  111. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +117 -165
  112. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +192 -53
  113. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  114. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  115. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  116. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
  117. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  118. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +8 -105
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +209 -92
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +36 -28
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +487 -247
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  138. package/cpp/llama.cpp/ggml/src/ggml.c +69 -19
  139. package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
  140. package/cpp/llama.cpp/gguf-py/gguf/constants.py +133 -0
  141. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +25 -1
  142. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +78 -3
  143. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
  144. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  145. package/cpp/llama.cpp/include/llama.h +141 -38
  146. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  147. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  148. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  149. package/cpp/llama.cpp/src/llama-arch.cpp +150 -3
  150. package/cpp/llama.cpp/src/llama-arch.h +25 -1
  151. package/cpp/llama.cpp/src/llama-batch.cpp +736 -274
  152. package/cpp/llama.cpp/src/llama-batch.h +110 -57
  153. package/cpp/llama.cpp/src/llama-chat.cpp +30 -8
  154. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  155. package/cpp/llama.cpp/src/llama-context.cpp +360 -266
  156. package/cpp/llama.cpp/src/llama-context.h +27 -23
  157. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  158. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  159. package/cpp/llama.cpp/src/llama-graph.cpp +411 -344
  160. package/cpp/llama.cpp/src/llama-graph.h +126 -58
  161. package/cpp/llama.cpp/src/llama-hparams.cpp +10 -2
  162. package/cpp/llama.cpp/src/llama-hparams.h +16 -2
  163. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +103 -73
  164. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +34 -42
  165. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +345 -221
  166. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +75 -50
  167. package/cpp/llama.cpp/src/llama-kv-cells.h +51 -22
  168. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
  169. package/cpp/llama.cpp/src/llama-memory-hybrid.h +138 -0
  170. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +302 -317
  171. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +60 -68
  172. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  173. package/cpp/llama.cpp/src/llama-memory.h +73 -36
  174. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  175. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  176. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  177. package/cpp/llama.cpp/src/llama-model.cpp +1630 -511
  178. package/cpp/llama.cpp/src/llama-model.h +26 -0
  179. package/cpp/llama.cpp/src/llama-quant.cpp +89 -6
  180. package/cpp/llama.cpp/src/llama-vocab.cpp +58 -26
  181. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  182. package/cpp/llama.cpp/src/llama.cpp +11 -7
  183. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  184. package/cpp/rn-completion.cpp +2 -2
  185. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  186. package/cpp/{rn-utils.hpp → rn-utils.h} +3 -0
  187. package/ios/include/chat.h +1 -1
  188. package/ios/include/common.h +6 -2
  189. package/ios/include/llama.h +141 -38
  190. package/ios/libs/llama.xcframework/Info.plist +15 -15
  191. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  192. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
  193. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  194. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
  195. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +141 -38
  196. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  197. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  198. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  199. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
  200. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  201. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  202. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  203. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  204. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  205. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  206. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3624
  207. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  208. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
  209. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +141 -38
  210. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  211. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
  212. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +141 -38
  213. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +141 -38
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  219. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  220. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
  221. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  222. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
  223. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +141 -38
  224. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  225. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  226. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  227. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
  228. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  229. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  231. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  232. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  233. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4725
  234. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  235. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
  236. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +141 -38
  237. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  238. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  239. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4746
  240. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3652
  241. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  242. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  243. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  244. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  245. package/package.json +1 -2
  246. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  247. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  248. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  249. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
  250. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
  251. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  252. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  253. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -2,8 +2,8 @@
2
2
 
3
3
  #include "llama-batch.h"
4
4
  #include "llama-graph.h"
5
- #include "llama-kv-cache.h"
6
5
  #include "llama-kv-cells.h"
6
+ #include "llama-memory.h"
7
7
 
8
8
  #include <unordered_map>
9
9
  #include <vector>
@@ -17,13 +17,26 @@ struct llama_context;
17
17
  // llama_kv_cache_unified
18
18
  //
19
19
 
20
- class llama_kv_cache_unified : public llama_kv_cache {
20
+ class llama_kv_cache_unified : public llama_memory_i {
21
21
  public:
22
22
  static uint32_t get_padding(const llama_cparams & cparams);
23
23
 
24
24
  // this callback is used to filter out layers that should not be included in the cache
25
25
  using layer_filter_cb = std::function<bool(int32_t il)>;
26
26
 
27
+ using ubatch_heads = std::vector<uint32_t>;
28
+
29
+ struct defrag_info {
30
+ bool empty() const {
31
+ return ids.empty();
32
+ }
33
+
34
+ // contains information about which cell moves where:
35
+ // - cell i moves to ids[i]
36
+ // - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
37
+ std::vector<uint32_t> ids;
38
+ };
39
+
27
40
  llama_kv_cache_unified(
28
41
  const llama_model & model,
29
42
  layer_filter_cb && filter,
@@ -43,7 +56,18 @@ public:
43
56
  // llama_memory_i
44
57
  //
45
58
 
46
- void clear() override;
59
+ llama_memory_context_ptr init_batch(
60
+ llama_batch_allocr & balloc,
61
+ uint32_t n_ubatch,
62
+ bool embd_all) override;
63
+
64
+ llama_memory_context_ptr init_full() override;
65
+
66
+ llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
67
+
68
+ bool get_can_shift() const override;
69
+
70
+ void clear(bool data) override;
47
71
 
48
72
  bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
49
73
  void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
@@ -54,24 +78,6 @@ public:
54
78
  llama_pos seq_pos_min(llama_seq_id seq_id) const override;
55
79
  llama_pos seq_pos_max(llama_seq_id seq_id) const override;
56
80
 
57
- //
58
- // llama_kv_cache
59
- //
60
-
61
- llama_memory_state_ptr init_batch(
62
- const llama_batch & batch,
63
- uint32_t n_ubatch,
64
- bool embd_pooled,
65
- bool logits_all) override;
66
-
67
- llama_memory_state_ptr init_full() override;
68
-
69
- bool update(llama_context & lctx) override;
70
-
71
- void defrag_sched(float thold) override;
72
-
73
- bool get_can_shift() const override;
74
-
75
81
  // state write/load
76
82
 
77
83
  void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
@@ -83,6 +89,8 @@ public:
83
89
 
84
90
  uint32_t get_size() const;
85
91
 
92
+ bool get_has_shift() const;
93
+
86
94
  //
87
95
  // graph_build API
88
96
  //
@@ -103,7 +111,9 @@ public:
103
111
 
104
112
  // find places for the provided ubatches in the cache, returns the head locations
105
113
  // return empty vector on failure
106
- std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches);
114
+ ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
115
+
116
+ bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
107
117
 
108
118
  // return the cell position where we can insert the ubatch
109
119
  // return -1 on failure to find a contiguous slot of kv cells
@@ -133,8 +143,7 @@ private:
133
143
  ggml_tensor * v;
134
144
  };
135
145
 
136
- bool do_defrag = false;
137
- bool v_trans = true; // the value tensor is transposed
146
+ bool v_trans = true; // the value tensor is transposed
138
147
 
139
148
  // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
140
149
  // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
@@ -148,6 +157,8 @@ private:
148
157
  // SWA
149
158
  const uint32_t n_swa = 0;
150
159
 
160
+ int debug = 0;
161
+
151
162
  const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
152
163
 
153
164
  std::vector<ggml_context_ptr> ctxs;
@@ -160,13 +171,8 @@ private:
160
171
  // model layer id -> KV cache layer id
161
172
  std::unordered_map<int32_t, int32_t> map_layer_ids;
162
173
 
163
- // defrag
164
- struct {
165
- std::vector<uint32_t> ids;
166
- } defrag_info;
167
-
168
- // return true if cells have been moved
169
- bool defrag_prepare(int32_t n_max_nodes);
174
+ // return non-empty vector if cells have been moved
175
+ defrag_info defrag_prepare(int32_t n_max_nodes) const;
170
176
 
171
177
  size_t total_size() const;
172
178
 
@@ -192,7 +198,8 @@ private:
192
198
  llm_graph_result_ptr build_graph_defrag(
193
199
  const llama_cparams & cparams,
194
200
  ggml_context * ctx,
195
- ggml_cgraph * gf) const;
201
+ ggml_cgraph * gf,
202
+ const defrag_info & dinfo) const;
196
203
 
197
204
  void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
198
205
  void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
@@ -201,40 +208,46 @@ private:
201
208
  bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
202
209
  };
203
210
 
204
- class llama_kv_cache_unified_state : public llama_memory_state_i {
211
+ class llama_kv_cache_unified_context : public llama_memory_context_i {
205
212
  public:
213
+ // some shorthands
214
+ using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
215
+ using defrag_info = llama_kv_cache_unified::defrag_info;
216
+
206
217
  // used for errors
207
- llama_kv_cache_unified_state(llama_memory_status status);
218
+ llama_kv_cache_unified_context(llama_memory_status status);
208
219
 
209
- // used to create a full-cache state
210
- llama_kv_cache_unified_state(
211
- llama_memory_status status,
220
+ // used to create a full-cache context
221
+ llama_kv_cache_unified_context(
212
222
  llama_kv_cache_unified * kv);
213
223
 
214
- // used to create a state from a batch
215
- llama_kv_cache_unified_state(
216
- llama_memory_status status,
224
+ // used to create an update context
225
+ llama_kv_cache_unified_context(
226
+ llama_kv_cache_unified * kv,
227
+ llama_context * lctx,
228
+ bool do_shift,
229
+ defrag_info dinfo);
230
+
231
+ // used to create a batch procesing context from a batch
232
+ llama_kv_cache_unified_context(
217
233
  llama_kv_cache_unified * kv,
218
- llama_sbatch sbatch,
219
- std::vector<uint32_t> heads,
234
+ ubatch_heads heads,
220
235
  std::vector<llama_ubatch> ubatches);
221
236
 
222
- virtual ~llama_kv_cache_unified_state();
237
+ virtual ~llama_kv_cache_unified_context();
223
238
 
224
239
  //
225
- // llama_memory_state_i
240
+ // llama_memory_context_i
226
241
  //
227
242
 
228
243
  bool next() override;
229
244
  bool apply() override;
230
245
 
231
- std::vector<int64_t> & out_ids() override;
232
-
233
246
  llama_memory_status get_status() const override;
234
247
  const llama_ubatch & get_ubatch() const override;
235
248
 
236
249
  //
237
- // llama_kv_cache_unified_state specific API
250
+ // llama_kv_cache_unified_context specific API
238
251
  //
239
252
 
240
253
  uint32_t get_n_kv() const;
@@ -253,16 +266,28 @@ public:
253
266
  void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
254
267
 
255
268
  private:
256
- const llama_memory_status status;
269
+ llama_memory_status status;
257
270
 
258
271
  llama_kv_cache_unified * kv;
272
+ llama_context * lctx;
273
+
274
+ //
275
+ // update context
276
+ //
277
+
278
+ bool do_shift = false;
259
279
 
260
- llama_sbatch sbatch;
280
+ defrag_info dinfo;
281
+
282
+ //
283
+ // batch processing context
284
+ //
261
285
 
262
286
  // the index of the next ubatch to process
263
287
  size_t i_next = 0;
264
288
 
265
- std::vector<uint32_t> heads;
289
+ ubatch_heads heads;
290
+
266
291
  std::vector<llama_ubatch> ubatches;
267
292
 
268
293
  //
@@ -7,6 +7,7 @@
7
7
  #include <cassert>
8
8
  #include <vector>
9
9
  #include <set>
10
+ #include <map>
10
11
 
11
12
  // meta information about KV cells that can be part of multiple sequences at the same time
12
13
  // TODO: add unit tests
@@ -23,7 +24,7 @@ public:
23
24
 
24
25
  used.clear();
25
26
 
26
- for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
27
+ for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
27
28
  seq_pos[s].clear();
28
29
  }
29
30
  }
@@ -80,6 +81,9 @@ public:
80
81
  assert(isrc < pos.size());
81
82
  assert(idst < pos.size());
82
83
 
84
+ assert(pos[idst] == -1);
85
+ assert(pos[isrc] != -1);
86
+
83
87
  pos [idst] = pos [isrc];
84
88
  shift[idst] = shift[isrc];
85
89
  seq [idst] = seq [isrc];
@@ -144,9 +148,10 @@ public:
144
148
  assert(pos[i] != -1);
145
149
 
146
150
  seq_pos_rm(i);
151
+ seq[i].reset();
147
152
 
148
153
  pos[i] = -1;
149
- seq[i].reset();
154
+ shift[i] = 0;
150
155
 
151
156
  used.erase(i);
152
157
  }
@@ -160,10 +165,11 @@ public:
160
165
  assert(seq_id >= 0);
161
166
 
162
167
  seq[i].reset(seq_id);
163
- seq_pos[seq_id].erase(pos[i]);
168
+ seq_pos_dec(seq_id, pos[i]);
164
169
 
165
170
  if (seq[i].none()) {
166
171
  pos[i] = -1;
172
+ shift[i] = 0;
167
173
 
168
174
  used.erase(i);
169
175
 
@@ -182,7 +188,7 @@ public:
182
188
  seq[i].reset();
183
189
 
184
190
  seq[i].set(seq_id);
185
- seq_pos[seq_id].insert(pos[i]);
191
+ seq_pos_inc(seq_id, pos[i]);
186
192
 
187
193
  return false;
188
194
  }
@@ -192,6 +198,7 @@ public:
192
198
  seq[i].reset();
193
199
 
194
200
  pos[i] = -1;
201
+ shift[i] = 0;
195
202
 
196
203
  used.erase(i);
197
204
 
@@ -226,7 +233,7 @@ public:
226
233
  assert(!seq[i].test(seq_id));
227
234
 
228
235
  seq[i].set(seq_id);
229
- seq_pos[seq_id].insert(pos[i]);
236
+ seq_pos_inc(seq_id, pos[i]);
230
237
  }
231
238
 
232
239
  // return the sequence id of this cell
@@ -234,7 +241,7 @@ public:
234
241
  llama_seq_id seq_get(uint32_t i) const {
235
242
  assert(seq[i].count() == 1);
236
243
 
237
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
244
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
238
245
  if (seq[i].test(s)) {
239
246
  return s;
240
247
  }
@@ -247,26 +254,30 @@ public:
247
254
  // return -1 if the sequence is not present
248
255
  llama_pos seq_pos_min(llama_seq_id seq_id) const {
249
256
  assert(seq_id >= 0);
250
- assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
257
+ assert(seq_id < LLAMA_MAX_SEQ);
251
258
 
252
259
  if (seq_pos[seq_id].empty()) {
253
260
  return -1;
254
261
  }
255
262
 
256
- return *seq_pos[seq_id].begin();
263
+ assert(seq_pos[seq_id].begin()->second > 0);
264
+
265
+ return seq_pos[seq_id].begin()->first;
257
266
  }
258
267
 
259
268
  // the maximum position of sequence seq_id currently present in any of the cells
260
269
  // return -1 if the sequence is not present
261
270
  llama_pos seq_pos_max(llama_seq_id seq_id) const {
262
271
  assert(seq_id >= 0);
263
- assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
272
+ assert(seq_id < LLAMA_MAX_SEQ);
264
273
 
265
274
  if (seq_pos[seq_id].empty()) {
266
275
  return -1;
267
276
  }
268
277
 
269
- return *seq_pos[seq_id].rbegin();
278
+ assert(seq_pos[seq_id].rbegin()->second > 0);
279
+
280
+ return seq_pos[seq_id].rbegin()->first;
270
281
  }
271
282
 
272
283
  // note: call only if the cell is not empty
@@ -317,21 +328,20 @@ public:
317
328
  pos[i] += d;
318
329
  shift[i] += d;
319
330
 
320
- seq_pos_add(i);
321
-
322
331
  has_shift = true;
323
332
 
324
333
  if (pos[i] < 0) {
325
- seq_pos_rm(i);
326
-
327
334
  seq[i].reset();
328
335
  pos[i] = -1;
336
+ shift[i] = 0;
329
337
 
330
338
  used.erase(i);
331
339
 
332
340
  return true;
333
341
  }
334
342
 
343
+ seq_pos_add(i);
344
+
335
345
  return false;
336
346
  }
337
347
 
@@ -379,31 +389,50 @@ private:
379
389
  //
380
390
  std::vector<llama_pos> shift;
381
391
 
382
- using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
392
+ using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
383
393
 
384
394
  // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
385
- std::vector<bits_t> seq;
395
+ std::vector<seq_set_t> seq;
386
396
 
387
- // the set seq_pos[s] tells us which positions are currently present for sequence s
397
+ // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
398
+ // if the position p is not present, seq_pos[s][p] is not set
388
399
  // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
389
- std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
400
+ //
401
+ // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
402
+ // - during performing a cache reuse via (rm + add)
403
+ // - some vision models have input embeddings with repeating positions
404
+ //
405
+ std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
390
406
 
391
407
  // helper functions for updating `seq_pos`, once cell at a time:
392
408
 
409
+ void seq_pos_dec(llama_seq_id s, llama_pos p) {
410
+ auto it = seq_pos[s].find(p);
411
+ assert(it != seq_pos[s].end());
412
+
413
+ if (--it->second == 0) {
414
+ seq_pos[s].erase(it);
415
+ }
416
+ }
417
+
418
+ void seq_pos_inc(llama_seq_id s, llama_pos p) {
419
+ seq_pos[s][p]++;
420
+ }
421
+
393
422
  // remove cell i
394
423
  void seq_pos_rm(uint32_t i) {
395
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
424
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
396
425
  if (seq[i].test(s)) {
397
- seq_pos[s].erase(pos[i]);
426
+ seq_pos_dec(s, pos[i]);
398
427
  }
399
428
  }
400
429
  }
401
430
 
402
431
  // add cell i
403
432
  void seq_pos_add(uint32_t i) {
404
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
433
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
405
434
  if (seq[i].test(s)) {
406
- seq_pos[s].insert(pos[i]);
435
+ seq_pos_inc(s, pos[i]);
407
436
  }
408
437
  }
409
438
  }
@@ -0,0 +1,246 @@
1
+ #include "llama-memory-hybrid.h"
2
+
3
+ #include "llama-impl.h"
4
+ #include "llama-model.h"
5
+ #include "llama-context.h"
6
+
7
+ //
8
+ // llama_memory_hybrid
9
+ //
10
+
11
+ llama_memory_hybrid::llama_memory_hybrid(
12
+ const llama_model & model,
13
+ /* attn */
14
+ ggml_type type_k,
15
+ ggml_type type_v,
16
+ bool v_trans,
17
+ uint32_t kv_size,
18
+ uint32_t n_pad,
19
+ uint32_t n_swa,
20
+ llama_swa_type swa_type,
21
+ /* recurrent */
22
+ ggml_type type_r,
23
+ ggml_type type_s,
24
+ uint32_t rs_size,
25
+ /* common */
26
+ uint32_t n_seq_max,
27
+ bool offload,
28
+ /* layer filters */
29
+ layer_filter_cb && filter_attn,
30
+ layer_filter_cb && filter_recr) :
31
+ hparams(model.hparams),
32
+ mem_attn(new llama_kv_cache_unified(
33
+ model,
34
+ filter_attn == nullptr ?
35
+ [&](int32_t il) { return !hparams.is_recurrent(il); }
36
+ : filter_attn,
37
+ type_k,
38
+ type_v,
39
+ v_trans,
40
+ offload,
41
+ kv_size,
42
+ n_seq_max,
43
+ n_pad,
44
+ n_swa,
45
+ swa_type
46
+ )),
47
+ mem_recr(new llama_memory_recurrent(
48
+ model,
49
+ filter_recr == nullptr ?
50
+ [&](int32_t il) { return hparams.is_recurrent(il); }
51
+ : filter_recr,
52
+ type_r,
53
+ type_s,
54
+ offload,
55
+ rs_size,
56
+ n_seq_max
57
+ )) {}
58
+
59
+ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
60
+ do {
61
+ balloc.split_reset();
62
+
63
+ // follow the recurrent pattern for creating the ubatch splits
64
+ std::vector<llama_ubatch> ubatches;
65
+
66
+ while (true) {
67
+ llama_ubatch ubatch;
68
+
69
+ if (embd_all) {
70
+ // if all tokens are output, split by sequence
71
+ ubatch = balloc.split_seq(n_ubatch);
72
+ } else {
73
+ ubatch = balloc.split_equal(n_ubatch);
74
+ }
75
+
76
+ if (ubatch.n_tokens == 0) {
77
+ break;
78
+ }
79
+
80
+ ubatches.push_back(std::move(ubatch)); // NOLINT
81
+ }
82
+
83
+ // prepare the recurrent batches first
84
+ if (!mem_recr->prepare(ubatches)) {
85
+ // TODO: will the recurrent cache be in an undefined context at this point?
86
+ LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
87
+ return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
88
+ }
89
+
90
+ // prepare the attention cache
91
+ auto heads_attn = mem_attn->prepare(ubatches);
92
+ if (heads_attn.empty()) {
93
+ LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
94
+ return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
95
+ }
96
+
97
+ return std::make_unique<llama_memory_hybrid_context>(
98
+ this, std::move(heads_attn), std::move(ubatches));
99
+ } while(false);
100
+
101
+ return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
102
+ }
103
+
104
+ llama_memory_context_ptr llama_memory_hybrid::init_full() {
105
+ return std::make_unique<llama_memory_hybrid_context>(this);
106
+ }
107
+
108
+ llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
109
+ return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize);
110
+ }
111
+
112
+ bool llama_memory_hybrid::get_can_shift() const {
113
+ // Shifting is trivially supported for recurrent
114
+ return mem_attn->get_can_shift();
115
+ }
116
+
117
+ void llama_memory_hybrid::clear(bool data) {
118
+ mem_attn->clear(data);
119
+ mem_recr->clear(data);
120
+ }
121
+
122
+ bool llama_memory_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
123
+ // Try removing from the recurrent cache first since it may fail. If it does
124
+ // fail, the cache will not have been mutated.
125
+ if (!mem_recr->seq_rm(seq_id, p0, p1)) {
126
+ return false;
127
+ }
128
+ return mem_attn->seq_rm(seq_id, p0, p1);
129
+ }
130
+
131
+ void llama_memory_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
132
+ mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
133
+ mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
134
+ }
135
+
136
+ void llama_memory_hybrid::seq_keep(llama_seq_id seq_id) {
137
+ mem_attn->seq_keep(seq_id);
138
+ mem_recr->seq_keep(seq_id);
139
+ }
140
+
141
+ void llama_memory_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
142
+ mem_attn->seq_add(seq_id, p0, p1, shift);
143
+ mem_recr->seq_add(seq_id, p0, p1, shift);
144
+ }
145
+
146
+ void llama_memory_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
147
+ mem_attn->seq_div(seq_id, p0, p1, d);
148
+ mem_recr->seq_div(seq_id, p0, p1, d);
149
+ }
150
+
151
+ llama_pos llama_memory_hybrid::seq_pos_min(llama_seq_id seq_id) const {
152
+ // the min of the total cache is the max of the two caches' min values
153
+ return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
154
+ }
155
+
156
+ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
157
+ // the max of the total cache is the min of the two caches' max values
158
+ return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
159
+ }
160
+
161
+ void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
162
+ mem_attn->state_write(io, seq_id);
163
+ mem_recr->state_write(io, seq_id);
164
+ }
165
+
166
+ void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
167
+ mem_attn->state_read(io, seq_id);
168
+ mem_recr->state_read(io, seq_id);
169
+ }
170
+
171
+ llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
172
+ return mem_attn.get();
173
+ }
174
+
175
+ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
176
+ return mem_recr.get();
177
+ }
178
+
179
+ llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}
180
+
181
+ llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
182
+ ctx_attn(mem->get_mem_attn()->init_full()),
183
+ ctx_recr(mem->get_mem_recr()->init_full()),
184
+ status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
185
+ }
186
+
187
+ llama_memory_hybrid_context::llama_memory_hybrid_context(
188
+ llama_memory_hybrid * mem,
189
+ llama_context * lctx,
190
+ bool optimize) :
191
+ ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
192
+ ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
193
+ status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
194
+ }
195
+
196
+ llama_memory_hybrid_context::llama_memory_hybrid_context(
197
+ llama_memory_hybrid * mem,
198
+ std::vector<uint32_t> heads_attn,
199
+ std::vector<llama_ubatch> ubatches) :
200
+ ubatches(std::move(ubatches)),
201
+ // note: here we copy the ubatches. not sure if this is ideal
202
+ ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
203
+ ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
204
+ status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
205
+ }
206
+
207
+ bool llama_memory_hybrid_context::next() {
208
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
209
+
210
+ ctx_attn->next();
211
+ ctx_recr->next();
212
+
213
+ if (++i_next >= ubatches.size()) {
214
+ return false;
215
+ }
216
+
217
+ return true;
218
+ }
219
+
220
+ bool llama_memory_hybrid_context::apply() {
221
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
222
+
223
+ bool res = true;
224
+
225
+ res = res & ctx_attn->apply();
226
+ res = res & ctx_recr->apply();
227
+
228
+ return res;
229
+ }
230
+
231
+ llama_memory_status llama_memory_hybrid_context::get_status() const {
232
+ return status;
233
+ }
234
+
235
+ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
236
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
237
+ return ubatches[i_next];
238
+ }
239
+
240
+ const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
241
+ return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
242
+ }
243
+
244
+ const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
245
+ return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
246
+ }