@novastera-oss/llamarn 0.2.6 → 0.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (253) hide show
  1. package/android/src/main/cpp/include/llama.h +141 -38
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +58 -24
  11. package/cpp/LlamaCppModel.h +3 -3
  12. package/cpp/PureCppImpl.cpp +1 -1
  13. package/cpp/PureCppImpl.h +2 -2
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  16. package/cpp/llama.cpp/Makefile +2 -2
  17. package/cpp/llama.cpp/README.md +32 -13
  18. package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
  19. package/cpp/llama.cpp/common/arg.cpp +37 -6
  20. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  21. package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
  22. package/cpp/llama.cpp/common/chat-parser.h +2 -0
  23. package/cpp/llama.cpp/common/chat.cpp +12 -9
  24. package/cpp/llama.cpp/common/chat.h +1 -1
  25. package/cpp/llama.cpp/common/common.cpp +53 -40
  26. package/cpp/llama.cpp/common/common.h +6 -2
  27. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  28. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  29. package/cpp/llama.cpp/convert_hf_to_gguf.py +215 -76
  30. package/cpp/llama.cpp/ggml/CMakeLists.txt +48 -2
  31. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  32. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  33. package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
  34. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +64 -13
  35. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  38. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +124 -26
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +4 -3
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +93 -104
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +194 -69
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1158 -0
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1571 -0
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +213 -37
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +59 -37
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +90 -39
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  88. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
  90. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  91. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +260 -49
  93. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +497 -282
  94. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1078 -468
  97. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  105. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +20 -48
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  110. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  111. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +117 -165
  112. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +192 -53
  113. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  114. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  115. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  116. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
  117. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  118. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +8 -105
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +209 -92
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +36 -28
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +487 -247
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  138. package/cpp/llama.cpp/ggml/src/ggml.c +69 -19
  139. package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
  140. package/cpp/llama.cpp/gguf-py/gguf/constants.py +133 -0
  141. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +25 -1
  142. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +78 -3
  143. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
  144. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  145. package/cpp/llama.cpp/include/llama.h +141 -38
  146. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  147. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  148. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  149. package/cpp/llama.cpp/src/llama-arch.cpp +150 -3
  150. package/cpp/llama.cpp/src/llama-arch.h +25 -1
  151. package/cpp/llama.cpp/src/llama-batch.cpp +736 -274
  152. package/cpp/llama.cpp/src/llama-batch.h +110 -57
  153. package/cpp/llama.cpp/src/llama-chat.cpp +30 -8
  154. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  155. package/cpp/llama.cpp/src/llama-context.cpp +360 -266
  156. package/cpp/llama.cpp/src/llama-context.h +27 -23
  157. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  158. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  159. package/cpp/llama.cpp/src/llama-graph.cpp +411 -344
  160. package/cpp/llama.cpp/src/llama-graph.h +126 -58
  161. package/cpp/llama.cpp/src/llama-hparams.cpp +10 -2
  162. package/cpp/llama.cpp/src/llama-hparams.h +16 -2
  163. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +103 -73
  164. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +34 -42
  165. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +345 -221
  166. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +75 -50
  167. package/cpp/llama.cpp/src/llama-kv-cells.h +51 -22
  168. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
  169. package/cpp/llama.cpp/src/llama-memory-hybrid.h +138 -0
  170. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +302 -317
  171. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +60 -68
  172. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  173. package/cpp/llama.cpp/src/llama-memory.h +73 -36
  174. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  175. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  176. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  177. package/cpp/llama.cpp/src/llama-model.cpp +1630 -511
  178. package/cpp/llama.cpp/src/llama-model.h +26 -0
  179. package/cpp/llama.cpp/src/llama-quant.cpp +89 -6
  180. package/cpp/llama.cpp/src/llama-vocab.cpp +58 -26
  181. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  182. package/cpp/llama.cpp/src/llama.cpp +11 -7
  183. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  184. package/cpp/rn-completion.cpp +2 -2
  185. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  186. package/cpp/{rn-utils.hpp → rn-utils.h} +3 -0
  187. package/ios/include/chat.h +1 -1
  188. package/ios/include/common.h +6 -2
  189. package/ios/include/llama.h +141 -38
  190. package/ios/libs/llama.xcframework/Info.plist +15 -15
  191. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  192. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
  193. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  194. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
  195. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +141 -38
  196. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  197. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  198. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  199. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
  200. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  201. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  202. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  203. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  204. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  205. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  206. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3624
  207. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  208. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
  209. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +141 -38
  210. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  211. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
  212. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +141 -38
  213. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +141 -38
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  219. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  220. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
  221. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  222. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
  223. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +141 -38
  224. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  225. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  226. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  227. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
  228. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  229. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  231. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  232. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  233. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4725
  234. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  235. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
  236. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +141 -38
  237. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  238. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  239. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4746
  240. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3652
  241. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  242. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  243. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  244. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  245. package/package.json +1 -2
  246. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  247. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  248. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  249. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
  250. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
  251. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  252. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  253. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -6,7 +6,8 @@
