@novastera-oss/llamarn 0.2.7 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (319) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  13. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  18. package/cpp/LlamaCppModel.cpp +56 -22
  19. package/cpp/build-info.cpp +2 -2
  20. package/cpp/llama.cpp/CMakeLists.txt +1 -2
  21. package/cpp/llama.cpp/README.md +4 -5
  22. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  23. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  24. package/cpp/llama.cpp/common/arg.cpp +24 -0
  25. package/cpp/llama.cpp/common/chat.cpp +37 -20
  26. package/cpp/llama.cpp/common/chat.h +2 -0
  27. package/cpp/llama.cpp/common/common.cpp +3 -0
  28. package/cpp/llama.cpp/common/common.h +5 -0
  29. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  30. package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
  31. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  32. package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
  33. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  34. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  35. package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
  36. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  100. package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
  101. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
  103. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
  104. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
  105. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  133. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  134. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  135. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  136. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  137. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
  138. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  141. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  142. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  144. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  168. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  169. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  170. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  172. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  173. package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
  177. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
  178. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
  179. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  180. package/cpp/llama.cpp/include/llama.h +8 -43
  181. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  182. package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
  183. package/cpp/llama.cpp/src/llama-arch.h +36 -1
  184. package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
  185. package/cpp/llama.cpp/src/llama-batch.h +105 -70
  186. package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
  187. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  188. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  189. package/cpp/llama.cpp/src/llama-context.h +13 -13
  190. package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
  191. package/cpp/llama.cpp/src/llama-graph.h +78 -79
  192. package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
  193. package/cpp/llama.cpp/src/llama-hparams.h +11 -0
  194. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
  195. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  196. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
  197. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
  198. package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
  199. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
  200. package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
  201. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
  202. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  203. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  204. package/cpp/llama.cpp/src/llama-memory.h +21 -22
  205. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  206. package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
  207. package/cpp/llama.cpp/src/llama-model.h +40 -0
  208. package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
  209. package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
  210. package/cpp/llama.cpp/src/llama-vocab.h +42 -0
  211. package/cpp/rn-utils.h +3 -0
  212. package/ios/include/chat.h +2 -0
  213. package/ios/include/common.h +5 -0
  214. package/ios/include/llama.h +8 -43
  215. package/ios/libs/llama.xcframework/Info.plist +19 -19
  216. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  217. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  218. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  219. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  220. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
  221. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  223. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  224. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  225. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  231. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  232. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  248. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  250. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
  254. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
  255. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  261. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  262. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  263. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  264. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  265. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  267. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  268. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
  269. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
  270. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
  274. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  275. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  276. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  277. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  278. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  279. package/package.json +1 -1
  280. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  315. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  316. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  317. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  318. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  319. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -33,13 +33,19 @@ llama_kv_cache_unified::llama_kv_cache_unified(
33
33
 
34
34
  GGML_ASSERT(kv_size % n_pad == 0);
35
35
 
36
+ // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
37
+ auto n_layer_cache = hparams.n_layer;
38
+ if (model.arch == LLM_ARCH_GEMMA3N) {
39
+ n_layer_cache = 20;
40
+ }
41
+
36
42
  // create a context for each buffer type
37
43
  std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
38
44
  auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
39
45
  auto it = ctx_map.find(buft);
40
46
  if (it == ctx_map.end()) {
41
47
  ggml_init_params params = {
42
- /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
48
+ /*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
43
49
  /*.mem_buffer =*/ NULL,
44
50
  /*.no_alloc =*/ true,
45
51
  };
@@ -62,7 +68,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
62
68
 
63
69
  cells.resize(kv_size);
64
70
 
65
- for (uint32_t il = 0; il < hparams.n_layer; il++) {
71
+ for (uint32_t il = 0; il < n_layer_cache; il++) {
66
72
  if (filter && !filter(il)) {
67
73
  LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
68
74
  continue;
@@ -102,6 +108,26 @@ llama_kv_cache_unified::llama_kv_cache_unified(
102
108
  layers.push_back({ il, k, v });
103
109
  }
104
110
 
111
+ // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
112
+ if (model.arch == LLM_ARCH_GEMMA3N) {
113
+ LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
114
+
115
+ for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
116
+ if (filter && !filter(il)) {
117
+ LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
118
+ continue;
119
+ }
120
+
121
+ const bool is_swa = hparams.is_swa(il);
122
+ const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
123
+
124
+ GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
125
+ map_layer_ids[il] = map_layer_ids[il_reuse];
126
+
127
+ LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
128
+ }
129
+ }
130
+
105
131
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
106
132
  for (auto it : ctx_map) {
107
133
  auto * buft = it.first;
@@ -130,6 +156,13 @@ llama_kv_cache_unified::llama_kv_cache_unified(
130
156
 
131
157
  const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
132
158
  debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
159
+
160
+ const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
161
+ supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
162
+
163
+ if (!supports_set_rows) {
164
+ LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
165
+ }
133
166
  }
134
167
 
135
168
  void llama_kv_cache_unified::clear(bool data) {
@@ -307,37 +340,48 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
307
340
  return cells.seq_pos_max(seq_id);
308
341
  }
309
342
 
310
- llama_memory_state_ptr llama_kv_cache_unified::init_batch(
311
- const llama_batch & batch,
343
+ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
344
+ llama_batch_allocr & balloc,
312
345
  uint32_t n_ubatch,
313
346
  bool embd_all) {
314
347
  GGML_UNUSED(embd_all);
315
348
 
316
349
  do {
317
- auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
350
+ balloc.split_reset();
318
351
 
319
352
  std::vector<llama_ubatch> ubatches;
320
- while (sbatch.n_tokens > 0) {
321
- ubatches.push_back(sbatch.split_simple(n_ubatch));
353
+ while (true) {
354
+ auto ubatch = balloc.split_simple(n_ubatch);
355
+
356
+ if (ubatch.n_tokens == 0) {
357
+ break;
358
+ }
359
+
360
+ ubatches.push_back(std::move(ubatch)); // NOLINT
361
+ }
362
+
363
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
364
+ // failed to find a suitable split
365
+ break;
322
366
  }
323
367
 
324
- auto heads = prepare(ubatches);
325
- if (heads.empty()) {
368
+ auto sinfos = prepare(ubatches);
369
+ if (sinfos.empty()) {
326
370
  break;
327
371
  }
328
372
 
329
- return std::make_unique<llama_kv_cache_unified_state>(
330
- this, std::move(sbatch), std::move(heads), std::move(ubatches));
373
+ return std::make_unique<llama_kv_cache_unified_context>(
374
+ this, std::move(sinfos), std::move(ubatches));
331
375
  } while (false);
332
376
 
333
- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
377
+ return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
334
378
  }
335
379
 
336
- llama_memory_state_ptr llama_kv_cache_unified::init_full() {
337
- return std::make_unique<llama_kv_cache_unified_state>(this);
380
+ llama_memory_context_ptr llama_kv_cache_unified::init_full() {
381
+ return std::make_unique<llama_kv_cache_unified_context>(this);
338
382
  }
339
383
 
340
- llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
384
+ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
341
385
  bool do_shift = get_has_shift();
342
386
 
343
387
  defrag_info dinfo;
@@ -367,15 +411,16 @@ llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx,
367
411
  }
368
412
  }
369
413
 
370
- return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
414
+ return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
371
415
  }
372
416
 
373
- llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
374
- llama_kv_cache_unified::ubatch_heads res;
417
+ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
418
+ llama_kv_cache_unified::slot_info_vec_t res;
375
419
 
376
420
  struct state {
377
421
  uint32_t head_old; // old position of the head, before placing the ubatch
378
- uint32_t head_new; // new position of the head, after placing the ubatch
422
+
423
+ slot_info sinfo; // slot info for the ubatch
379
424
 
380
425
  llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
381
426
  };
@@ -386,26 +431,29 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::
386
431
  bool success = true;
387
432
 
388
433
  for (const auto & ubatch : ubatches) {
434
+ // non-continuous slots require support for ggml_set_rows()
435
+ const bool cont = supports_set_rows ? false : true;
436
+
389
437
  // only find a suitable slot for the ubatch. don't modify the cells yet
390
- const int32_t head_new = find_slot(ubatch);
391
- if (head_new < 0) {
438
+ const auto sinfo_new = find_slot(ubatch, cont);
439
+ if (sinfo_new.empty()) {
392
440
  success = false;
393
441
  break;
394
442
  }
395
443
 
396
444
  // remeber the position that we found
397
- res.push_back(head_new);
445
+ res.push_back(sinfo_new);
398
446
 
399
447
  // store the old state of the cells in the recovery stack
400
- states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)});
448
+ states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
401
449
 
402
450
  // now emplace the ubatch
403
- apply_ubatch(head_new, ubatch);
451
+ apply_ubatch(sinfo_new, ubatch);
404
452
  }
405
453
 
406
454
  // iterate backwards and restore the cells to their original state
407
455
  for (auto it = states.rbegin(); it != states.rend(); ++it) {
408
- cells.set(it->head_new, it->cells);
456
+ cells.set(it->sinfo.idxs, it->cells);
409
457
  head = it->head_old;
410
458
  }
411
459
 
@@ -507,7 +555,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
507
555
  return updated;
508
556
  }
509
557
 
510
- int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
558
+ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
511
559
  const uint32_t n_tokens = ubatch.n_tokens;
512
560
 
513
561
  uint32_t head_cur = this->head;
@@ -520,7 +568,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
520
568
 
521
569
  if (n_tokens > cells.size()) {
522
570
  LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
523
- return -1;
571
+ return { };
524
572
  }
525
573
 
526
574
  if (debug > 0) {
@@ -583,15 +631,26 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
583
631
 
584
632
  uint32_t n_tested = 0;
585
633
 
634
+ // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
635
+ // for non-continuous slots, we test the tokens one by one
636
+ const uint32_t n_test = cont ? n_tokens : 1;
637
+
638
+ slot_info res;
639
+
640
+ auto & idxs = res.idxs;
641
+
642
+ idxs.reserve(n_tokens);
643
+
586
644
  while (true) {
587
- if (head_cur + n_tokens > cells.size()) {
645
+ if (head_cur + n_test > cells.size()) {
588
646
  n_tested += cells.size() - head_cur;
589
647
  head_cur = 0;
590
648
  continue;
591
649
  }
592
650
 
593
- bool found = true;
594
- for (uint32_t i = 0; i < n_tokens; i++) {
651
+ for (uint32_t i = 0; i < n_test; i++) {
652
+ const auto idx = head_cur;
653
+
595
654
  //const llama_pos pos = ubatch.pos[i];
596
655
  //const llama_seq_id seq_id = ubatch.seq_id[i][0];
597
656
 
@@ -601,19 +660,19 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
601
660
  // - (disabled) mask causally, if the sequence is the same as the one we are inserting
602
661
  // - mask SWA, using current max pos for that sequence in the cache
603
662
  // always insert in the cell with minimum pos
604
- bool can_use = cells.is_empty(head_cur + i);
663
+ bool can_use = cells.is_empty(idx);
605
664
 
606
- if (!can_use && cells.seq_count(head_cur + i) == 1) {
607
- const llama_pos pos_cell = cells.pos_get(head_cur + i);
665
+ if (!can_use && cells.seq_count(idx) == 1) {
666
+ const llama_pos pos_cell = cells.pos_get(idx);
608
667
 
609
668
  // (disabled) causal mask
610
669
  // note: it's better to purge any "future" tokens beforehand
611
- //if (cells.seq_has(head_cur + i, seq_id)) {
670
+ //if (cells.seq_has(idx, seq_id)) {
612
671
  // can_use = pos_cell >= pos;
613
672
  //}
614
673
 
615
674
  if (!can_use) {
616
- const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
675
+ const llama_seq_id seq_id_cell = cells.seq_get(idx);
617
676
 
618
677
  // SWA mask
619
678
  if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
@@ -622,34 +681,39 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
622
681
  }
623
682
  }
624
683
 
625
- if (!can_use) {
626
- found = false;
627
- head_cur += i + 1;
628
- n_tested += i + 1;
684
+ head_cur++;
685
+ n_tested++;
686
+
687
+ if (can_use) {
688
+ idxs.push_back(idx);
689
+ } else {
629
690
  break;
630
691
  }
631
692
  }
632
693
 
633
- if (found) {
694
+ if (idxs.size() == n_tokens) {
634
695
  break;
635
696
  }
636
697
 
698
+ if (cont) {
699
+ idxs.clear();
700
+ }
701
+
637
702
  if (n_tested >= cells.size()) {
638
703
  //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
639
- return -1;
704
+ return { };
640
705
  }
641
706
  }
642
707
 
643
- return head_cur;
644
- }
645
-
646
- void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
647
- if (debug > 0) {
648
- LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
649
- LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
650
- LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
708
+ // we didn't find a suitable slot - return empty result
709
+ if (idxs.size() < n_tokens) {
710
+ res.clear();
651
711
  }
652
712
 
713
+ return res;
714
+ }
715
+
716
+ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
653
717
  // keep track of the max sequence position that we would overwrite with this ubatch
654
718
  // for non-SWA cache, this would be always empty
655
719
  llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@@ -657,27 +721,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
657
721
  seq_pos_max_rm[s] = -1;
658
722
  }
659
723
 
660
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
661
- for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) {
662
- const uint32_t idx = s*ubatch.n_seq_tokens + j;
724
+ assert(ubatch.n_tokens == sinfo.idxs.size());
663
725
 
664
- if (!cells.is_empty(head_cur + idx)) {
665
- assert(cells.seq_count(head_cur + idx) == 1);
726
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
727
+ const auto idx = sinfo.idxs.at(i);
666
728
 
667
- const llama_seq_id seq_id = cells.seq_get(head_cur + idx);
668
- const llama_pos pos = cells.pos_get(head_cur + idx);
729
+ if (!cells.is_empty(idx)) {
730
+ assert(cells.seq_count(idx) == 1);
669
731
 
670
- seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
732
+ const llama_seq_id seq_id = cells.seq_get(idx);
733
+ const llama_pos pos = cells.pos_get(idx);
671
734
 
672
- cells.rm(head_cur + idx);
673
- }
735
+ seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
674
736
 
675
- cells.pos_set(head_cur + idx, ubatch.pos[idx]);
737
+ cells.rm(idx);
738
+ }
676
739
 
677
- // TODO: fix indexing [UBATCH_IDX]
678
- for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) {
679
- cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
680
- }
740
+ cells.pos_set(idx, ubatch.pos[i]);
741
+
742
+ for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
743
+ cells.seq_add(idx, ubatch.seq_id[i][s]);
681
744
  }
682
745
  }
683
746
 
@@ -696,8 +759,9 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
696
759
  seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
697
760
  }
698
761
  }
