@novastera-oss/llamarn 0.3.1 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (347) hide show
  1. package/README.md +86 -3
  2. package/RNLlamaCpp.podspec +1 -1
  3. package/android/CMakeLists.txt +11 -3
  4. package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
  5. package/android/src/main/cpp/include/llama.h +53 -114
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  13. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  20. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  21. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  22. package/cpp/LlamaCppModel.cpp +2 -10
  23. package/cpp/PureCppImpl.cpp +71 -4
  24. package/cpp/SystemUtils.cpp +3 -7
  25. package/cpp/build-info.cpp +2 -2
  26. package/cpp/llama.cpp/CMakeLists.txt +2 -0
  27. package/cpp/llama.cpp/CODEOWNERS +1 -1
  28. package/cpp/llama.cpp/Makefile +6 -1605
  29. package/cpp/llama.cpp/README.md +5 -1
  30. package/cpp/llama.cpp/common/arg.cpp +230 -51
  31. package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
  32. package/cpp/llama.cpp/common/chat.cpp +539 -8
  33. package/cpp/llama.cpp/common/chat.h +8 -1
  34. package/cpp/llama.cpp/common/common.cpp +60 -15
  35. package/cpp/llama.cpp/common/common.h +64 -15
  36. package/cpp/llama.cpp/common/speculative.cpp +135 -54
  37. package/cpp/llama.cpp/common/speculative.h +8 -1
  38. package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
  39. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
  40. package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
  41. package/cpp/llama.cpp/flake.nix +0 -5
  42. package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
  43. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
  44. package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
  45. package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
  46. package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
  47. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
  51. package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
  52. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
  54. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
  55. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
  57. package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
  73. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
  74. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
  75. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
  76. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
  77. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
  80. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
  82. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
  83. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
  84. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
  85. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  100. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
  101. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
  102. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
  103. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
  104. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
  105. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
  106. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
  107. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
  108. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
  110. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
  111. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
  112. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
  113. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
  114. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
  117. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
  118. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  119. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  120. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
  121. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
  122. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
  123. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  124. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  125. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
  126. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  127. package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  128. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
  131. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
  132. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
  134. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
  135. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
  136. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
  137. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  138. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  141. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  142. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
  143. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
  144. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  145. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
  147. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
  148. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
  149. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  150. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
  152. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  153. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  154. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  161. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  164. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  166. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
  167. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
  168. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
  169. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
  170. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
  173. package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
  174. package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
  175. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
  176. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  177. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
  178. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
  179. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
  180. package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
  181. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
  182. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
  183. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  184. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  185. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  186. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  187. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
  188. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  189. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
  190. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  191. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  192. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
  193. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
  194. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
  195. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
  196. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
  197. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
  198. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
  199. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
  200. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
  201. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  202. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
  203. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
  204. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
  205. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  206. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  207. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
  208. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  209. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
  210. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  211. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  212. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  213. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
  214. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  215. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
  216. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
  217. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
  218. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
  219. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
  220. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
  221. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
  222. package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  223. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
  224. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
  225. package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
  226. package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
  227. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
  228. package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
  229. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
  230. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
  231. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
  232. package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
  233. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
  234. package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
  235. package/cpp/llama.cpp/include/llama.h +53 -114
  236. package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
  237. package/cpp/llama.cpp/models/templates/README.md +2 -1
  238. package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
  239. package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
  240. package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
  241. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
  242. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
  243. package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
  244. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  245. package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
  246. package/cpp/llama.cpp/src/llama-adapter.h +3 -0
  247. package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
  248. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  249. package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
  250. package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
  251. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  252. package/cpp/llama.cpp/src/llama-context.cpp +61 -252
  253. package/cpp/llama.cpp/src/llama-context.h +10 -15
  254. package/cpp/llama.cpp/src/llama-cparams.h +0 -1
  255. package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
  256. package/cpp/llama.cpp/src/llama-graph.h +90 -51
  257. package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
  258. package/cpp/llama.cpp/src/llama-hparams.h +21 -6
  259. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
  260. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
  261. package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
  262. package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
  263. package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
  264. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
  265. package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
  266. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
  267. package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
  268. package/cpp/llama.cpp/src/llama-memory.h +13 -10
  269. package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
  270. package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
  271. package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
  272. package/cpp/llama.cpp/src/llama-model.h +28 -4
  273. package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
  274. package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
  275. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  276. package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
  277. package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
  278. package/cpp/rn-completion.cpp +3 -27
  279. package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
  280. package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
  281. package/ios/include/chat.h +8 -1
  282. package/ios/include/common/minja/chat-template.hpp +16 -7
  283. package/ios/include/common/minja/minja.hpp +47 -12
  284. package/ios/include/common.h +64 -15
  285. package/ios/include/llama.h +53 -114
  286. package/ios/include/speculative.h +8 -1
  287. package/ios/libs/llama.xcframework/Info.plist +18 -18
  288. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  289. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
  290. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  291. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
  292. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
  293. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  294. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  295. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
  296. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  297. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  298. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  299. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  300. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  301. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  302. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  303. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
  304. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
  305. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
  306. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
  307. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
  308. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
  309. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
  310. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  311. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
  312. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
  313. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
  314. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  315. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  316. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  317. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
  318. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  319. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
  320. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
  321. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  322. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  323. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  324. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  325. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  326. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  327. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  328. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  329. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  330. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
  331. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  332. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
  333. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
  334. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  335. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  336. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
  337. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
  338. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  339. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  340. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  341. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  342. package/lib/module/NativeRNLlamaCpp.js.map +1 -1
  343. package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
  344. package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
  345. package/package.json +1 -2
  346. package/src/NativeRNLlamaCpp.ts +7 -0
  347. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