6
6
 
7
7
  #include "llama-kv-cache-unified.h"
8
8
  #include "llama-kv-cache-unified-iswa.h"
9
- #include "llama-kv-cache-recurrent.h"
9
+ #include "llama-memory-hybrid.h"
10
+ #include "llama-memory-recurrent.h"
10
11
 
11
12
  #include <cassert>
12
13
  #include <cmath>
@@ -86,41 +87,33 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
86
87
 
87
88
  void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
88
89
  if (pos_bucket) {
89
- kv_state->set_input_pos_bucket(pos_bucket, ubatch);
90
+ mctx->set_input_pos_bucket(pos_bucket, ubatch);
90
91
  }
91
92
  }
92
93
 
93
94
  void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
94
- if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
95
- //GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
95
+ GGML_ASSERT(out_ids);
96
96
 
97
- if (!out_ids) {
98
- LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
99
- } else {
100
- const int64_t n_tokens = ubatch->n_tokens;
97
+ const int64_t n_tokens = ubatch->n_tokens;
101
98
 
102
- GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
103
- int32_t * data = (int32_t *) out_ids->data;
99
+ GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
100
+ int32_t * data = (int32_t *) out_ids->data;
104
101
 
105
- if (n_outputs == n_tokens) {
106
- for (int i = 0; i < n_tokens; ++i) {
107
- data[i] = i;
108
- }
109
- } else if (ubatch->output) {
110
- int32_t n_outputs = 0;
111
- for (int i = 0; i < n_tokens; ++i) {
112
- if (ubatch->output[i]) {
113
- data[n_outputs++] = i;
114
- }
115
- }
116
- // the graph needs to have been passed the correct number of outputs
117
- GGML_ASSERT(n_outputs == n_outputs);
118
- } else if (n_outputs == 1) {
119
- // only keep last output
120
- data[0] = n_tokens - 1;
121
- } else {
122
- GGML_ASSERT(n_outputs == 0);
123
- }
102
+ if (n_outputs == n_tokens) {
103
+ for (int i = 0; i < n_tokens; ++i) {
104
+ data[i] = i;
105
+ }
106
+
107
+ return;
108
+ }
109
+
110
+ GGML_ASSERT(ubatch->output);
111
+
112
+ int n_outputs = 0;
113
+
114
+ for (int i = 0; i < n_tokens; ++i) {
115
+ if (ubatch->output[i]) {
116
+ data[n_outputs++] = i;
124
117
  }
125
118
  }
126
119
  }
@@ -129,139 +122,114 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
129
122
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
130
123
  const int64_t n_tokens = ubatch->n_tokens;
131
124
  const int64_t n_seq_tokens = ubatch->n_seq_tokens;
132
- const int64_t n_seqs = ubatch->n_seqs;
125
+ const int64_t n_seqs_unq = ubatch->n_seqs_unq;
133
126
 
134
127
  GGML_ASSERT(mean);
135
128
  GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
136
129
 
137
130
  float * data = (float *) mean->data;
138
- memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
139
-
140
- std::vector<uint64_t> sum(n_tokens, 0);
141
-
142
- for (int s = 0; s < n_seqs; ++s) {
143
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
131
+ memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
144
132
 
145
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
146
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
133
+ std::vector<uint64_t> sums(n_seqs_unq, 0);
134
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
135
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
136
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
137
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
147
138
 
148
- sum[seq_id] += ubatch->n_seq_tokens;
139
+ sums[seq_idx] += ubatch->n_seq_tokens;
140
+ }
149
141
  }
150
142
 
151
- std::vector<float> div(n_tokens, 0.0f);
152
- for (int i = 0; i < n_tokens; ++i) {
153
- const uint64_t s = sum[i];
154
- if (s > 0) {
155
- div[i] = 1.0f/float(s);
143
+ std::vector<float> div(n_seqs_unq, 0.0f);
144
+ for (int s = 0; s < n_seqs_unq; ++s) {
145
+ const uint64_t sum = sums[s];
146
+ if (sum > 0) {
147
+ div[s] = 1.0f/float(sum);
156
148
  }
157
149
  }
158
150
 
159
- for (int s = 0; s < n_seqs; ++s) {
160
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
151
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
152
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
153
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
154
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
161
155
 
162
- for (int i = 0; i < n_seq_tokens; ++i) {
163
- data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
156
+ for (int j = 0; j < n_seq_tokens; ++j) {
157
+ data[seq_idx*n_tokens + i + j] = div[seq_idx];
158
+ }
164
159
  }
165
160
  }
166
161
  }
167
162
  }
168
163
 
169
164
  void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
