@novastera-oss/llamarn 0.2.7 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (319) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  13. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  18. package/cpp/LlamaCppModel.cpp +56 -22
  19. package/cpp/build-info.cpp +2 -2
  20. package/cpp/llama.cpp/CMakeLists.txt +1 -2
  21. package/cpp/llama.cpp/README.md +4 -5
  22. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  23. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  24. package/cpp/llama.cpp/common/arg.cpp +24 -0
  25. package/cpp/llama.cpp/common/chat.cpp +37 -20
  26. package/cpp/llama.cpp/common/chat.h +2 -0
  27. package/cpp/llama.cpp/common/common.cpp +3 -0
  28. package/cpp/llama.cpp/common/common.h +5 -0
  29. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  30. package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
  31. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  32. package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
  33. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  34. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  35. package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
  36. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  100. package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
  101. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
  103. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
  104. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
  105. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  133. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  134. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  135. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  136. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  137. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
  138. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  141. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  142. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  144. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  168. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  169. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  170. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  172. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  173. package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
  177. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
  178. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
  179. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  180. package/cpp/llama.cpp/include/llama.h +8 -43
  181. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  182. package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
  183. package/cpp/llama.cpp/src/llama-arch.h +36 -1
  184. package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
  185. package/cpp/llama.cpp/src/llama-batch.h +105 -70
  186. package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
  187. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  188. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  189. package/cpp/llama.cpp/src/llama-context.h +13 -13
  190. package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
  191. package/cpp/llama.cpp/src/llama-graph.h +78 -79
  192. package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
  193. package/cpp/llama.cpp/src/llama-hparams.h +11 -0
  194. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
  195. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  196. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
  197. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
  198. package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
  199. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
  200. package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
  201. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
  202. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  203. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  204. package/cpp/llama.cpp/src/llama-memory.h +21 -22
  205. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  206. package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
  207. package/cpp/llama.cpp/src/llama-model.h +40 -0
  208. package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
  209. package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
  210. package/cpp/llama.cpp/src/llama-vocab.h +42 -0
  211. package/cpp/rn-utils.h +3 -0
  212. package/ios/include/chat.h +2 -0
  213. package/ios/include/common.h +5 -0
  214. package/ios/include/llama.h +8 -43
  215. package/ios/libs/llama.xcframework/Info.plist +19 -19
  216. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  217. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  218. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  219. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  220. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
  221. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  223. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  224. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  225. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  231. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  232. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  248. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  250. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
  254. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
  255. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  261. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  262. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  263. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  264. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  265. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  267. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  268. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
  269. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
  270. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
  274. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  275. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  276. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  277. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  278. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  279. package/package.json +1 -1
  280. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  315. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  316. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  317. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  318. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  319. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -24,8 +24,6 @@ public:
24
24
  // this callback is used to filter out layers that should not be included in the cache
25
25
  using layer_filter_cb = std::function<bool(int32_t il)>;
26
26
 