@@ -46,10 +46,8 @@ struct llama_context {
46
46
 
47
47
  llama_memory_t get_memory() const;
48
48
 
49
- // return true of the KV cache was updated
50
- // TODO: remove
51
- bool kv_self_update(bool optimize);
52
- void kv_self_defrag_sched();
49
+ // return true if the memory was updated
50
+ bool memory_update(bool optimize);
53
51
 
54
52
  enum llama_pooling_type pooling_type() const;
55
53
 
@@ -111,9 +109,9 @@ struct llama_context {
111
109
  size_t state_get_data( uint8_t * dst, size_t size);
112
110
  size_t state_set_data(const uint8_t * src, size_t size);
113
111
 
114
- size_t state_seq_get_size(llama_seq_id seq_id);
115
- size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
116
- size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
112
+ size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags);
113
+ size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags);
114
+ size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags);
117
115
 
118
116
  bool state_load_file(
119
117
  const char * filepath,
@@ -152,6 +150,7 @@ struct llama_context {
152
150
 
153
151
  void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
154
152
 
153
+ // TODO: more flexible combinations of logical/physical batch size and context size
155
154
  void opt_epoch(
156
155
  ggml_opt_dataset_t dataset,
157
156
  ggml_opt_result_t result_train,
@@ -212,8 +211,8 @@ private:
212
211
  size_t state_write_data(llama_io_write_i & io);
213
212
  size_t state_read_data (llama_io_read_i & io);
214
213
 
215
- size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
216
- size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
214
+ size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
215
+ size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
217
216
 
218
217
  //
219
218
  // members
@@ -229,9 +228,6 @@ private:
229
228
 
230
229
  std::unique_ptr<llama_memory_i> memory;
231
230
 
232
- // TODO: temporary, until the llama_kv_self_defrag() API is removed
233
- bool memory_force_optimize = false;
234
-
235
231
  // decode output (2-dimensional array: [n_outputs][n_vocab])
236
232
  size_t logits_size = 0; // capacity (of floats) for logits
237
233
  float * logits = nullptr;
@@ -287,9 +283,8 @@ private:
287
283
 
288
284
  bool has_evaluated_once = false;
289
285
 
290
- // env: LLAMA_SET_ROWS (temporary)
291
- // ref: https://github.com/ggml-org/llama.cpp/pull/14285
292
- bool supports_set_rows = false;
286
+ // env: LLAMA_GRAPH_REUSE_DISABLE
287
+ bool graph_reuse_disable = false;
293
288
 
294
289
  // perf
295
290
  mutable int64_t t_start_us = 0;
@@ -24,7 +24,6 @@ struct llama_cparams {
24
24
  float yarn_attn_factor;
25
25
  float yarn_beta_fast;
26
26
  float yarn_beta_slow;
27
- float defrag_thold;
28
27
 
29
28
  bool embeddings;
30
29
  bool causal_attn;
@@ -4,8 +4,8 @@
4
4
  #include "llama-batch.h"
5
5
  #include "llama-cparams.h"
6
6
 
7
- #include "llama-kv-cache-unified.h"
8
- #include "llama-kv-cache-unified-iswa.h"
7
+ #include "llama-kv-cache.h"
8
+ #include "llama-kv-cache-iswa.h"
9
9
  #include "llama-memory-hybrid.h"
10
10
  #include "llama-memory-recurrent.h"
11
11
 
@@ -188,38 +188,23 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
188
188
 
189
189
  void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
190
190
  const int64_t n_tokens = ubatch->n_tokens;
191
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
192
191
  const int64_t n_seqs_unq = ubatch->n_seqs_unq;
193
192
 
194
193
  if (cparams.embeddings && (
195
- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
196
- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
197
- )) {
194
+ cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
195
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
196
+ cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
197
+ )) {
198
198
  GGML_ASSERT(cls);
199
199
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
200
200
 
201
201
  uint32_t * data = (uint32_t *) cls->data;
202
202
  memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
203
203
 
204
- for (int i = 0; i < n_tokens; i += n_seq_tokens) {
205
- for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
206
- const llama_seq_id seq_id = ubatch->seq_id[i][s];
207
- const int32_t seq_idx = ubatch->seq_idx[seq_id];
208
-
209
- data[seq_idx] = i;
210
- }
211
- }
212
- }
204
+ std::vector<int> target_pos(n_seqs_unq, -1);
205
+ std::vector<int> target_row(n_seqs_unq, -1);
213
206
 
214
- if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
215
- GGML_ASSERT(cls);
216
- GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
217
-
218
- uint32_t * data = (uint32_t *) cls->data;
219
- memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
220
-
221
- std::vector<int> last_pos(n_seqs_unq, -1);
222
- std::vector<int> last_row(n_seqs_unq, -1);
207
+ bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST;
223
208
 
224
209
  for (int i = 0; i < n_tokens; ++i) {
225
210
  const llama_pos pos = ubatch->pos[i];
@@ -228,16 +213,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
228
213
  const llama_seq_id seq_id = ubatch->seq_id[i][s];
229
214
  const int32_t seq_idx = ubatch->seq_idx[seq_id];
230
215
 
231
- if (pos >= last_pos[seq_idx]) {
232
- last_pos[seq_idx] = pos;
233
- last_row[seq_idx] = i;
216
+ if (
217
+ (target_pos[seq_idx] == -1) ||
218
+ ( last && pos >= target_pos[seq_idx]) ||
219
+ (!last && pos < target_pos[seq_idx])
220
+ ) {
221
+ target_pos[seq_idx] = pos;
222
+ target_row[seq_idx] = i;
234
223
  }
235
224
  }
236
225
  }
237
226
 
238
227
  for (int s = 0; s < n_seqs_unq; ++s) {
239
- if (last_row[s] >= 0) {
240
- data[s] = last_row[s];
228
+ if (target_row[s] >= 0) {
229
+ data[s] = target_row[s];
241
230
  }
242
231
  }
243
232
  }
@@ -288,7 +277,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
288
277
  for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
289
278
  const llama_seq_id s0 = ubatch->seq_id[i0][0];
290
279
 
291
- // TODO: reimplement this like in llama_kv_cache_unified
280
+ // TODO: reimplement this like in llama_kv_cache
292
281
  if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
293
282
  if (hparams.use_alibi) {
294
283
  f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
@@ -305,15 +294,15 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
305
294
  }
306
295
  }
307
296
 
308
- void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
297
+ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
309
298
  mctx->set_input_k_idxs(self_k_idxs, ubatch);
310
299
  mctx->set_input_v_idxs(self_v_idxs, ubatch);
311
300
 
312
301
  mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
313
302
  }