170
- if (cparams.embeddings && (
171
- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
172
- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
173
- const int64_t n_tokens = ubatch->n_tokens;
174
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
175
- const int64_t n_seqs = ubatch->n_seqs;
165
+ const int64_t n_tokens = ubatch->n_tokens;
166
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
167
+ const int64_t n_seqs_unq = ubatch->n_seqs_unq;
176
168
 
169
+ if (cparams.embeddings && (
170
+ cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
171
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
172
+ )) {
177
173
  GGML_ASSERT(cls);
178
174
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
179
175
 
180
176
  uint32_t * data = (uint32_t *) cls->data;
181
- memset(cls->data, 0, n_tokens * ggml_element_size(cls));
182
-
183
- for (int s = 0; s < n_seqs; ++s) {
184
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
177
+ memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
185
178
 
186
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
187
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
179
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
180
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
181
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
182
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
188
183
 
189
- for (int i = 0; i < n_seq_tokens; ++i) {
190
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
191
-
192
- if (pos == 0) {
193
- data[seq_id] = s*n_seq_tokens + i;
194
- }
184
+ data[seq_idx] = i;
195
185
  }
196
186
  }
197
187
  }
198
188
 
199
189
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
200
- const int64_t n_tokens = ubatch->n_tokens;
201
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
202
- const int64_t n_seqs = ubatch->n_seqs;
203
-
204
190
  GGML_ASSERT(cls);
205
191
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
206
192
 
207
193
  uint32_t * data = (uint32_t *) cls->data;
208
- memset(cls->data, 0, n_tokens * ggml_element_size(cls));
209
-
210
- std::vector<int> last_pos(n_tokens, -1);
211
- std::vector<int> last_row(n_tokens, -1);
194
+ memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
212
195
 
213
- for (int s = 0; s < n_seqs; ++s) {
214
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
196
+ std::vector<int> last_pos(n_seqs_unq, -1);
197
+ std::vector<int> last_row(n_seqs_unq, -1);
215
198
 
216
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
217
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
199
+ for (int i = 0; i < n_tokens; ++i) {
200
+ const llama_pos pos = ubatch->pos[i];
218
201
 
219
- for (int i = 0; i < n_seq_tokens; ++i) {
220
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
202
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
203
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
204
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
221
205
 
222
- if (pos >= last_pos[seq_id]) {
223
- last_pos[seq_id] = pos;
224
- last_row[seq_id] = s*n_seq_tokens + i;
206
+ if (pos >= last_pos[seq_idx]) {
207
+ last_pos[seq_idx] = pos;
208
+ last_row[seq_idx] = i;
225
209
  }
226
210
  }
227
211
  }
228
212
 
229
- for (int i = 0; i < n_tokens; ++i) {
230
- if (last_row[i] >= 0) {
231
- data[i] = last_row[i];
213
+ for (int s = 0; s < n_seqs_unq; ++s) {
214
+ if (last_row[s] >= 0) {
215
+ data[s] = last_row[s];
232
216
  }
233
217
  }
234
218
  }
235
219
  }
236
220
 
237
- void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
221
+ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
238
222
  GGML_UNUSED(ubatch);
239
223
 
240
- const int64_t n_kv = kv_state->get_n_kv();
224
+ const int64_t n_rs = mctx->get_n_rs();
241
225
 
242
226
  if (s_copy) {
243
227
  GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
244
228
  int32_t * data = (int32_t *) s_copy->data;
245
229
 
246
230
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
247
- for (uint32_t i = 0; i < n_kv; ++i) {
248
- data[i] = kv_state->s_copy(i);
249
- }
250
- }
251
- }
252
-
253
- void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
254
- GGML_UNUSED(ubatch);
255
-
256
- const int64_t n_kv = kv_state->get_n_kv();
257
-
258
- if (s_mask) {
259
- GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
260
- float * data = (float *) s_mask->data;
261
-
262
- // clear unused states
263
- for (int i = 0; i < n_kv; ++i) {
264
- data[i] = kv_state->s_mask(i);
231
+ for (uint32_t i = 0; i < n_rs; ++i) {
232
+ data[i] = mctx->s_copy(i);
265
233
  }
266
234
  }
267
235
  }
@@ -277,87 +245,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
277
245
  }
278
246
 
279
247
  void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
280
- if (kq_mask) {
281
- if (cparams.causal_attn) {
282
- const int64_t n_kv = ubatch->n_tokens;
283
- const int64_t n_tokens = ubatch->n_tokens;
284
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
285
- const int64_t n_seqs = ubatch->n_seqs;
286
-
287
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
288
- float * data = (float *) kq_mask->data;
289
-
290
- for (int h = 0; h < 1; ++h) {
291
- for (int s1 = 0; s1 < n_seqs; ++s1) {
292
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
293
-
294
- for (int j = 0; j < n_seq_tokens; ++j) {
295
- const int32_t tj = s1*n_seq_tokens + j;
296
-
297
- for (int s0 = 0; s0 < n_seqs; ++s0) {
298
- for (int i = 0; i < n_seq_tokens; ++i) {
299
- const int32_t ti = s0*n_seq_tokens + i;
300
- float f = -INFINITY;
301
-
302
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
303
- if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
304
- if (hparams.use_alibi) {
305
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
306
- } else {
307
- f = 0.0f;
308
- }
309
- break;
310
- }
311
- }
312
-
313
- data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
314
- }
315
- }
316
- }
317
- }
318
- }
319
- } else {
320
- const int64_t n_tokens = ubatch->n_tokens;
321
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
322
- const int64_t n_seqs = ubatch->n_seqs;
323
- const int64_t n_stride = ubatch->n_tokens;
324
-
325
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
326
-
327
- float * data = (float *) kq_mask->data;
328
-
329
- for (int h = 0; h < 1; ++h) {
330
- for (int s1 = 0; s1 < n_seqs; ++s1) {
331
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
332
-
333
- for (int j = 0; j < n_seq_tokens; ++j) {
334
- const int32_t tj = s1*n_seq_tokens + j;
335
-
336
- for (int s0 = 0; s0 < n_seqs; ++s0) {
337
- for (int i = 0; i < n_seq_tokens; ++i) {
338
- const int32_t ti = s0*n_seq_tokens + i;
339
- float f = -INFINITY;
340
-
341
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
342
- if (ubatch->seq_id[s0][s] == seq_id) {
343
- if (hparams.use_alibi) {
344
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
345
- } else {
346
- f = 0.0f;
347
- }
348
- break;
349
- }
350
- }
351
-
352
- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
353
- }
354
- }
248
+ const int64_t n_kv = ubatch->n_tokens;
249
+ const int64_t n_tokens = ubatch->n_tokens;
250
+
251
+ GGML_ASSERT(kq_mask);
252
+ GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
253
+
254
+ float * data = (float *) kq_mask->data;
355
255
 
356
- for (int i = n_tokens; i < n_stride; ++i) {
357
- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
256
+ for (int h = 0; h < 1; ++h) {
257
+ for (int i1 = 0; i1 < n_tokens; ++i1) {
258
+ const llama_seq_id s1 = ubatch->seq_id[i1][0];
259
+
260
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
261
+ float f = -INFINITY;
262
+
263
+ for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
264
+ const llama_seq_id s0 = ubatch->seq_id[i0][0];
265
+
266
+ // TODO: reimplement this like in llama_kv_cache_unified
267
+ if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
268
+ if (hparams.use_alibi) {
269
+ f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
270
+ } else {
271
+ f = 0.0f;
358
272
  }
273
+ break;
359
274
  }
360
275
  }
276
+
277
+ data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
361
278
  }
362
279
  }
363
280
  }
@@ -365,53 +282,80 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
365
282
 
366
283
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
367
284
  if (self_kq_mask) {
368
- kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
285
+ mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
369
286
  }
370
287
  }
371
288
 
372
289
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
373
290
  if (self_kq_mask) {
374
- kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
291
+ mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
375
292
  }
376
293
 
377
294
  if (self_kq_mask_swa) {
378
- kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
295
+ mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
379
296
  }
380
297
  }