762
+
699
763
  // move the head at the end of the slot
700
- head = head_cur + ubatch.n_tokens;
764
+ head = sinfo.idxs.back() + 1;
701
765
  }
702
766
 
703
767
  bool llama_kv_cache_unified::get_can_shift() const {
@@ -750,51 +814,135 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
750
814
  0);
751
815
  }
752
816
 
753
- ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
817
+ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
754
818
  const int32_t ikv = map_layer_ids.at(il);
755
819
 
756
820
  auto * k = layers[ikv].k;
757
821
 
822
+ const int64_t n_embd_k_gqa = k->ne[0];
758
823
  const int64_t n_tokens = k_cur->ne[2];
759
824
 
825
+ k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
826
+
827
+ if (k_idxs && supports_set_rows) {
828
+ return ggml_set_rows(ctx, k, k_cur, k_idxs);
829
+ }
830
+
831
+ // TODO: fallback to old ggml_cpy() method for backwards compatibility
832
+ // will be removed when ggml_set_rows() is adopted by all backends
833
+
760
834
  ggml_tensor * k_view = ggml_view_1d(ctx, k,
761
- n_tokens*hparams.n_embd_k_gqa(il),
762
- ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur);
835
+ n_tokens*n_embd_k_gqa,
836
+ ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
763
837
 