314
303
 
315
- bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
316
- const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
304
+ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
305
+ const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
317
306
 
318
307
  this->mctx = mctx;
319
308
 
@@ -325,12 +314,10 @@ bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params)
325
314
  res &= self_kq_mask->ne[0] == mctx->get_n_kv();
326
315
  res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
327
316
 
328
- res &= mctx->get_supports_set_rows(); // TODO: tmp
329
-
330
317
  return res;
331
318
  }
332
319
 
333
- void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
320
+ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
334
321
  mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
335
322
  mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
336
323
 
@@ -342,8 +329,8 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
342
329
  mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
343
330
  }
344
331
 
345
- bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
346
- const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
332
+ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
333
+ const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
347
334
 
348
335
  this->mctx = mctx;
349
336
 
@@ -361,8 +348,6 @@ bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & pa
361
348
  res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
362
349
  res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
363
350
 
364
- res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp
365
-
366
351
  return res;
367
352
  }
368
353
 
@@ -751,6 +736,8 @@ ggml_tensor * llm_graph_context::build_ffn(
751
736
  cur = ggml_reglu(ctx0, cur);
752
737
  cb(cur, "ffn_reglu", il);
753
738
  } break;
739
+ default:
740
+ GGML_ABORT("fatal error");
754
741
  }
755
742
 