27
- using ubatch_heads = std::vector<uint32_t>;
28
-
29
27
  struct defrag_info {
30
28
  bool empty() const {
31
29
  return ids.empty();
@@ -37,6 +35,32 @@ public:
37
35
  std::vector<uint32_t> ids;
38
36
  };
39
37
 
38
+ // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
39
+ // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
40
+ struct slot_info {
41
+ // data for ggml_set_rows
42
+ using idx_vec_t = std::vector<uint32_t>;
43
+
44
+ idx_vec_t idxs;
45
+
46
+ uint32_t head() const {
47
+ return idxs.at(0);
48
+ }
49
+
50
+ bool empty() const {
51
+ return idxs.empty();
52
+ }
53
+
54
+ void clear() {
55
+ idxs.clear();
56
+ }
57
+
58
+ // TODO: implement
59
+ //std::vector<idx_vec_t> seq_idxs;
60
+ };
61
+
62
+ using slot_info_vec_t = std::vector<slot_info>;
63
+
40
64
  llama_kv_cache_unified(
41
65
  const llama_model & model,
42
66
  layer_filter_cb && filter,
@@ -56,14 +80,14 @@ public:
56
80
  // llama_memory_i
57
81
  //
58
82
 
59
- llama_memory_state_ptr init_batch(
60
- const llama_batch & batch,
83
+ llama_memory_context_ptr init_batch(
84
+ llama_batch_allocr & balloc,
61
85
  uint32_t n_ubatch,
62
86
  bool embd_all) override;
63
87
 
64
- llama_memory_state_ptr init_full() override;
88
+ llama_memory_context_ptr init_full() override;
65
89
 
66
- llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
90
+ llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
67
91
 
68
92
  bool get_can_shift() const override;
69
93
 
@@ -102,30 +126,37 @@ public:
102
126
  ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
103
127
 
104
128
  // store k_cur and v_cur in the cache based on the provided head location
105
- ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
106
- ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
129
+ ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
130
+ ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
107
131
 
108
132
  //
109
133
  // preparation API
110
134
  //
111
135
 
112
- // find places for the provided ubatches in the cache, returns the head locations
136
+ // find places for the provided ubatches in the cache, returns the slot infos
113
137
  // return empty vector on failure
114
- ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
138
+ slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
115
139
 
116
140
  bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
117
141
 
118
- // return the cell position where we can insert the ubatch
119
- // return -1 on failure to find a contiguous slot of kv cells
120
- int32_t find_slot(const llama_ubatch & ubatch) const;
142
+ // find a slot of kv cells that can hold the ubatch
143
+ // if cont == true, then the slot must be continuous
144
+ // return empty slot_info on failure
145
+ slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
121
146
 
122
- // emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
123
- void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
147
+ // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
148
+ void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
124
149
 
125
150
  //
126
- // set_input API
151
+ // input API
127
152
  //
128
153
 
154
+ ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
155
+ ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
156
+
157
+ void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
158
+ void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
159
+
129
160
  void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
130
161
  void set_input_k_shift (ggml_tensor * dst) const;
131
162
  void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
@@ -157,8 +188,13 @@ private:
157
188
  // SWA
158
189
  const uint32_t n_swa = 0;
159
190
 
191
+ // env: LLAMA_KV_CACHE_DEBUG
160
192
  int debug = 0;
161
193
 
194
+ // env: LLAMA_SET_ROWS (temporary)
195
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14285
196
+ int supports_set_rows = false;
197
+
162
198
  const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
163
199
 
164
200
  std::vector<ggml_context_ptr> ctxs;
@@ -208,49 +244,46 @@ private:
208
244
  bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
209
245
  };
210
246
 
211
- class llama_kv_cache_unified_state : public llama_memory_state_i {
247
+ class llama_kv_cache_unified_context : public llama_memory_context_i {
212
248
  public:
213
249
  // some shorthands
214
- using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
215
- using defrag_info = llama_kv_cache_unified::defrag_info;
250
+ using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
251
+ using defrag_info = llama_kv_cache_unified::defrag_info;
216
252
 
217
253
  // used for errors
218
- llama_kv_cache_unified_state(llama_memory_status status);
254
+ llama_kv_cache_unified_context(llama_memory_status status);
219
255
 
220
- // used to create a full-cache state
221
- llama_kv_cache_unified_state(
256
+ // used to create a full-cache context
257
+ llama_kv_cache_unified_context(
222
258
  llama_kv_cache_unified * kv);
223
259
 
224
- // used to create an update state
225
- llama_kv_cache_unified_state(
260
+ // used to create an update context
261
+ llama_kv_cache_unified_context(
226
262
  llama_kv_cache_unified * kv,
227
263
  llama_context * lctx,
228
264
  bool do_shift,
229
265
  defrag_info dinfo);
230
266
 
231
- // used to create a decode state from a batch
232
- llama_kv_cache_unified_state(
267
+ // used to create a batch procesing context from a batch
268
+ llama_kv_cache_unified_context(
233
269
  llama_kv_cache_unified * kv,
234
- llama_sbatch sbatch,
235
- ubatch_heads heads,
270
+ slot_info_vec_t sinfos,
236
271
  std::vector<llama_ubatch> ubatches);
237
272
 
238
- virtual ~llama_kv_cache_unified_state();
273
+ virtual ~llama_kv_cache_unified_context();
239
274
 
240
275
  //
241
- // llama_memory_state_i
276
+ // llama_memory_context_i
242
277
  //
243
278
 
244
279
  bool next() override;
245
280
  bool apply() override;
246
281
 
247
- std::vector<int64_t> & out_ids() override;
248
-
249
282
  llama_memory_status get_status() const override;
250
283
  const llama_ubatch & get_ubatch() const override;
251
284
 
252
285
  //
253
- // llama_kv_cache_unified_state specific API
286
+ // llama_kv_cache_unified_context specific API
254
287
  //
255
288
 
256
289
  uint32_t get_n_kv() const;
@@ -260,11 +293,16 @@ public:
260
293
  ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
261
294
 
262
295
  // store k_cur and v_cur in the cache based on the provided head location
263
- ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
264
- ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
296
+ ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
297
+ ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
298
+
299
+ ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
300
+ ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
265
301
 
266
- void set_input_k_shift(ggml_tensor * dst) const;
302
+ void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
303
+ void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
267
304
 
305
+ void set_input_k_shift (ggml_tensor * dst) const;
268
306
  void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
269
307
  void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
270
308
 
@@ -275,7 +313,7 @@ private:
275
313
  llama_context * lctx;
276
314
 
277
315
  //
278
- // update state
316
+ // update context
279
317
  //
280
318
 
281
319
  bool do_shift = false;
@@ -283,15 +321,13 @@ private:
283
321
  defrag_info dinfo;
284
322
 
285
323
  //
286
- // batch processing state
324
+ // batch processing context
287
325
  //
288
326
 
289
- llama_sbatch sbatch;
327
+ // the index of the cur ubatch to process
328
+ size_t i_cur = 0;
290
329
 
291
- // the index of the next ubatch to process
292
- size_t i_next = 0;
293
-
294
- ubatch_heads heads;
330
+ slot_info_vec_t sinfos;
295
331
 
296
332
  std::vector<llama_ubatch> ubatches;
297
333
 
@@ -302,7 +338,4 @@ private:
302
338
  // a heuristic, to avoid attending the full cache if it is not yet utilized
303
339
  // as the cache gets filled, the benefit from this heuristic disappears
304
340
  int32_t n_kv;
305
-
306
- // the beginning of the current slot in which the ubatch will be inserted
307
- int32_t head;
308
341
  };
@@ -7,6 +7,7 @@
7
7
  #include <cassert>
8
8
  #include <vector>
9
9
  #include <set>
10
+ #include <map>
10
11
 
11
12
  // meta information about KV cells that can be part of multiple sequences at the same time
12
13
  // TODO: add unit tests
@@ -104,10 +105,30 @@ public:
104
105
  res.resize(n);
105
106
 
106
107
  for (uint32_t j = 0; j < n; ++j) {
107
- res.pos[j] = pos[i + j];
108
- res.seq[j] = seq[i + j];
108
+ const auto idx = i + j;
109
109
 
110
- assert(shift[i + j] == 0);
110
+ res.pos[j] = pos[idx];
111
+ res.seq[j] = seq[idx];
112
+
113
+ assert(shift[idx] == 0);
114
+ }
115
+
116
+ return res;
117
+ }
118
+
119
+ // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
120
+ llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
121
+ llama_kv_cells_unified res;
122
+
123
+ res.resize(idxs.size());
124
+
125
+ for (uint32_t j = 0; j < idxs.size(); ++j) {
126
+ const auto idx = idxs[j];
127
+
128
+ res.pos[j] = pos[idx];
129
+ res.seq[j] = seq[idx];
130
+
131
+ assert(shift[idx] == 0);
111
132
  }
112
133
 
113
134
  return res;
@@ -118,26 +139,58 @@ public:
118
139
  assert(i + other.pos.size() <= pos.size());
119
140
 
120
141
  for (uint32_t j = 0; j < other.pos.size(); ++j) {
121
- if (pos[i + j] == -1 && other.pos[j] != -1) {
142
+ const auto idx = i + j;
143
+
144
+ if (pos[idx] == -1 && other.pos[j] != -1) {
122
145
  used.insert(i + j);
123
146
  }
124
147
 
125
- if (pos[i + j] != -1 && other.pos[j] == -1) {
148
+ if (pos[idx] != -1 && other.pos[j] == -1) {
126
149
  used.erase(i + j);
127
150
  }
128
151
 
129
- if (pos[i + j] != -1) {
152
+ if (pos[idx] != -1) {
130
153
  seq_pos_rm(i + j);
131
154
  }
132
155
 
133
- pos[i + j] = other.pos[j];
134
- seq[i + j] = other.seq[j];
156
+ pos[idx] = other.pos[j];
157
+ seq[idx] = other.seq[j];
135
158
 
136
- if (pos[i + j] != -1) {
159
+ if (pos[idx] != -1) {
137
160
  seq_pos_add(i + j);
138
161
  }
139
162
 
140
- assert(shift[i + j] == 0);
163
+ assert(shift[idx] == 0);
164
+ }
165
+ }
166
+
167
+ // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
168
+ void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
169
+ assert(idxs.size() == other.pos.size());
170
+
171
+ for (uint32_t j = 0; j < other.pos.size(); ++j) {
172
+ const auto idx = idxs[j];
173
+
174
+ if (pos[idx] == -1 && other.pos[j] != -1) {
175
+ used.insert(idx);
176
+ }
177
+
178
+ if (pos[idx] != -1 && other.pos[j] == -1) {
179
+ used.erase(idx);
180
+ }
181
+
182
+ if (pos[idx] != -1) {
183
+ seq_pos_rm(idx);
184
+ }
185
+
186
+ pos[idx] = other.pos[j];
187
+ seq[idx] = other.seq[j];
188
+
189
+ if (pos[idx] != -1) {
190
+ seq_pos_add(idx);
191
+ }
192
+
193
+ assert(shift[idx] == 0);
141
194
  }
142
195
  }
143
196
 
@@ -164,7 +217,7 @@ public:
164
217
  assert(seq_id >= 0);
165
218
 
166
219
  seq[i].reset(seq_id);
167
- seq_pos[seq_id].erase(pos[i]);
220
+ seq_pos_dec(seq_id, pos[i]);
168
221
 
169
222
  if (seq[i].none()) {
170
223
  pos[i] = -1;
@@ -187,7 +240,7 @@ public:
187
240
  seq[i].reset();
188
241
 
189
242
  seq[i].set(seq_id);
190
- seq_pos[seq_id].insert(pos[i]);
243
+ seq_pos_inc(seq_id, pos[i]);
191
244
 
192
245
  return false;
193
246
  }
@@ -232,7 +285,7 @@ public:
232
285
  assert(!seq[i].test(seq_id));
233
286
 
234
287
  seq[i].set(seq_id);
235
- seq_pos[seq_id].insert(pos[i]);
288
+ seq_pos_inc(seq_id, pos[i]);
236
289
  }
237
290
 
238
291
  // return the sequence id of this cell
@@ -259,7 +312,9 @@ public:
259
312
  return -1;
260
313
  }
261
314
 
262
- return *seq_pos[seq_id].begin();
315
+ assert(seq_pos[seq_id].begin()->second > 0);
316
+
317
+ return seq_pos[seq_id].begin()->first;
263
318
  }
264
319
 
265
320
  // the maximum position of sequence seq_id currently present in any of the cells
@@ -272,7 +327,9 @@ public:
272
327
  return -1;
273
328
  }
274
329
 
275
- return *seq_pos[seq_id].rbegin();
330
+ assert(seq_pos[seq_id].rbegin()->second > 0);
331
+
332
+ return seq_pos[seq_id].rbegin()->first;
276
333
  }
277
334
 
278
335
  // note: call only if the cell is not empty
@@ -384,22 +441,41 @@ private:
384
441
  //
385
442
  std::vector<llama_pos> shift;
386
443
 
387
- using bits_t = std::bitset<LLAMA_MAX_SEQ>;
444
+ using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
388
445
 
389
446
  // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
390
- std::vector<bits_t> seq;
447
+ std::vector<seq_set_t> seq;
391
448
 
392
- // the set seq_pos[s] tells us which positions are currently present for sequence s
449
+ // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
450
+ // if the position p is not present, seq_pos[s][p] is not set
393
451
  // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
394
- std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
452
+ //
453
+ // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
454
+ // - during performing a cache reuse via (rm + add)
455
+ // - some vision models have input embeddings with repeating positions
456
+ //
457
+ std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
395
458
 
396
459
  // helper functions for updating `seq_pos`, once cell at a time:
397
460
 
461
+ void seq_pos_dec(llama_seq_id s, llama_pos p) {
462
+ auto it = seq_pos[s].find(p);
463
+ assert(it != seq_pos[s].end());
464
+
465
+ if (--it->second == 0) {
466
+ seq_pos[s].erase(it);
467
+ }
468
+ }
469
+
470
+ void seq_pos_inc(llama_seq_id s, llama_pos p) {
471
+ seq_pos[s][p]++;
472
+ }
473
+
398
474
  // remove cell i
399
475
  void seq_pos_rm(uint32_t i) {
400
476
  for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
401
477
  if (seq[i].test(s)) {
402
- seq_pos[s].erase(pos[i]);
478
+ seq_pos_dec(s, pos[i]);
403
479
  }
404
480
  }
405
481
  }
@@ -408,7 +484,7 @@ private:
408
484
  void seq_pos_add(uint32_t i) {
409
485
  for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
410
486
  if (seq[i].test(s)) {
411
- seq_pos[s].insert(pos[i]);
487
+ seq_pos_inc(s, pos[i]);
412
488
  }
413
489
  }
414
490
  }