764
838
  return ggml_cpy(ctx, k_cur, k_view);
765
839
  }
766
840
 
767
- ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
841
+ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
768
842
  const int32_t ikv = map_layer_ids.at(il);
769
843
 
770
844
  auto * v = layers[ikv].v;
771
845
 
846
+ const int64_t n_embd_v_gqa = v->ne[0];
772
847
  const int64_t n_tokens = v_cur->ne[2];
773
848
 
774
- v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
849
+ v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
850
+
851
+ if (v_idxs && supports_set_rows) {
852
+ if (!v_trans) {
853
+ return ggml_set_rows(ctx, v, v_cur, v_idxs);
854
+ }
855
+
856
+ // the row becomes a single element
857
+ ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
858
+
859
+ // note: the V cache is transposed when not using flash attention
860
+ v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
861
+
862
+ // note: we can be more explicit here at the cost of extra cont
863
+ // however, above we take advantage that a row of single element is always continuous regardless of the row stride
864
+ //v_cur = ggml_transpose(ctx, v_cur);
865
+ //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
866
+
867
+ // we broadcast the KV indices n_embd_v_gqa times
868
+ // v [1, n_kv, n_embd_v_gqa]
869
+ // v_cur [1, n_tokens, n_embd_v_gqa]
870
+ // v_idxs [n_tokens, 1, 1]
871
+ return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
872
+ }
873
+
874
+ // TODO: fallback to old ggml_cpy() method for backwards compatibility
875
+ // will be removed when ggml_set_rows() is adopted by all backends
775
876
 