756
743
  if (gate && type_gate == LLM_FFN_PAR) {
@@ -760,8 +747,8 @@ ggml_tensor * llm_graph_context::build_ffn(
760
747
 
761
748
  if (down) {
762
749
  cur = build_lora_mm(down, cur);
763
- if (arch == LLM_ARCH_GLM4) {
764
- // GLM4 seems to have numerical issues with half-precision accumulators
750
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
751
+ // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
765
752
  ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
766
753
  }
767
754
  }
@@ -796,13 +783,64 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
796
783
  bool scale_w,
797
784
  float w_scale,
798
785
  llama_expert_gating_func_type gating_op,
799
- int il) const {
786
+ int il,
787
+ ggml_tensor * probs_in) const {
788
+ return build_moe_ffn(
789
+ cur,
790
+ gate_inp, /* gate_inp_b */ nullptr,
791
+ up_exps, /* up_exps_b */ nullptr,
792
+ gate_exps, /* gate_exps_b */ nullptr,
793
+ down_exps, /* down_exps_b */ nullptr,
794
+ exp_probs_b,
795
+ n_expert,
796
+ n_expert_used,
797
+ type_op,
798
+ norm_w,
799
+ scale_w,
800
+ w_scale,
801
+ gating_op,
802
+ il,
803
+ probs_in
804
+ );
805
+ }
806
+
807
+ ggml_tensor * llm_graph_context::build_moe_ffn(
808
+ ggml_tensor * cur,
809
+ ggml_tensor * gate_inp,
810
+ ggml_tensor * gate_inp_b,
811
+ ggml_tensor * up_exps,
812
+ ggml_tensor * up_exps_b,
813
+ ggml_tensor * gate_exps,
814
+ ggml_tensor * gate_exps_b,
815
+ ggml_tensor * down_exps,
816
+ ggml_tensor * down_exps_b,
817
+ ggml_tensor * exp_probs_b,
818
+ int64_t n_expert,
819
+ int64_t n_expert_used,
820
+ llm_ffn_op_type type_op,
821
+ bool norm_w,
822
+ bool scale_w,
823
+ float w_scale,
824
+ llama_expert_gating_func_type gating_op,
825
+ int il,
826
+ ggml_tensor * probs_in) const {
800
827
  const int64_t n_embd = cur->ne[0];
801
828
  const int64_t n_tokens = cur->ne[1];
802
829
  const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
803
830
 
804
- ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
805
- cb(logits, "ffn_moe_logits", il);
831
+ ggml_tensor * logits = nullptr;
832
+
833
+ if (probs_in == nullptr) {
834
+ logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
835
+ cb(logits, "ffn_moe_logits", il);
836
+ } else {
837
+ logits = probs_in;
838
+ }
839
+
840
+ if (gate_inp_b) {
841
+ logits = ggml_add(ctx0, logits, gate_inp_b);
842
+ cb(logits, "ffn_moe_logits_biased", il);
843
+ }
806
844
 
807
845
  ggml_tensor * probs = nullptr;
808
846
  switch (gating_op) {
@@ -814,6 +852,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
814
852
  {
815
853
  probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
816
854
  } break;
855
+ case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
856
+ {
857
+ probs = logits; // [n_expert, n_tokens]
858
+ } break;
817
859
  default:
818
860
  GGML_ABORT("fatal error");
819
861
  }
@@ -842,6 +884,13 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
842
884
  ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
843
885
  cb(weights, "ffn_moe_weights", il);
844
886
 
887
+ if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
888
+ weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
889
+ weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
890
+ weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
891
+ cb(weights, "ffn_moe_weights_softmax", il);
892
+ }
893
+
845
894
  if (norm_w) {
846
895
  weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
847
896
 
@@ -870,6 +919,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
870
919
  ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
871
920
  cb(up, "ffn_moe_up", il);
872
921
 
922
+ if (up_exps_b) {
923
+ up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
924
+ cb(up, "ffn_moe_up_biased", il);
925
+ }
926
+
873
927
  ggml_tensor * experts = nullptr;
874
928
  if (gate_exps) {
875
929
  cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
@@ -878,6 +932,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
878
932
  cur = up;
879
933
  }
880
934
 
935
+ if (gate_exps_b) {
936
+ cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
937
+ cb(cur, "ffn_moe_gate_biased", il);
938
+ }
939
+
881
940
  switch (type_op) {
882
941
  case LLM_FFN_SILU:
883
942
  if (gate_exps) {
@@ -895,6 +954,22 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
895
954
  cur = ggml_gelu(ctx0, cur);
896
955
  cb(cur, "ffn_moe_gelu", il);
897
956
  } break;
957
+ case LLM_FFN_SWIGLU_OAI_MOE:
958
+ {
959
+ // TODO: move to hparams?
960
+ constexpr float alpha = 1.702f;
961
+ constexpr float limit = 7.0f;
962
+ cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
963
+ cb(cur, "ffn_moe_swiglu_oai", il);
964
+ } break;
965
+ case LLM_FFN_RELU:
966
+ if (gate_exps) {
967
+ cur = ggml_reglu_split(ctx0, cur, up);
968
+ cb(cur, "ffn_moe_reglu", il);
969
+ } else {
970
+ cur = ggml_relu(ctx0, cur);
971
+ cb(cur, "ffn_moe_relu", il);
972
+ } break;
898
973
  default:
899
974
  GGML_ABORT("fatal error");
900
975
  }
@@ -902,6 +977,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
902
977
  experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
903
978
  cb(experts, "ffn_moe_down", il);
904
979
 
980
+ if (down_exps_b) {
981
+ experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
982
+ cb(experts, "ffn_moe_down_biased", il);
983
+ }
984
+
905
985
  if (!weight_before_ffn) {
906
986
  experts = ggml_mul(ctx0, experts, weights);
907
987
  cb(cur, "ffn_moe_weighted", il);
@@ -1102,7 +1182,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1102
1182
  }
1103
1183
 
1104
1184
  ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1105
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1185
+ const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1106
1186
 
1107
1187
  auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
1108
1188
 
@@ -1139,6 +1219,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1139
1219
  ggml_tensor * v,
1140
1220
  ggml_tensor * kq_b,
1141
1221
  ggml_tensor * kq_mask,
1222
+ ggml_tensor * sinks,
1142
1223
  ggml_tensor * v_mla,
1143
1224
  float kq_scale) const {
1144
1225
  const bool v_trans = v->nb[1] > v->nb[2];
@@ -1176,7 +1257,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1176
1257
  cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
1177
1258
  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
1178
1259
 
1179
- ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
1260
+ ggml_flash_attn_ext_add_sinks(cur, sinks);
1261
+ ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
1180
1262
 
1181
1263
  if (v_mla) {
1182
1264
  #if 0
@@ -1224,6 +1306,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1224
1306
  }
1225
1307
 
1226
1308
  kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
1309
+ ggml_soft_max_add_sinks(kq, sinks);
1227
1310
 
1228
1311
  if (!v_trans) {
1229
1312
  // note: avoid this branch
@@ -1273,6 +1356,7 @@ ggml_tensor * llm_graph_context::build_attn(
1273
1356
  ggml_tensor * k_cur,
1274
1357
  ggml_tensor * v_cur,
1275
1358
  ggml_tensor * kq_b,
1359
+ ggml_tensor * sinks,
1276
1360
  ggml_tensor * v_mla,
1277
1361
  float kq_scale,
1278
1362
  int il) const {
@@ -1288,13 +1372,13 @@ ggml_tensor * llm_graph_context::build_attn(
1288
1372
 
1289
1373
  // [TAG_NO_CACHE_PAD]
1290
1374
  // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1291
- assert(!ubatch.equal_seqs());
1375
+ assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
1292
1376
 
1293
1377
  ggml_tensor * q = q_cur;
1294
1378
  ggml_tensor * k = k_cur;
1295
1379
  ggml_tensor * v = v_cur;
1296
1380
 
1297
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1381
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
1298
1382
  cb(cur, "kqv_out", il);
1299
1383
 
1300
1384
  if (wo) {
@@ -1312,17 +1396,17 @@ ggml_tensor * llm_graph_context::build_attn(
1312
1396
  return cur;
1313
1397
  }
1314
1398
 
1315
- static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
1399
+ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
1316
1400
  ggml_context * ctx0,
1317
1401
  const llama_ubatch & ubatch,
1318
1402
  const llama_hparams & hparams,
1319
1403
  const llama_cparams & cparams,
1320
- const llama_kv_cache_unified_context * mctx_cur) {
1404
+ const llama_kv_cache_context * mctx_cur) {
1321
1405
 
1322
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
1406
+ auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
1323
1407
 
1324
1408
  {
1325
- GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1409
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
1326
1410
 
1327
1411
  const auto n_kv = mctx_cur->get_n_kv();
1328
1412
  const auto n_tokens = ubatch.n_tokens;
@@ -1340,22 +1424,23 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
1340
1424
  return inp;
1341
1425
  }
1342
1426
 
1343
- llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1344
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1427
+ llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
1428
+ const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1345
1429
 
1346
- auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1430
+ auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1347
1431
 
1348
- return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
1432
+ return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
1349
1433
  }
1350
1434
 
1351
1435
  ggml_tensor * llm_graph_context::build_attn(
1352
- llm_graph_input_attn_kv_unified * inp,
1436
+ llm_graph_input_attn_kv * inp,
1353
1437
  ggml_tensor * wo,
1354
1438
  ggml_tensor * wo_b,
1355
1439
  ggml_tensor * q_cur,
1356
1440
  ggml_tensor * k_cur,
1357
1441
  ggml_tensor * v_cur,
1358
1442
  ggml_tensor * kq_b,
1443
+ ggml_tensor * sinks,
1359
1444
  ggml_tensor * v_mla,
1360
1445
  float kq_scale,
1361
1446
  int il) const {
@@ -1382,13 +1467,13 @@ ggml_tensor * llm_graph_context::build_attn(
1382
1467
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1383
1468
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1384
1469
 
1385
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1470
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
1386
1471
  cb(cur, "kqv_out", il);
1387
1472
 
1388
1473
  if (wo) {
1389
1474
  cur = build_lora_mm(wo, cur);
1390
- if (arch == LLM_ARCH_GLM4) {
1391
- // GLM4 seems to have numerical issues with half-precision accumulators
1475
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
1476
+ // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
1392
1477
  ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1393
1478
  }
1394
1479
  }
@@ -1401,13 +1486,14 @@ ggml_tensor * llm_graph_context::build_attn(
1401
1486
  }
1402
1487
 
1403
1488
  ggml_tensor * llm_graph_context::build_attn(
1404
- llm_graph_input_attn_kv_unified_iswa * inp,
1489
+ llm_graph_input_attn_kv_iswa * inp,
1405
1490
  ggml_tensor * wo,
1406
1491
  ggml_tensor * wo_b,
1407
1492
  ggml_tensor * q_cur,
1408
1493
  ggml_tensor * k_cur,
1409
1494
  ggml_tensor * v_cur,
1410
1495
  ggml_tensor * kq_b,
1496
+ ggml_tensor * sinks,
1411
1497
  ggml_tensor * v_mla,
1412
1498
  float kq_scale,
1413
1499
  int il) const {
@@ -1448,7 +1534,7 @@ ggml_tensor * llm_graph_context::build_attn(
1448
1534
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1449
1535
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1450
1536
 
1451
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1537
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
1452
1538
  cb(cur, "kqv_out", il);
1453
1539
 
1454
1540
  if (wo) {
@@ -1487,6 +1573,7 @@ ggml_tensor * llm_graph_context::build_attn(
1487
1573
  ggml_tensor * k_cur,
1488
1574
  ggml_tensor * v_cur,
1489
1575
  ggml_tensor * kq_b,
1576
+ ggml_tensor * sinks,
1490
1577
  ggml_tensor * v_mla,
1491
1578
  float kq_scale,
1492
1579
  int il) const {
@@ -1502,7 +1589,7 @@ ggml_tensor * llm_graph_context::build_attn(
1502
1589
  ggml_tensor * k = k_cur;
1503
1590
  ggml_tensor * v = v_cur;
1504
1591
 
1505
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1592
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
1506
1593
  cb(cur, "kqv_out", il);
1507
1594
 
1508
1595
  if (wo) {
@@ -1523,10 +1610,10 @@ ggml_tensor * llm_graph_context::build_attn(
1523
1610
  // TODO: maybe separate the inner implementation into a separate function
1524
1611
  // like with the non-sliding window equivalent
1525
1612
  // once sliding-window hybrid caches are a thing.
1526
- llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1527
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1613
+ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
1614
+ const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
1528
1615
 
1529
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1616
+ auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
1530
1617
 
1531
1618
  const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1532
1619
 
@@ -1543,7 +1630,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1543
1630
  }
1544
1631
 
1545
1632
  {
1546
- GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1633
+ GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
1547
1634
 
1548
1635
  const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1549
1636
 
@@ -1556,21 +1643,22 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1556
1643
  inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1557
1644
  }
1558
1645
 
1559
- return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1646
+ return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
1560
1647
  }
1561
1648
 
1562
1649
  ggml_tensor * llm_graph_context::build_rs(
1563
1650
  ggml_tensor * s,
1564
- ggml_tensor * state_copy,
1651
+ ggml_tensor * state_copy_main,
1652
+ ggml_tensor * state_copy_extra,
1565
1653
  int32_t state_size,
1566
1654
  int32_t n_seqs,
1567
- uint32_t n_kv,
1568
- uint32_t kv_head,
1569
- uint32_t kv_size,
1655
+ uint32_t n_rs,
1656
+ uint32_t rs_head,
1657
+ uint32_t rs_size,
1570
1658
  int32_t rs_zero,
1571
1659
  const llm_graph_get_rows_fn & get_state_rows) const {
1572
1660
 
1573
- ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
1661
+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
1574
1662
 
1575
1663
  // Clear a single state which will then be copied to the other cleared states.
1576
1664
  // Note that this is a no-op when the view is zero-sized.
@@ -1578,39 +1666,44 @@ ggml_tensor * llm_graph_context::build_rs(
1578
1666
  ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1579
1667
 
1580
1668
  // copy states
1581
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1582
- // {state_size, kv_size} -> {state_size, n_seqs}
1583
- ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1669
+ // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
1670
+ // {state_size, rs_size} -> {state_size, n_seqs}
1671
+ ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
1584
1672
  ggml_build_forward_expand(gf, output_states);
1585
1673
 
1586
- // copy extra states which won't be changed further (between n_seqs and n_kv)
1587
- ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
1674
+ // copy extra states which won't be changed further (between n_seqs and n_rs)
1675
+ ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
1588
1676
  ggml_build_forward_expand(gf,
1589
1677
  ggml_cpy(ctx0,
1590
1678
  states_extra,
1591
- ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
1679
+ ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
1592
1680
 
1593
1681
  return output_states;
1594
1682
  }
1595
1683
 
1596
1684
  static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
1597
1685
  ggml_context * ctx0,
1686
+ const llama_ubatch & ubatch,
1598
1687
  const llama_memory_recurrent_context * mctx_cur) {
1599
1688
 
1600
1689
  auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1601
1690
 
1602
- const auto n_rs = mctx_cur->get_n_rs();
1691
+ const int64_t n_rs = mctx_cur->get_n_rs();
1692
+ const int64_t n_seqs = ubatch.n_seqs;
1603
1693
 
1604
1694
  inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1605
1695
  ggml_set_input(inp->s_copy);
1606
1696
 
1697
+ inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
1698
+ inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
1699
+
1607
1700
  return inp;
1608
1701
  }
1609
1702
 
1610
1703
  llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1611
1704
  const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1612
1705
 
1613
- auto inp = build_rs_inp_impl(ctx0, mctx_cur);
1706
+ auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
1614
1707
 
1615
1708
  return (llm_graph_input_rs *) res->add_input(std::move(inp));
1616
1709
  }
@@ -1623,7 +1716,9 @@ ggml_tensor * llm_graph_context::build_rs(
1623
1716
  const llm_graph_get_rows_fn & get_state_rows) const {
1624
1717
  const auto * kv_state = inp->mctx;
1625
1718
 
1626
- return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1719
+ return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
1720
+ kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
1721
+ get_state_rows);
1627
1722
  }
1628
1723
 
1629
1724
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
@@ -1670,8 +1765,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1670
1765
  llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1671
1766
  const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
1672
1767
 
1673
- auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
1674
- auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
1768
+ auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
1769
+ auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
1675
1770
 
1676
1771
  auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
1677
1772