@@ -32,7 +32,7 @@ llama_memory_hybrid::llama_memory_hybrid(
32
32
  mem_attn(new llama_kv_cache_unified(
33
33
  model,
34
34
  filter_attn == nullptr ?
35
- [&](int32_t il) { return !model.hparams.is_recurrent(il); }
35
+ [&](int32_t il) { return !hparams.is_recurrent(il); }
36
36
  : filter_attn,
37
37
  type_k,
38
38
  type_v,
@@ -47,7 +47,7 @@ llama_memory_hybrid::llama_memory_hybrid(
47
47
  mem_recr(new llama_memory_recurrent(
48
48
  model,
49
49
  filter_recr == nullptr ?
50
- [&](int32_t il) { return model.hparams.is_recurrent(il); }
50
+ [&](int32_t il) { return hparams.is_recurrent(il); }
51
51
  : filter_recr,
52
52
  type_r,
53
53
  type_s,
@@ -56,50 +56,62 @@ llama_memory_hybrid::llama_memory_hybrid(
56
56
  n_seq_max
57
57
  )) {}
58
58
 
59
- llama_memory_state_ptr llama_memory_hybrid::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
59
+ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
60
+ do {
61
+ balloc.split_reset();
60
62
 
61
- // since this includes a recurrent cache, we cannot use split_simple
62
- auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
63
+ // follow the recurrent pattern for creating the ubatch splits
64
+ std::vector<llama_ubatch> ubatches;
63
65
 
64
- // follow the recurrent pattern for creating the ubatch splits
65
- std::vector<llama_ubatch> ubatches;
66
- while (sbatch.n_tokens > 0) {
67
- llama_ubatch ubatch;
66
+ while (true) {
67
+ llama_ubatch ubatch;
68
68
 
69
- if (embd_pooled) {
70
- // Pooled embeddings cannot be split across ubatches (yet)
71
- ubatch = sbatch.split_seq(n_ubatch);
72
- } else {
73
- ubatch = sbatch.split_equal(n_ubatch);
69
+ if (embd_all) {
70
+ // if all tokens are output, split by sequence
71
+ ubatch = balloc.split_seq(n_ubatch);
72
+ } else {
73
+ ubatch = balloc.split_equal(n_ubatch, false);
74
+ }
75
+
76
+ if (ubatch.n_tokens == 0) {
77
+ break;
78
+ }
79
+
80
+ ubatches.push_back(std::move(ubatch)); // NOLINT
74
81
  }
75
82
 
76
- ubatches.push_back(ubatch);
77
- }
83
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
84
+ // failed to find a suitable split
85
+ break;
86
+ }
78
87
 
79
- // prepare the recurrent batches first
80
- if (!mem_recr->prepare(ubatches)) {
81
- // TODO: will the recurrent cache be in an undefined state at this point?
82
- LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
83
- return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
84
- }
88
+ // prepare the recurrent batches first
89
+ if (!mem_recr->prepare(ubatches)) {
90
+ // TODO: will the recurrent cache be in an undefined context at this point?
91
+ LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
92
+ return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
93
+ }
85
94
 
86
- // prepare the attention cache
87
- auto heads_attn = mem_attn->prepare(ubatches);
88
- if (heads_attn.empty()) {
89
- LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
90
- return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
91
- }
95
+ // prepare the attention cache
96
+ auto heads_attn = mem_attn->prepare(ubatches);
97
+ if (heads_attn.empty()) {
98
+ LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
99
+ return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
100
+ }
101
+
102
+ return std::make_unique<llama_memory_hybrid_context>(
103
+ this, std::move(heads_attn), std::move(ubatches));
104
+ } while(false);
92
105
 
93
- return std::make_unique<llama_memory_hybrid_state>(
94
- this, std::move(sbatch), std::move(heads_attn), std::move(ubatches));
106
+ return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
95
107
  }
96
108
 
97
- llama_memory_state_ptr llama_memory_hybrid::init_full() {
98
- return std::make_unique<llama_memory_hybrid_state>(this);
109
+ llama_memory_context_ptr llama_memory_hybrid::init_full() {
110
+ return std::make_unique<llama_memory_hybrid_context>(this);
99
111
  }
100
112
 
101
- llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
102
- return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize);
113
+ llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
114
+ return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize);
103
115
  }
104
116
 
105
117
  bool llama_memory_hybrid::get_can_shift() const {
@@ -169,41 +181,39 @@ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
169
181
  return mem_recr.get();
170
182
  }
171
183
 
172
- llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {}
184
+ llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}
173
185
 
174
- llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) :
175
- state_attn(mem->get_mem_attn()->init_full()),
176
- state_recr(mem->get_mem_recr()->init_full()),
177
- status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
186
+ llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
187
+ ctx_attn(mem->get_mem_attn()->init_full()),
188
+ ctx_recr(mem->get_mem_recr()->init_full()),
189
+ status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
178
190
  }