776
877
  ggml_tensor * v_view = nullptr;
777
878
 
778
879
  if (!v_trans) {
779
880
  v_view = ggml_view_1d(ctx, v,
780
- n_tokens*hparams.n_embd_v_gqa(il),
781
- ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
881
+ n_tokens*n_embd_v_gqa,
882
+ ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
782
883
  } else {
783
- // note: the V cache is transposed when not using flash attention
784
- v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
785
- (v->ne[1])*ggml_element_size(v),
786
- (head_cur)*ggml_element_size(v));
787
-
788
884
  v_cur = ggml_transpose(ctx, v_cur);
885
+
886
+ v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
887
+ (v->ne[1] )*ggml_element_size(v),
888
+ (sinfo.head())*ggml_element_size(v));
789
889
  }
790
890
 
791
891
  return ggml_cpy(ctx, v_cur, v_view);
792
892
  }
793
893
 
894
+ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
895
+ const uint32_t n_tokens = ubatch.n_tokens;
896
+
897
+ ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
898
+
899
+ ggml_set_input(k_idxs);
900
+
901
+ return k_idxs;
902
+ }
903
+
904
+ ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
905
+ const uint32_t n_tokens = ubatch.n_tokens;
906
+
907
+ ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
908
+
909
+ ggml_set_input(v_idxs);
910
+
911
+ return v_idxs;
912
+ }
913
+
914
+ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
915
+ if (!supports_set_rows) {
916
+ return;
917
+ }
918
+
919
+ const uint32_t n_tokens = ubatch->n_tokens;
920
+
921
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
922
+ int64_t * data = (int64_t *) dst->data;
923
+
924
+ for (int64_t i = 0; i < n_tokens; ++i) {
925
+ data[i] = sinfo.idxs.at(i);
926
+ }
927
+ }
928
+
929
+ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
930
+ if (!supports_set_rows) {
931
+ return;
932
+ }
933
+
934
+ const uint32_t n_tokens = ubatch->n_tokens;
935
+
936
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
937
+ int64_t * data = (int64_t *) dst->data;
938
+
939
+ for (int64_t i = 0; i < n_tokens; ++i) {
940
+ data[i] = sinfo.idxs.at(i);
941
+ }
942
+ }
943
+
794
944
  void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