381
298
 
382
299
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
383
- if (cross_kq_mask) {
384
- const int64_t n_enc = cross_kq_mask->ne[0];
385
- const int64_t n_tokens = ubatch->n_tokens;
300
+ GGML_ASSERT(cross_kq_mask);
386
301
 
387
- GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
388
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
302
+ const int64_t n_enc = cross_kq_mask->ne[0];
303
+ const int64_t n_tokens = ubatch->n_tokens;
389
304
 
390
- float * data = (float *) cross_kq_mask->data;
305
+ GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
306
+ GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
391
307
 
392
- for (int h = 0; h < 1; ++h) {
393
- for (int j = 0; j < n_tokens; ++j) {
394
- for (int i = 0; i < n_enc; ++i) {
395
- float f = -INFINITY;
396
- for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
397
- const llama_seq_id seq_id = ubatch->seq_id[j][s];
398
- if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
399
- f = 0.0f;
400
- }
308
+ float * data = (float *) cross_kq_mask->data;
309
+
310
+ for (int h = 0; h < 1; ++h) {
311
+ for (int i = 0; i < n_tokens; ++i) {
312
+ for (int j = 0; j < n_enc; ++j) {
313
+ float f = -INFINITY;
314
+
315
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
316
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
317
+
318
+ if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
319
+ f = 0.0f;
401
320
  }
402
- data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
403
321
  }
322
+
323
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
404
324
  }
325
+ }
405
326
 
406
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
407
- for (int j = 0; j < n_enc; ++j) {
408
- data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
409
- }
327
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
328
+ for (int j = 0; j < n_enc; ++j) {
329
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
410
330
  }
411
331
  }
412
332
  }
413
333
  }
414
334
 