179
191
 
180
- llama_memory_hybrid_state::llama_memory_hybrid_state(
192
+ llama_memory_hybrid_context::llama_memory_hybrid_context(
181
193
  llama_memory_hybrid * mem,
182
194
  llama_context * lctx,
183
195
  bool optimize) :
184
- state_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
185
- state_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
186
- status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
196
+ ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
197
+ ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
198
+ status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
187
199
  }
188
200
 
189
- llama_memory_hybrid_state::llama_memory_hybrid_state(
201
+ llama_memory_hybrid_context::llama_memory_hybrid_context(
190
202
  llama_memory_hybrid * mem,
191
- llama_sbatch sbatch,
192
- std::vector<uint32_t> heads_attn,
203
+ slot_info_vec_t sinfos_attn,
193
204
  std::vector<llama_ubatch> ubatches) :
194
- sbatch(std::move(sbatch)),
195
205
  ubatches(std::move(ubatches)),
196
206
  // note: here we copy the ubatches. not sure if this is ideal
197
- state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), {}, std::move(heads_attn), this->ubatches)),
198
- state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), {}, this->ubatches)),
199
- status(LLAMA_MEMORY_STATUS_SUCCESS) {
207
+ ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
208
+ ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
209
+ status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
200
210
  }
201
211
 