795
- const uint32_t n_tokens = ubatch->n_tokens;
796
- const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
797
- const uint32_t n_seqs = ubatch->n_seqs;
945
+ const uint32_t n_tokens = ubatch->n_tokens;
798
946
 
799
947
  GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
800
948
  float * data = (float *) dst->data;
@@ -814,52 +962,48 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
814
962
  // xxxxx-----
815
963
  // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
816
964
  for (uint32_t h = 0; h < 1; ++h) {
817
- for (uint32_t s = 0; s < n_seqs; ++s) {
818
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
819
-
820
- for (uint32_t j = 0; j < n_seq_tokens; ++j) {
821
- const uint32_t idx = s*n_seq_tokens + j;
822
-
823
- const llama_pos p1 = ubatch->pos[idx];
965
+ for (uint32_t i = 0; i < n_tokens; ++i) {
966
+ const llama_seq_id seq_id = ubatch->seq_id[i][0];
824
967
 
825
- for (uint32_t i = 0; i < n_kv; ++i) {
826
- float f = 0.0f;
968
+ const llama_pos p1 = ubatch->pos[i];
827
969
 
828
- bool masked = false;
970
+ for (uint32_t j = 0; j < n_kv; ++j) {
971
+ float f = 0.0f;
829
972
 
830
- if (cells.is_empty(i)) {
831
- masked = true;
832
- } else {
833
- const llama_pos p0 = cells.pos_get(i);
973
+ bool masked = false;
834
974
 
835
- // mask the token if not the same sequence
836
- masked = masked || (!cells.seq_has(i, seq_id));
975
+ if (cells.is_empty(j)) {
976
+ masked = true;
977
+ } else {
978
+ const llama_pos p0 = cells.pos_get(j);
837
979
 
838
- // mask future tokens
839
- masked = masked || (causal_attn && p0 > p1);
980
+ // mask the token if not the same sequence
981
+ masked = masked || (!cells.seq_has(j, seq_id));
840
982
 
841
- // apply SWA if any
842
- masked = masked || (is_masked_swa(p0, p1));
983
+ // mask future tokens
984
+ masked = masked || (causal_attn && p0 > p1);
843
985
 
844
- if (!masked && hparams.use_alibi) {
845
- f = -std::abs(p0 - p1);
846
- }
847
- }
986
+ // apply SWA if any
987
+ masked = masked || (is_masked_swa(p0, p1));
848
988
 
849
- if (masked) {
850
- f = -INFINITY;
989
+ if (!masked && hparams.use_alibi) {
990
+ f = -std::abs(p0 - p1);
851
991
  }
992
+ }
852
993
 
853
- data[h*(n_kv*n_tokens) + idx*n_kv + i] = f;
994
+ if (masked) {
995
+ f = -INFINITY;
854
996
  }
997
+
998
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
855
999
  }
856
1000
  }
857
1001
 
858
1002
  // mask padded tokens
859
1003
  if (data) {
860
- for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) {
861
- for (uint32_t i = 0; i < n_kv; ++i) {
862
- data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
1004
+ for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
1005
+ for (uint32_t j = 0; j < n_kv; ++j) {
1006
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
863
1007
  }
864
1008
  }
865
1009
  }
@@ -887,12 +1031,12 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
887
1031
  const int32_t n_kv = dst->ne[0];
888
1032
 
889
1033
  for (int h = 0; h < 1; ++h) {
890
- for (int j = 0; j < n_tokens; ++j) {
891
- for (int i = 0; i < n_kv; ++i) {
1034
+ for (int i = 0; i < n_tokens; ++i) {
1035
+ for (int j = 0; j < n_kv; ++j) {
892
1036
  // the position when the cells is empty is irrelevant - it will be masked out later in the attention
893
- const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
1037
+ const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
894
1038
 
895
- data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
1039
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
896
1040
  }
897
1041
  }
898
1042
  }
@@ -1509,12 +1653,9 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1509
1653
 
1510
1654
  seq_rm(dest_seq_id, -1, -1);
1511
1655
 
1512
- llama_sbatch sbatch;
1513
- llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1656
+ llama_batch_allocr balloc(hparams.n_pos_per_embd());
1514
1657
 
1515
- ubatch.n_tokens = cell_count;
1516
- ubatch.n_seq_tokens = cell_count;
1517
- ubatch.n_seqs = 1;
1658
+ llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
1518
1659
 
1519
1660
  for (uint32_t i = 0; i < cell_count; ++i) {
1520
1661
  llama_pos pos;
@@ -1539,13 +1680,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1539
1680
  ubatch.seq_id[i] = &dest_seq_id;
1540
1681
  }
1541
1682
 
1542
- const auto head_cur = find_slot(ubatch);
1543
- if (head_cur < 0) {
1683
+ const auto sinfo = find_slot(ubatch, true);
1684
+ if (sinfo.empty()) {
1544
1685
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1545
1686
  return false;
1546
1687
  }
1547
1688
 
1548
- apply_ubatch(head_cur, ubatch);
1689
+ apply_ubatch(sinfo, ubatch);
1690
+
1691
+ const auto head_cur = sinfo.head();
1549
1692
 
1550
1693
  // keep the head at the old position because we will read the KV data into it in state_read_data()
1551
1694
  head = head_cur;
@@ -1723,18 +1866,22 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
1723
1866
  }
1724
1867
 
1725
1868
  //
1726
- // llama_kv_cache_unified_state
1869
+ // llama_kv_cache_unified_context
1727
1870
  //
1728
1871
 
1729
- llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
1872
+ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
1730
1873
 
1731
- llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1874
+ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1732
1875
  llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
1733
1876
  n_kv = kv->get_size();
1734
- head = 0;
1877
+
1878
+ // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
1879
+ sinfos.resize(1);
1880
+ sinfos[0].idxs.resize(1);
1881
+ sinfos[0].idxs[0] = 0;
1735
1882
  }
1736
1883
 
1737
- llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1884
+ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1738
1885
  llama_kv_cache_unified * kv,
1739
1886
  llama_context * lctx,
1740
1887
  bool do_shift,
@@ -1744,27 +1891,26 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1744
1891
  }
1745
1892
  }
1746
1893
 
1747
- llama_kv_cache_unified_state::llama_kv_cache_unified_state(
1894
+ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1748
1895
  llama_kv_cache_unified * kv,
1749
- llama_sbatch sbatch,
1750
- llama_kv_cache_unified::ubatch_heads heads,
1751
- std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) {
1896
+ llama_kv_cache_unified::slot_info_vec_t sinfos,
1897
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
1752
1898
  }
1753
1899
 
1754
- llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
1900
+ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
1755
1901
 
1756
- bool llama_kv_cache_unified_state::next() {
1902
+ bool llama_kv_cache_unified_context::next() {
1757
1903
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1758
1904
 
1759
- if (++i_next >= ubatches.size()) {
1905
+ if (++i_cur >= ubatches.size()) {
1760
1906
  return false;
1761
1907
  }
1762
1908
 
1763
1909
  return true;
1764
1910
  }
1765
1911
 
1766
- bool llama_kv_cache_unified_state::apply() {
1767
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1912
+ bool llama_kv_cache_unified_context::apply() {
1913
+ assert(!llama_memory_status_is_fail(status));
1768
1914
 
1769
1915
  // no ubatches -> this is a KV cache update
1770
1916
  if (ubatches.empty()) {
@@ -1773,59 +1919,68 @@ bool llama_kv_cache_unified_state::apply() {
1773
1919
  return true;
1774
1920
  }
1775
1921
 
1776
- kv->apply_ubatch(heads[i_next], ubatches[i_next]);
1922
+ kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
1777
1923
 
1778
1924
  n_kv = kv->get_n_kv();
1779
- head = heads[i_next];
1780
1925
 
1781
1926
  return true;
1782
1927
  }
1783
1928
 
1784
- std::vector<int64_t> & llama_kv_cache_unified_state::out_ids() {
1785
- assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1786
-
1787
- return sbatch.out_ids;
1788
- }
1789
-
1790
- llama_memory_status llama_kv_cache_unified_state::get_status() const {
1929
+ llama_memory_status llama_kv_cache_unified_context::get_status() const {
1791
1930
  return status;
1792
1931
  }
1793
1932
 
1794
- const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const {
1933
+ const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
1795
1934
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1796
1935
 
1797
- return ubatches[i_next];
1936
+ return ubatches[i_cur];
1798
1937
  }
1799
1938
 
1800
- uint32_t llama_kv_cache_unified_state::get_n_kv() const {
1939
+ uint32_t llama_kv_cache_unified_context::get_n_kv() const {
1801
1940
  return n_kv;
1802
1941
  }
1803
1942
 
1804
- ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const {
1943
+ ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
1805
1944
  return kv->get_k(ctx, il, n_kv);
1806
1945
  }
1807
1946
 
1808
- ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const {
1947
+ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
1809
1948
  return kv->get_v(ctx, il, n_kv);
1810
1949
  }
1811
1950
 
1812
- ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1813
- return kv->cpy_k(ctx, k_cur, il, head);
1951
+ ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
1952
+ return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
1953
+ }
1954
+
1955
+ ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
1956
+ return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
1957
+ }
1958
+
1959
+ ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1960
+ return kv->build_input_k_idxs(ctx, ubatch);
1814
1961
  }
1815
1962
 
1816
- ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1817
- return kv->cpy_v(ctx, v_cur, il, head);
1963
+ ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1964
+ return kv->build_input_v_idxs(ctx, ubatch);
1818
1965
  }
1819
1966
 
1820
- void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const {
1967
+ void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
1821
1968
  kv->set_input_k_shift(dst);
1822
1969
  }
1823
1970
 
1824
- void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1971
+ void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1972
+ kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
1973
+ }
1974
+
1975
+ void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1976
+ kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
1977
+ }
1978
+
1979
+ void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1825
1980
  kv->set_input_kq_mask(dst, ubatch, causal_attn);
1826
1981
  }
1827
1982
 
1828
- void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1983
+ void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1829
1984
  kv->set_input_pos_bucket(dst, ubatch);
1830
1985
  }
1831
1986