335
+ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
336
+ if (self_kq_mask) {
337
+ mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
338
+ }
339
+
340
+ const int64_t n_rs = mctx->get_recr()->get_n_rs();
341
+
342
+ if (s_copy) {
343
+ GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
344
+ int32_t * data = (int32_t *) s_copy->data;
345
+
346
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
347
+ for (uint32_t i = 0; i < n_rs; ++i) {
348
+ data[i] = mctx->get_recr()->s_copy(i);
349
+ }
350
+ }
351
+ }
352
+
353
+ void llm_graph_input_one::set_input(const llama_ubatch *) {
354
+ GGML_ASSERT(one && ggml_nelements(one) == 1);
355
+ float f_one = 1.0f;
356
+ ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
357
+ }
358
+
415
359
  //
416
360
  // llm_graph_context
417
361
  //
@@ -451,16 +395,12 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
451
395
  backend_cpu (params.backend_cpu),
452
396
  cvec (params.cvec),
453
397
  loras (params.loras),
454
- mstate (params.mstate),
398
+ mctx (params.mctx),
455
399
  cross (params.cross),
456
400
  cb_func (params.cb),
457
401
  res (std::make_unique<llm_graph_result>()) {
458
402
  }
459
403
 
460
- int64_t llm_graph_context::n_pos_per_embd() const {
461
- return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
462
- }
463
-
464
404
  void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