202
- bool llama_memory_hybrid_state::next() {
212
+ bool llama_memory_hybrid_context::next() {
203
213
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
204
214
 
205
- state_attn->next();
206
- state_recr->next();
215
+ ctx_attn->next();
216
+ ctx_recr->next();
207
217
 
208
218
  if (++i_next >= ubatches.size()) {
209
219
  return false;
@@ -212,36 +222,30 @@ bool llama_memory_hybrid_state::next() {
212
222
  return true;
213
223
  }
214
224
 
215
- bool llama_memory_hybrid_state::apply() {
216
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
225
+ bool llama_memory_hybrid_context::apply() {
226
+ assert(!llama_memory_status_is_fail(status));
217
227
 
218
228
  bool res = true;
219
229
 
220
- res = res & state_attn->apply();
221
- res = res & state_recr->apply();
230
+ res = res & ctx_attn->apply();
231
+ res = res & ctx_recr->apply();
222
232
 
223
233
  return res;
224
234
  }
225
235
 
226
- std::vector<int64_t> & llama_memory_hybrid_state::out_ids() {
227
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
228
-
229
- return sbatch.out_ids;
230
- }
231
-
232
- llama_memory_status llama_memory_hybrid_state::get_status() const {
236
+ llama_memory_status llama_memory_hybrid_context::get_status() const {
233
237
  return status;
234
238
  }
235
239
 
236
- const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const {
240
+ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
237
241
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
238
242
  return ubatches[i_next];
239
243
  }
240
244
 
241
- const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const {
242
- return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
245
+ const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
246
+ return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
243
247
  }
244
248
 
245
- const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const {
246
- return static_cast<const llama_memory_recurrent_state *>(state_recr.get());
249
+ const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
250
+ return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
247
251
  }