465
405
  if (cb_func) {
466
406
  cb_func(ubatch, cur, name, il);
@@ -650,6 +590,7 @@ ggml_tensor * llm_graph_context::build_ffn(
650
590
  {
651
591
  // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
652
592
  int64_t split_point = cur->ne[0] / 2;
593
+ // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
653
594
  ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
654
595
  ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
655
596
 
@@ -659,6 +600,20 @@ ggml_tensor * llm_graph_context::build_ffn(
659
600
  cur = ggml_mul(ctx0, x0, x1);
660
601
  cb(cur, "ffn_mul", il);
661
602
  } break;
603
+ case LLM_FFN_GEGLU:
604
+ {
605
+ // Split into two equal parts
606
+ int64_t split_point = cur->ne[0] / 2;
607
+ // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
608
+ ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
609
+ ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
610
+
611
+ x0 = ggml_gelu(ctx0, x0);
612
+ cb(x0, "ffn_gelu", il);
613
+
614
+ cur = ggml_mul(ctx0, x0, x1);
615
+ cb(cur, "ffn_geglu", il);
616
+ } break;
662
617
  }
663
618
 
664
619
  if (gate && type_gate == LLM_FFN_PAR) {
@@ -769,9 +724,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
769
724
  cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
770
725
 
771
726
  if (weight_before_ffn) {
772
- // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
773
- ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
774
- repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
727
+ // repeat cur to [n_embd, n_expert_used, n_tokens]
728
+ ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
775
729
  cur = ggml_mul(ctx0, repeated, weights);
776
730
  cb(cur, "ffn_moe_weighted", il);
777
731
  }
@@ -891,11 +845,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
891
845
  }
892
846
 
893
847
  ggml_tensor * llm_graph_context::build_inp_pos() const {
894
- auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
848
+ auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
895
849
 
896
850
  auto & cur = inp->pos;
897
851
 
898
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
852
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
899
853
  ggml_set_input(cur);
900
854
 
901
855
  res->add_input(std::move(inp));
@@ -918,6 +872,14 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
918
872
  }
919
873
 
920
874
  ggml_tensor * llm_graph_context::build_inp_out_ids() const {
875
+ // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
876
+ // but this would make the graph topology depend on the number of output tokens, which can interere with
877
+ // features that require constant topology such as pipline parallelism
878
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
879
+ //if (n_outputs < n_tokens) {
880
+ // return nullptr;
881
+ //}
882
+
921
883
  auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
922
884
 
923
885
  auto & cur = inp->out_ids;
@@ -935,7 +897,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
935
897
 
936
898
  auto & cur = inp->mean;
937
899
 
938
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
900
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
939
901
  ggml_set_input(cur);
940
902
 
941
903
  res->add_input(std::move(inp));
@@ -948,41 +910,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
948
910
 
949
911
  auto & cur = inp->cls;
950
912
 
951
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
952
- ggml_set_input(cur);
953
-
954
- res->add_input(std::move(inp));
955
-
956
- return cur;
957
- }
958
-
959
- ggml_tensor * llm_graph_context::build_inp_s_copy() const {
960
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
961
-
962
- auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
963
-
964
- const auto n_kv = kv_state->get_n_kv();
965
-
966
- auto & cur = inp->s_copy;
967
-
968
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
969
- ggml_set_input(cur);
970
-
971
- res->add_input(std::move(inp));
972
-
973
- return cur;
974
- }
975
-
976
- ggml_tensor * llm_graph_context::build_inp_s_mask() const {
977
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
978
-
979
- auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
980
-
981
- const auto n_kv = kv_state->get_n_kv();
982
-
983
- auto & cur = inp->s_mask;
984
-
985
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
913
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
986
914
  ggml_set_input(cur);
987
915
 
988
916
  res->add_input(std::move(inp));
@@ -1028,11 +956,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1028
956
  }
1029
957
 
1030
958
  ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1031
- const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
959
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1032
960
 
1033
- auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
961
+ auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
1034
962
 
1035
- const auto n_kv = kv_state->get_n_kv();
963
+ const auto n_kv = mctx_cur->get_n_kv();
1036
964
 
1037
965
  auto & cur = inp->pos_bucket;
1038
966
 
@@ -1059,6 +987,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
1059
987
  return pos_bias;
1060
988
  }
1061
989
 
990
+ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
991
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
992
+
993
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
994
+
995
+ {
996
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
997
+
998
+ const auto n_kv = inp->mctx->get_attn()->get_n_kv();
999
+
1000
+ inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1001
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1002
+ ggml_set_input(inp->self_kq_mask);
1003
+
1004
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1005
+ }
1006
+
1007
+ {
1008
+ const auto n_rs = mctx_cur->get_recr()->get_n_rs();
1009
+
1010
+ inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1011
+ ggml_set_input(inp->s_copy);
1012
+ }
1013
+
1014
+ return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1015
+ }
1016
+
1062
1017
  ggml_tensor * llm_graph_context::build_attn_mha(
1063
1018
  ggml_cgraph * gf,
1064
1019
  ggml_tensor * q,
@@ -1234,14 +1189,14 @@ ggml_tensor * llm_graph_context::build_attn(
1234
1189
  }
1235
1190
 
1236
1191
  llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1237
- const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1192
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1238
1193
 
1239
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
1194
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
1240
1195
 
1241
1196
  {
1242
1197
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1243
1198
 
1244
- const auto n_kv = kv_state->get_n_kv();
1199
+ const auto n_kv = mctx_cur->get_n_kv();
1245
1200
 
1246
1201
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1247
1202
  //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1271,19 +1226,19 @@ ggml_tensor * llm_graph_context::build_attn(
1271
1226
  ggml_build_forward_expand(gf, k_cur);
1272
1227
  ggml_build_forward_expand(gf, v_cur);
1273
1228
 
1274
- const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1229
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1275
1230
 
1276
1231
  // store to KV cache
1277
1232
  {
1278
- ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1279
- ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1233
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1234
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1280
1235
  }
1281
1236
 
1282
1237
  const auto & kq_mask = inp->get_kq_mask();
1283
1238
 
1284
1239
  ggml_tensor * q = q_cur;
1285
- ggml_tensor * k = kv_state->get_k(ctx0, il);
1286
- ggml_tensor * v = kv_state->get_v(ctx0, il);
1240
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1241
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1287
1242
 
1288
1243
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1289
1244
  cb(cur, "kqv_out", il);
@@ -1303,36 +1258,6 @@ ggml_tensor * llm_graph_context::build_attn(
1303
1258
  return cur;
1304
1259
  }
1305
1260
 
1306
- llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1307
- const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
1308
-
1309
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
1310
-
1311
- {
1312
- const auto n_kv = kv_state->get_base()->get_n_kv();
1313
-
1314
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1315
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1316
- ggml_set_input(inp->self_kq_mask);
1317
-
1318
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1319
- }
1320
-
1321
- {
1322
- GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1323
-
1324
- const auto n_kv = kv_state->get_swa()->get_n_kv();
1325
-
1326
- inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1327
- //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1328
- ggml_set_input(inp->self_kq_mask_swa);
1329
-
1330
- inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1331
- }
1332
-
1333
- return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1334
- }
1335
-
1336
1261
  ggml_tensor * llm_graph_context::build_attn(
1337
1262
  llm_graph_input_attn_kv_unified_iswa * inp,
1338
1263
  ggml_cgraph * gf,
@@ -1348,26 +1273,35 @@ ggml_tensor * llm_graph_context::build_attn(
1348
1273
  // these nodes are added to the graph together so that they are not reordered
1349
1274
  // by doing so, the number of splits in the graph is reduced
1350
1275
  ggml_build_forward_expand(gf, q_cur);
1351
- ggml_build_forward_expand(gf, k_cur);
1352
- ggml_build_forward_expand(gf, v_cur);
1353
1276
 
1354
- const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
1277
+ if (k_cur) {
1278
+ ggml_build_forward_expand(gf, k_cur);
1279
+ }
1280
+
1281
+ if (v_cur) {
1282
+ ggml_build_forward_expand(gf, v_cur);
1283
+ }
1284
+
1285
+ const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1355
1286
 
1356
1287
  const bool is_swa = hparams.is_swa(il);
1357
1288
 
1358
- const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
1289
+ const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
1359
1290
 
1360
- // store to KV cache
1361
- {
1362
- ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1363
- ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1291
+ // optionally store to KV cache
1292
+ if (k_cur) {
1293
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1294
+ }
1295
+
1296
+ if (v_cur) {
1297
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1364
1298
  }
1365
1299
 
1366
1300
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1367
1301
 
1368
1302
  ggml_tensor * q = q_cur;
1369
- ggml_tensor * k = kv_state->get_k(ctx0, il);
1370
- ggml_tensor * v = kv_state->get_v(ctx0, il);
1303
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1304
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1371
1305
 
1372
1306
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1373
1307
  cb(cur, "kqv_out", il);
@@ -1442,56 +1376,182 @@ ggml_tensor * llm_graph_context::build_attn(
1442
1376
  return cur;
1443
1377
  }
1444
1378
 
1445
- ggml_tensor * llm_graph_context::build_copy_mask_state(
1446
- ggml_cgraph * gf,
1447
- ggml_tensor * s,
1448
- ggml_tensor * state_copy,
1449
- ggml_tensor * state_mask,
1450
- int32_t n_state,
1451
- int32_t n_seqs) const {
1452
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1379
+ ggml_tensor * llm_graph_context::build_attn(
1380
+ llm_graph_input_mem_hybrid * inp,
1381
+ ggml_cgraph * gf,
1382
+ ggml_tensor * wo,
1383
+ ggml_tensor * wo_b,
1384
+ ggml_tensor * q_cur,
1385
+ ggml_tensor * k_cur,
1386
+ ggml_tensor * v_cur,
1387
+ ggml_tensor * kq_b,
1388
+ ggml_tensor * v_mla,
1389
+ float kq_scale,
1390
+ int il) const {
1391
+ // these nodes are added to the graph together so that they are not reordered
1392
+ // by doing so, the number of splits in the graph is reduced
1393
+ ggml_build_forward_expand(gf, q_cur);
1394
+ ggml_build_forward_expand(gf, k_cur);
1395
+ ggml_build_forward_expand(gf, v_cur);
1453
1396
 
1454
- const auto n_kv = kv_state->get_n_kv();
1455
- const auto kv_head = kv_state->get_head();
1397
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
1456
1398
 
1457
- ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
1399
+ // store to KV cache
1400
+ {
1401
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1402
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1403
+ }
1404
+
1405
+ const auto & kq_mask = inp->get_kq_mask();
1458
1406
 
1459
- // copy states
1460
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1461
- // this shrinks the tensors's ne[1] to n_kv
1462
- states = ggml_get_rows(ctx0, states, state_copy);
1407
+ ggml_tensor * q = q_cur;
1408
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1409
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1463
1410
 
1464
- // clear states of sequences which are starting at the beginning of this batch
1465
- // FIXME: zero-out NANs?
1466
- states = ggml_mul(ctx0, states, state_mask);
1411
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1412
+ cb(cur, "kqv_out", il);
1467
1413
 
1468
- // copy states which won't be changed further (between n_seqs and n_kv)
1414
+ if (wo) {
1415
+ cur = build_lora_mm(wo, cur);
1416
+ if (arch == LLM_ARCH_GLM4) {
1417
+ // GLM4 seems to have numerical issues with half-precision accumulators
1418
+ ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1419
+ }
1420
+ }
1421
+
1422
+ if (wo_b) {
1423
+ cur = ggml_add(ctx0, cur, wo_b);
1424
+ }
1425
+
1426
+ return cur;
1427
+ }
1428
+
1429
+ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1430
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1431
+
1432
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1433
+
1434
+ {
1435
+ const auto n_kv = mctx_cur->get_base()->get_n_kv();
1436
+
1437
+ inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1438
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1439
+ ggml_set_input(inp->self_kq_mask);
1440
+
1441
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1442
+ }
1443
+
1444
+ {
1445
+ GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1446
+
1447
+ const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1448
+
1449
+ inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1450
+ //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1451
+ ggml_set_input(inp->self_kq_mask_swa);
1452
+
1453
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1454
+ }
1455
+
1456
+ return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1457
+ }
1458
+
1459
+ ggml_tensor * llm_graph_context::build_rs(
1460
+ ggml_cgraph * gf,
1461
+ ggml_tensor * s,
1462
+ ggml_tensor * state_copy,
1463
+ int32_t state_size,
1464
+ int32_t n_seqs,
1465
+ uint32_t n_kv,
1466
+ uint32_t kv_head,
1467
+ uint32_t kv_size,
1468
+ int32_t rs_zero,
1469
+ bool avoid_copies) const {
1470
+
1471
+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
1472
+
1473
+ // Clear a single state which will then be copied to the other cleared states.
1474
+ // Note that this is a no-op when the view is zero-sized.
1475
+ ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1476
+ ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1477
+
1478
+ ggml_tensor * output_states;
1479
+
1480
+ if (!avoid_copies) {
1481
+ // copy states
1482
+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1483
+ // {state_size, kv_size} -> {state_size, n_seqs}
1484
+ output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1485
+ ggml_build_forward_expand(gf, output_states);
1486
+ } else {
1487
+ // FIXME: make the gathering operation happen before the copy below
1488
+ // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1489
+ output_states = states;
1490
+ }
1491
+
1492
+ // copy extra states which won't be changed further (between n_seqs and n_kv)
1493
+ ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
1469
1494
  ggml_build_forward_expand(gf,
1470
1495
  ggml_cpy(ctx0,
1471
- ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
1472
- ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
1496
+ states_extra,
1497
+ ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
1498
+
1499
+ return output_states;
1500
+ }
1501
+
1502
+ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1503
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1473
1504
 
1474
- // the part of the states that will be used and modified
1475
- return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
1505
+ auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1506
+
1507
+ const auto n_rs = mctx_cur->get_n_rs();
1508
+
1509
+ inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1510
+ ggml_set_input(inp->s_copy);
1511
+
1512
+ return (llm_graph_input_rs *) res->add_input(std::move(inp));
1513
+ }
1514
+
1515
+ ggml_tensor * llm_graph_context::build_rs(
1516
+ llm_graph_input_rs * inp,
1517
+ ggml_cgraph * gf,
1518
+ ggml_tensor * s,
1519
+ int32_t state_size,
1520
+ int32_t n_seqs,
1521
+ bool avoid_copies) const {
1522
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1523
+
1524
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1525
+ }
1526
+
1527
+ ggml_tensor * llm_graph_context::build_rs(
1528
+ llm_graph_input_mem_hybrid * inp,
1529
+ ggml_cgraph * gf,
1530
+ ggml_tensor * s,
1531
+ int32_t state_size,
1532
+ int32_t n_seqs,
1533
+ bool avoid_copies) const {
1534
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
1535
+
1536
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1476
1537
  }
1477
1538
 
1478
1539
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1479
- ggml_cgraph * gf,
1480
- ggml_tensor * state_copy,
1481
- ggml_tensor * state_mask,
1482
- const llama_ubatch & ubatch,
1540
+ llm_graph_input_rs * inp,
1541
+ ggml_cgraph * gf,
1542
+ const llama_ubatch & ubatch,
1483
1543
  int il) const {
1484
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1544
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1485
1545
 
1486
1546
  const auto token_shift_count = hparams.token_shift_count;
1487
1547
 
1488
1548
  const int64_t n_seqs = ubatch.n_seqs;
1489
1549
 
1490
- ggml_tensor * token_shift_all = kv_state->get_k_l(il);
1550
+ ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
1491
1551
 
1492
- ggml_tensor * token_shift = build_copy_mask_state(
1493
- gf, token_shift_all, state_copy, state_mask,
1494
- hparams.n_embd_k_s(), n_seqs);
1552
+ ggml_tensor * token_shift = build_rs(
1553
+ inp, gf, token_shift_all,
1554
+ hparams.n_embd_r(), n_seqs);
1495
1555
 
1496
1556
  token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
1497
1557
 
@@ -1502,19 +1562,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1502
1562
  ggml_tensor * token_shift,
1503
1563
  const llama_ubatch & ubatch,
1504
1564
  int il) const {
1505
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1565
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1506
1566
 
1507
1567
  const auto token_shift_count = hparams.token_shift_count;
1508
1568
  const auto n_embd = hparams.n_embd;
1509
1569
 
1510
1570
  const int64_t n_seqs = ubatch.n_seqs;
1511
1571
 
1512
- const auto kv_head = kv_state->get_head();
1572
+ const auto kv_head = mctx_cur->get_head();
1513
1573
 
1514
1574
  return ggml_cpy(
1515
1575
  ctx0,
1516
1576
  ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1517
- ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
1577
+ ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
1518
1578
  );
1519
1579
  }
1520
1580
 
@@ -1565,23 +1625,30 @@ void llm_graph_context::build_pooling(
1565
1625
  ggml_tensor * inp_cls = build_inp_cls();
1566
1626
  inp = ggml_get_rows(ctx0, inp, inp_cls);
1567
1627
 
1568
- if (cls != nullptr && cls_b != nullptr) {
1628
+ if (cls) {
1569
1629
  // classification head
1570
1630
  // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1571
- cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
1631
+ cur = ggml_mul_mat(ctx0, cls, inp);
1632
+ if (cls_b) {
1633
+ cur = ggml_add(ctx0, cur, cls_b);
1634
+ }
1572
1635
  cur = ggml_tanh(ctx0, cur);
1573
1636
 
1574
1637
  // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1575
1638
  // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1576
1639
  if (cls_out) {
1577
- GGML_ASSERT(cls_out_b != nullptr);
1578
- cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
1640
+ cur = ggml_mul_mat(ctx0, cls_out, cur);
1641
+ if (cls_out_b) {
1642
+ cur = ggml_add(ctx0, cur, cls_out_b);
1643
+ }
1579
1644
  }
1580
1645
  } else if (cls_out) {
1581
1646
  // Single layer classification head (direct projection)
1582
1647
  // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1583
- GGML_ASSERT(cls_out_b != nullptr);
1584
- cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
1648
+ cur = ggml_mul_mat(ctx0, cls_out, inp);
1649
+ if (cls_out_b) {
1650
+ cur = ggml_add(ctx0, cur, cls_out_b);
1651
+ }
1585
1652
  } else {
1586
1653
  GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
1587
1654
  }