cui-llama.rn 1.7.4 → 1.7.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. package/README.md +217 -17
  2. package/android/src/main/CMakeLists.txt +34 -15
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +79 -5
  4. package/android/src/main/java/com/rnllama/RNLlama.java +237 -0
  5. package/android/src/main/jni.cpp +213 -14
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  14. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +35 -0
  15. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +34 -0
  16. package/cpp/README.md +1 -1
  17. package/cpp/chat-parser.cpp +385 -0
  18. package/cpp/chat-parser.h +120 -0
  19. package/cpp/chat.cpp +726 -596
  20. package/cpp/chat.h +71 -6
  21. package/cpp/common.cpp +56 -38
  22. package/cpp/common.h +9 -3
  23. package/cpp/ggml-backend-reg.cpp +5 -0
  24. package/cpp/ggml-backend.cpp +10 -2
  25. package/cpp/ggml-common.h +4 -0
  26. package/cpp/ggml-cpu/amx/amx.cpp +1 -1
  27. package/cpp/ggml-cpu/amx/mmq.cpp +11 -10
  28. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  29. package/cpp/ggml-cpu/arch/arm/quants.c +4114 -0
  30. package/cpp/ggml-cpu/arch/arm/repack.cpp +2163 -0
  31. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  32. package/cpp/ggml-cpu/arch/x86/quants.c +4311 -0
  33. package/cpp/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  34. package/cpp/ggml-cpu/arch-fallback.h +184 -0
  35. package/cpp/ggml-cpu/common.h +4 -3
  36. package/cpp/ggml-cpu/ggml-cpu-impl.h +21 -16
  37. package/cpp/ggml-cpu/ggml-cpu.c +123 -104
  38. package/cpp/ggml-cpu/ggml-cpu.cpp +11 -8
  39. package/cpp/ggml-cpu/ops.cpp +330 -148
  40. package/cpp/ggml-cpu/ops.h +1 -0
  41. package/cpp/ggml-cpu/quants.c +1158 -0
  42. package/cpp/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  43. package/cpp/ggml-cpu/repack.cpp +1571 -0
  44. package/cpp/ggml-cpu/repack.h +98 -0
  45. package/cpp/ggml-cpu/simd-mappings.h +330 -38
  46. package/cpp/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  47. package/cpp/ggml-cpu/vec.cpp +87 -18
  48. package/cpp/ggml-cpu/vec.h +249 -94
  49. package/cpp/ggml-cpu.h +1 -0
  50. package/cpp/ggml-impl.h +63 -183
  51. package/cpp/ggml-llama-sim.metallib +0 -0
  52. package/cpp/ggml-llama.metallib +0 -0
  53. package/cpp/ggml-metal.m +152 -45
  54. package/cpp/ggml-quants.c +0 -2
  55. package/cpp/ggml.c +61 -21
  56. package/cpp/ggml.h +22 -3
  57. package/cpp/gguf.cpp +24 -3
  58. package/cpp/json-partial.cpp +256 -0
  59. package/cpp/json-partial.h +38 -0
  60. package/cpp/json-schema-to-grammar.cpp +5 -47
  61. package/cpp/json-schema-to-grammar.h +4 -4
  62. package/cpp/llama-arch.cpp +153 -3
  63. package/cpp/llama-arch.h +27 -1
  64. package/cpp/llama-batch.cpp +741 -272
  65. package/cpp/llama-batch.h +112 -54
  66. package/cpp/llama-chat.cpp +30 -8
  67. package/cpp/llama-chat.h +1 -0
  68. package/cpp/llama-context.cpp +524 -339
  69. package/cpp/llama-context.h +38 -17
  70. package/cpp/llama-cparams.cpp +4 -0
  71. package/cpp/llama-cparams.h +2 -0
  72. package/cpp/llama-grammar.cpp +12 -2
  73. package/cpp/llama-graph.cpp +431 -356
  74. package/cpp/llama-graph.h +126 -58
  75. package/cpp/llama-hparams.cpp +10 -2
  76. package/cpp/llama-hparams.h +19 -2
  77. package/cpp/llama-kv-cache-unified-iswa.cpp +279 -0
  78. package/cpp/llama-kv-cache-unified-iswa.h +128 -0
  79. package/cpp/llama-kv-cache-unified.cpp +1841 -0
  80. package/cpp/llama-kv-cache-unified.h +303 -0
  81. package/cpp/llama-kv-cells.h +439 -0
  82. package/cpp/llama-memory-hybrid.cpp +246 -0
  83. package/cpp/llama-memory-hybrid.h +138 -0
  84. package/cpp/llama-memory-recurrent.cpp +1112 -0
  85. package/cpp/llama-memory-recurrent.h +183 -0
  86. package/cpp/llama-memory.cpp +41 -0
  87. package/cpp/llama-memory.h +86 -5
  88. package/cpp/llama-mmap.cpp +1 -1
  89. package/cpp/llama-model-loader.cpp +42 -17
  90. package/cpp/llama-model-saver.cpp +1 -0
  91. package/cpp/llama-model.cpp +1639 -513
  92. package/cpp/llama-model.h +26 -0
  93. package/cpp/llama-sampling.cpp +2 -2
  94. package/cpp/llama-vocab.cpp +65 -28
  95. package/cpp/llama-vocab.h +1 -0
  96. package/cpp/llama.cpp +11 -7
  97. package/cpp/llama.h +150 -42
  98. package/cpp/minja/chat-template.hpp +1 -1
  99. package/cpp/minja/minja.hpp +1 -1
  100. package/cpp/{json.hpp → nlohmann/json.hpp} +3027 -2267
  101. package/cpp/nlohmann/json_fwd.hpp +187 -0
  102. package/cpp/regex-partial.cpp +204 -0
  103. package/cpp/regex-partial.h +56 -0
  104. package/cpp/rn-llama.cpp +646 -35
  105. package/cpp/rn-llama.h +32 -1
  106. package/cpp/rn-tts.h +39 -0
  107. package/cpp/sampling.cpp +7 -8
  108. package/cpp/tools/mtmd/clip-impl.h +5 -0
  109. package/cpp/tools/mtmd/clip.cpp +572 -436
  110. package/cpp/tools/mtmd/clip.h +14 -4
  111. package/cpp/tools/mtmd/mtmd-audio.cpp +0 -86
  112. package/cpp/tools/mtmd/mtmd-audio.h +2 -17
  113. package/cpp/tools/mtmd/mtmd-helper.cpp +175 -12
  114. package/cpp/tools/mtmd/mtmd-helper.h +91 -0
  115. package/cpp/tools/mtmd/mtmd.cpp +368 -248
  116. package/cpp/tools/mtmd/mtmd.h +6 -70
  117. package/cpp/unicode.cpp +5 -0
  118. package/ios/CMakeLists.txt +26 -6
  119. package/ios/RNLlama.h +1 -1
  120. package/ios/RNLlama.mm +153 -3
  121. package/ios/RNLlamaContext.h +9 -1
  122. package/ios/RNLlamaContext.mm +112 -9
  123. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
  124. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +71 -6
  125. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +9 -3
  126. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
  127. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
  128. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
  129. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +22 -3
  130. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-partial.h +38 -0
  131. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  132. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
  133. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
  134. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  135. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +38 -17
  136. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
  137. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
  138. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
  139. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  140. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  141. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  142. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  143. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  144. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
  145. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +26 -0
  146. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
  147. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +150 -42
  148. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/{json.hpp → nlohmann/json.hpp} +3027 -2267
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  157. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
  165. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  166. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
  167. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
  168. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  169. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
  170. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
  171. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
  172. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
  173. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  174. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  175. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  176. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  177. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  178. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
  179. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
  180. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
  181. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
  182. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  183. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
  184. package/ios/rnllama.xcframework/{tvos-arm64/rnllama.framework/Headers → ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
  185. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  186. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
  187. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
  188. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
  189. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  190. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  191. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
  192. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +71 -6
  193. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +9 -3
  194. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
  195. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
  196. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
  197. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +22 -3
  198. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-partial.h +38 -0
  199. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  200. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
  201. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
  202. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  203. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +38 -17
  204. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
  205. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
  206. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
  207. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  208. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  209. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  210. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  211. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  212. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
  213. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +26 -0
  214. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
  215. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +150 -42
  216. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  217. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
  218. package/ios/rnllama.xcframework/{ios-arm64_x86_64-simulator/rnllama.framework/Headers → tvos-arm64/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
  219. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  220. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
  221. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
  222. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
  223. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  224. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  225. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
  226. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
  227. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
  228. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
  229. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
  230. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
  231. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
  232. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
  233. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  234. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
  235. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
  236. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  237. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
  238. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
  239. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
  240. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
  241. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  242. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  243. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  244. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  245. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  246. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
  247. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
  248. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
  249. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
  250. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  251. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
  252. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json.hpp +25526 -0
  253. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  254. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
  255. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
  256. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
  257. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  258. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  259. package/jest/mock.js +24 -0
  260. package/package.json +1 -1
  261. package/src/NativeRNLlama.ts +46 -2
  262. package/src/index.ts +105 -1
  263. package/cpp/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  264. package/cpp/ggml-cpu/ggml-cpu-quants.c +0 -13326
  265. package/cpp/ggml-cpu/sgemm.cpp +0 -3544
  266. package/cpp/ggml-cpu/sgemm.h +0 -14
  267. package/cpp/llama-kv-cache.cpp +0 -2827
  268. package/cpp/llama-kv-cache.h +0 -515
  269. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  270. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  271. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  272. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +0 -24766
  273. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  274. /package/cpp/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  275. /package/cpp/tools/mtmd/{miniaudio.h → miniaudio/miniaudio.h} +0 -0
  276. /package/cpp/tools/mtmd/{stb_image.h → stb/stb_image.h} +0 -0
@@ -2,88 +2,146 @@
2
2
 
3
3
  #include "llama.h"
4
4
 
5
+ #include "llama-cparams.h"
6
+
5
7
  #include <array>
6
8
  #include <vector>
9
+ #include <set>
10
+ #include <bitset>
11
+ #include <unordered_map>
7
12
 
8
- // very similar to llama_batch,
9
- // but has more metadata about sequences
13
+ // keep this struct lightweight
14
+ // it points to data in `llama_batch_allocr`
10
15
  struct llama_ubatch {
11
16
  bool equal_seqs;
12
17
  // TODO: whole_seqs for embeddings?
13
18
 
14
- uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
15
- uint32_t n_seq_tokens; // tokens per sequence
16
- uint32_t n_seqs;
17
-
18
- llama_token * token; // [n_tokens]
19
- float * embd; // [n_embd, n_tokens]
20
- llama_pos * pos; // [n_tokens]
21
- int32_t * n_seq_id; // [n_seqs]
22
- llama_seq_id ** seq_id; // [n_seqs]
23
- int8_t * output; // [n_tokens]
19
+ uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
20
+ uint32_t n_seq_tokens; // tokens per sequence set
21
+ uint32_t n_seqs; // sequence sets in the ubatch
22
+ uint32_t n_seqs_unq; // unique sequence ids in the ubatch
23
+
24
+ // seq_id_unq: unique sequence ids in the ubatch
25
+ // seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
26
+ // used for extracting sequence pooled embeddings
27
+
28
+ // // size | idx | val
29
+ llama_token * token; // [n_tokens] | i | id, token
30
+ float * embd; // [n_embd, n_tokens] | i | embd
31
+ llama_pos * pos; // [n_tokens] | i | pos
32
+ int32_t * n_seq_id; // [n_tokens] | i | -
33
+ llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
34
+ llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
35
+ int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
36
+ int8_t * output; // [n_tokens] | i | -
24
37
  };
25
38
 
26
- struct llama_sbatch_seq {
27
- int32_t n_seq_id;
39
+ // a helper for sanitizing, fulfilling and splitting a batch
40
+ class llama_batch_allocr {
41
+ public:
42
+ llama_batch_allocr(uint32_t n_pos_per_embd);
28
43
 
29
- llama_seq_id * seq_id;
44
+ // sanitize and auto-gen missing data in the input batch
45
+ // memory is optional. if provided will be used to check for sequence continuity and to determine the positions
46
+ bool init(
47
+ const llama_batch & batch_inp,
48
+ const llama_vocab & vocab,
49
+ const llama_memory_i * memory,
50
+ uint32_t n_embd,
51
+ bool output_all);
30
52
 
31
- size_t offset;
32
- size_t length;
33
- };
53
+ const llama_batch & get_batch() const;
34
54
 
35
- // sequence-length-aware batch splitting
36
- struct llama_sbatch {
37
- // tokens left in this batch
38
- size_t n_tokens;
55
+ uint32_t get_n_tokens() const;
56
+ uint32_t get_n_outputs() const;
39
57
 
40
- size_t n_embd;
58
+ // the array of output indices in the order they were encountered during the ubatch splitting
59
+ std::vector<int32_t> & get_out_ids();
41
60
 
42
- bool logits_all; // TODO: remove once lctx.logits_all is removed too
61
+ // min/max positions of each sequence in the current ubatch
62
+ llama_pos seq_pos_min(llama_seq_id seq_id) const;
63
+ llama_pos seq_pos_max(llama_seq_id seq_id) const;
43
64
 
44
- // sorted indices into the batch
45
- std::vector<int64_t> ids;
46
- // batch indices of the output
47
- std::vector<int64_t> out_ids;
48
- std::vector<llama_sbatch_seq> seq;
65
+ // call once before splitting the batch to reset the internal state
66
+ void split_reset();
49
67
 
50
- const llama_batch * batch = nullptr;
68
+ // simple split, unknown number of sequence sets of unequal lengths
69
+ llama_ubatch split_simple(uint32_t n_ubatch);
51
70
 
52
- // buffers for the ubatch
53
- std::vector<llama_token> ubatch_token;
54
- std::vector<float> ubatch_embd;
55
- std::vector<llama_pos> ubatch_pos;
56
- std::vector<int32_t> ubatch_n_seq_id;
57
- std::vector<llama_seq_id *> ubatch_seq_id;
58
- std::vector<int8_t> ubatch_output;
71
+ // make ubatches of equal-length sequences sets
72
+ llama_ubatch split_equal(uint32_t n_ubatch);
59
73
 
60
- llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
74
+ // sequence-set-wise split - each ubatch contains a single sequence-set
75
+ llama_ubatch split_seq(uint32_t n_ubatch);
61
76
 
62
- void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
77
+ // a helper method for creating a well-defined ubatch of tokens
78
+ // TODO: support embeddings if needed in the future
79
+ llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
63
80
 
64
- // simple split, unknown number of sequences of unequal lengths
65
- llama_ubatch split_simple(size_t n_ubatch);
81
+ private:
82
+ void clear();
66
83
 
67
- // make batches of equal-length sequences
68
- llama_ubatch split_equal(size_t n_ubatch);
84
+ // create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
85
+ // return llama_ubatch.n_tokens == 0 if the entire batch was consumed
86
+ llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
69
87
 
70
- // sequence-wise split
71
- llama_ubatch split_seq(size_t n_ubatch);
88
+ // for debugging, start with LLAMA_BATCH_DEBUG=2
89
+ void ubatch_print(const llama_ubatch & ubatch, int debug);
72
90
 
73
- llama_sbatch() = default;
74
- llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
75
- };
91
+ llama_batch batch;
92
+
93
+ // only for debugging purposes
94
+ const llama_vocab * vocab;
95
+
96
+ // TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
97
+ // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
98
+ const uint32_t n_pos_per_embd;
76
99
 
77
- // temporary allocate memory for the input batch if needed
78
- struct llama_batch_allocr {
79
- struct llama_batch batch;
100
+ uint32_t n_embd;
101
+ uint32_t n_outputs;
80
102
 
81
103
  std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
104
+
82
105
  std::vector<llama_pos> pos;
83
106
  std::vector<int32_t> n_seq_id;
84
107
  std::vector<llama_seq_id *> seq_id;
85
- std::vector<int8_t> logits;
108
+ std::vector<llama_seq_id> seq_id_unq;
109
+ std::vector<int32_t> seq_idx;
110
+ std::vector<int8_t> output;
111
+
112
+ using pos_set_t = std::set<llama_pos>;
113
+ using seq_cpl_t = std::vector<bool>;
114
+
115
+ std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
116
+ std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
117
+
118
+ using idx_vec_t = std::vector<int32_t>;
119
+ using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
86
120
 
87
- // optionally fulfill the batch returned by llama_batch_get_one
88
- llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
121
+ std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
122
+
123
+ std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
124
+
125
+ // batch indices of the output
126
+ std::vector<int32_t> out_ids;
127
+
128
+ // used[i] indicates if token i has already been used in a previous ubatch
129
+ std::vector<bool> used;
130
+
131
+ // llama_ubatch points to this data:
132
+ struct ubatch {
133
+ std::vector<llama_token> token;
134
+ std::vector<float> embd;
135
+ std::vector<llama_pos> pos;
136
+ std::vector<int32_t> n_seq_id;
137
+ std::vector<llama_seq_id *> seq_id;
138
+ std::vector<llama_seq_id> seq_id_unq;
139
+ std::vector<int32_t> seq_idx;
140
+ std::vector<int8_t> output;
141
+ };
142
+
143
+ // current splitting state:
144
+ std::vector<ubatch> ubatches;
145
+
146
+ int debug;
89
147
  };
@@ -43,6 +43,7 @@ enum llm_chat_template {
43
43
  LLM_CHAT_TEMPLATE_BAILING,
44
44
  LLM_CHAT_TEMPLATE_LLAMA4,
45
45
  LLM_CHAT_TEMPLATE_SMOLVLM,
46
+ LLM_CHAT_TEMPLATE_DOTS1,
46
47
  LLM_CHAT_TEMPLATE_UNKNOWN,
47
48
  };
48
49
 
@@ -1,7 +1,6 @@
1
1
  #pragma once
2
2
 
3
3
  #include "llama.h"
4
- #include "llama-batch.h"
5
4
  #include "llama-cparams.h"
6
5
  #include "llama-graph.h"
7
6
  #include "llama-adapter.h"
@@ -13,11 +12,14 @@
13
12
  #include <vector>
14
13
 
15
14
  struct llama_model;
16
- struct llama_kv_cache;
15
+ class llama_batch_allocr;
17
16
 
18
17
  class llama_io_read_i;
19
18
  class llama_io_write_i;
20
19
 
20
+ struct llama_memory_i;
21
+ struct llama_memory_context_i;
22
+
21
23
  struct llama_context {
22
24
  // init scheduler and compute buffers, reserve worst-case graphs
23
25
  llama_context(
@@ -44,10 +46,12 @@ struct llama_context {
44
46
  uint32_t n_threads() const;
45
47
  uint32_t n_threads_batch() const;
46
48
 
47
- llama_kv_cache * get_kv_self();
48
- const llama_kv_cache * get_kv_self() const;
49
+ llama_memory_t get_memory() const;
49
50
 
50
- void kv_self_update();
51
+ // return true of the KV cache was updated
52
+ // TODO: remove
53
+ bool kv_self_update(bool optimize);
54
+ void kv_self_defrag_sched();
51
55
 
52
56
  enum llama_pooling_type pooling_type() const;
53
57
 
@@ -88,8 +92,18 @@ struct llama_context {
88
92
  int32_t il_start,
89
93
  int32_t il_end);
90
94
 
91
- int encode(llama_batch & inp_batch);
92
- int decode(llama_batch & inp_batch);
95
+ // process a single ubatch with a specific graph type
96
+ // if memory_context is provided, it will be applied first to the context's memory
97
+ // ret contains the status of the graph computation
98
+ // returns nullptr only if ret != LM_GGML_STATUS_SUCCESS
99
+ llm_graph_result_ptr process_ubatch(
100
+ const llama_ubatch & ubatch,
101
+ llm_graph_type gtype,
102
+ llama_memory_context_i * mctx,
103
+ lm_ggml_status & ret);
104
+
105
+ int encode(const llama_batch & batch_inp);
106
+ int decode(const llama_batch & batch_inp);
93
107
 
94
108
  //
95
109
  // state save/load
@@ -167,7 +181,7 @@ private:
167
181
 
168
182
  // Make sure enough space is available for outputs.
169
183
  // Returns max number of outputs for which space was reserved.
170
- int32_t output_reserve(int32_t n_outputs);
184
+ uint32_t output_reserve(int32_t n_outputs);
171
185
 
172
186
  //
173
187
  // graph
@@ -180,16 +194,18 @@ public:
180
194
  lm_ggml_cgraph * graph_init();
181
195
 
182
196
  // returns the result of lm_ggml_backend_sched_graph_compute_async execution
183
- lm_ggml_status graph_compute(
184
- lm_ggml_cgraph * gf,
185
- bool batched);
197
+ lm_ggml_status graph_compute(lm_ggml_cgraph * gf, bool batched);
198
+
199
+ // reserve a graph with a dummy ubatch of the specified size
200
+ lm_ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
186
201
 
187
202
  private:
188
203
  llm_graph_result_ptr graph_build(
189
- lm_ggml_context * ctx,
190
- lm_ggml_cgraph * gf,
191
- const llama_ubatch & ubatch,
192
- llm_graph_type gtype);
204
+ lm_ggml_context * ctx,
205
+ lm_ggml_cgraph * gf,
206
+ const llama_ubatch & ubatch,
207
+ llm_graph_type gtype,
208
+ const llama_memory_context_i * mctx);
193
209
 
194
210
  llm_graph_cb graph_get_cb() const;
195
211
 
@@ -214,6 +230,9 @@ private:
214
230
 
215
231
  std::unique_ptr<llama_memory_i> memory;
216
232
 
233
+ // TODO: temporary, until the llama_kv_self_defrag() API is removed
234
+ bool memory_force_optimize = false;
235
+
217
236
  // decode output (2-dimensional array: [n_outputs][n_vocab])
218
237
  size_t logits_size = 0; // capacity (of floats) for logits
219
238
  float * logits = nullptr;
@@ -227,8 +246,10 @@ private:
227
246
  // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
228
247
  std::map<llama_seq_id, std::vector<float>> embd_seq;
229
248
 
230
- int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
231
- int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers
249
+ // reuse the batch_allocr to avoid unnecessary memory allocations
250
+ std::unique_ptr<llama_batch_allocr> balloc;
251
+
252
+ uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
232
253
 
233
254
  std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
234
255
 
@@ -4,6 +4,8 @@
4
4
 
5
5
  #include <cstdint>
6
6
 
7
+ #define LLAMA_MAX_SEQ 64
8
+
7
9
  struct llama_cparams {
8
10
  uint32_t n_ctx; // context size used during inference
9
11
  uint32_t n_batch;
@@ -17,10 +17,12 @@ struct lm_ggml_tensor;
17
17
  struct llama_ubatch;
18
18
  struct llama_cparams;
19
19
 
20
- class llama_memory_i;
21
- class llama_kv_cache_unified;
22
- class llama_kv_cache_unified_iswa;
23
- class llama_kv_cache_recurrent;
20
+ struct llama_memory_context_i;
21
+
22
+ class llama_kv_cache_unified_context;
23
+ class llama_kv_cache_unified_iswa_context;
24
+ class llama_memory_recurrent_context;
25
+ class llama_memory_hybrid_context;
24
26
 
25
27
  // certain models (typically multi-modal) can produce different types of graphs
26
28
  enum llm_graph_type {
@@ -35,6 +37,7 @@ enum llm_ffn_op_type {
35
37
  LLM_FFN_RELU,
36
38
  LLM_FFN_RELU_SQR,
37
39
  LLM_FFN_SWIGLU,
40
+ LLM_FFN_GEGLU,
38
41
  };
39
42
 
40
43
  enum llm_ffn_gate_type {
@@ -92,14 +95,14 @@ public:
92
95
 
93
96
  class llm_graph_input_pos : public llm_graph_input_i {
94
97
  public:
95
- llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
98
+ llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
96
99
  virtual ~llm_graph_input_pos() = default;
97
100
 
98
101
  void set_input(const llama_ubatch * ubatch) override;
99
102
 
100
103
  lm_ggml_tensor * pos = nullptr; // I32 [n_batch]
101
104
 
102
- const int64_t n_pos_per_embd = 1;
105
+ const uint32_t n_pos_per_embd = 1;
103
106
  };
104
107
 
105
108
  // temperature tuning, used by llama4
@@ -133,7 +136,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
133
136
  public:
134
137
  llm_graph_input_pos_bucket_kv(
135
138
  const llama_hparams & hparams,
136
- const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
139
+ const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
137
140
  virtual ~llm_graph_input_pos_bucket_kv() = default;
138
141
 
139
142
  void set_input(const llama_ubatch * ubatch) override;
@@ -141,7 +144,8 @@ public:
141
144
  lm_ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
142
145
 
143
146
  const llama_hparams & hparams;
144
- const llama_kv_cache_unified * kv_self;
147
+
148
+ const llama_kv_cache_unified_context * mctx;
145
149
  };
146
150
 
147
151
  class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -186,28 +190,16 @@ public:
186
190
  const llama_cparams & cparams;
187
191
  };
188
192
 
189
- class llm_graph_input_s_copy : public llm_graph_input_i {
193
+ class llm_graph_input_rs : public llm_graph_input_i {
190
194
  public:
191
- llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
192
- virtual ~llm_graph_input_s_copy() = default;
195
+ llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
196
+ virtual ~llm_graph_input_rs() = default;
193
197
 
194
198
  void set_input(const llama_ubatch * ubatch) override;
195
199
 
196
200
  lm_ggml_tensor * s_copy; // I32 [kv_size]
197
201
 
198
- const llama_kv_cache_recurrent * kv_self;
199
- };
200
-
201
- class llm_graph_input_s_mask : public llm_graph_input_i {
202
- public:
203
- llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
204
- virtual ~llm_graph_input_s_mask() = default;
205
-
206
- void set_input(const llama_ubatch * ubatch) override;
207
-
208
- lm_ggml_tensor * s_mask; // F32 [1, n_kv]
209
-
210
- const llama_kv_cache_recurrent * kv_self;
202
+ const llama_memory_recurrent_context * mctx;
211
203
  };
212
204
 
213
205
  class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -247,10 +239,10 @@ public:
247
239
  llm_graph_input_attn_kv_unified(
248
240
  const llama_hparams & hparams,
249
241
  const llama_cparams & cparams,
250
- const llama_kv_cache_unified * kv_self) :
242
+ const llama_kv_cache_unified_context * mctx) :
251
243
  hparams(hparams),
252
244
  cparams(cparams),
253
- kv_self(kv_self) {
245
+ mctx(mctx) {
254
246
  }
255
247
  ~llm_graph_input_attn_kv_unified() = default;
256
248
 
@@ -264,7 +256,7 @@ public:
264
256
  const llama_hparams & hparams;
265
257
  const llama_cparams & cparams;
266
258
 
267
- const llama_kv_cache_unified * kv_self;
259
+ const llama_kv_cache_unified_context * mctx;
268
260
  };
269
261
 
270
262
  class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
@@ -272,10 +264,10 @@ public:
272
264
  llm_graph_input_attn_kv_unified_iswa(
273
265
  const llama_hparams & hparams,
274
266
  const llama_cparams & cparams,
275
- const llama_kv_cache_unified_iswa * kv_self) :
267
+ const llama_kv_cache_unified_iswa_context * mctx) :
276
268
  hparams(hparams),
277
269
  cparams(cparams),
278
- kv_self(kv_self) {
270
+ mctx(mctx) {
279
271
  }
280
272
  ~llm_graph_input_attn_kv_unified_iswa() = default;
281
273
 
@@ -292,7 +284,7 @@ public:
292
284
  const llama_hparams & hparams;
293
285
  const llama_cparams & cparams;
294
286
 
295
- const llama_kv_cache_unified_iswa * kv_self;
287
+ const llama_kv_cache_unified_iswa_context * mctx;
296
288
  };
297
289
 
298
290
  class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -310,6 +302,44 @@ public:
310
302
  const llama_cross * cross = nullptr;
311
303
  };
312
304
 
305
+ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
306
+ public:
307
+ llm_graph_input_mem_hybrid(
308
+ const llama_hparams & hparams,
309
+ const llama_cparams & cparams,
310
+ const llama_memory_hybrid_context * mctx) :
311
+ hparams(hparams),
312
+ cparams(cparams),
313
+ mctx(mctx) {
314
+ }
315
+ virtual ~llm_graph_input_mem_hybrid() = default;
316
+
317
+ void set_input(const llama_ubatch * ubatch) override;
318
+
319
+ lm_ggml_tensor * s_copy; // I32 [kv_size]
320
+
321
+ lm_ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
322
+
323
+ lm_ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
324
+ lm_ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
325
+
326
+ const llama_hparams & hparams;
327
+ const llama_cparams & cparams;
328
+
329
+ const llama_memory_hybrid_context * mctx;
330
+ };
331
+
332
+ // TODO: remove this when lm_ggml_scale_add is implemented
333
+ class llm_graph_input_one : public llm_graph_input_i {
334
+ public:
335
+ llm_graph_input_one() {}
336
+ virtual ~llm_graph_input_one() = default;
337
+
338
+ void set_input(const llama_ubatch *) override;
339
+
340
+ lm_ggml_tensor * one = nullptr; // F32
341
+ };
342
+
313
343
  //
314
344
  // llm_graph_result
315
345
  //
@@ -383,12 +413,12 @@ struct llm_graph_params {
383
413
  lm_ggml_backend_sched_t sched;
384
414
  lm_ggml_backend_t backend_cpu;
385
415
 
386
- const llama_adapter_cvec * cvec;
387
- const llama_adapter_loras * loras;
388
- const llama_memory_i * memory;
389
- const llama_cross * cross;
416
+ const llama_adapter_cvec * cvec;
417
+ const llama_adapter_loras * loras;
418
+ const llama_memory_context_i * mctx;
419
+ const llama_cross * cross;
390
420
 
391
- int32_t n_outputs;
421
+ uint32_t n_outputs;
392
422
 
393
423
  const llm_graph_cb & cb;
394
424
  };
@@ -422,8 +452,8 @@ struct llm_graph_context {
422
452
  const float norm_eps;
423
453
  const float norm_rms_eps;
424
454
 
425
- const int32_t n_tokens;
426
- const int32_t n_outputs;
455
+ const int64_t n_tokens;
456
+ const int64_t n_outputs;
427
457
  const int32_t n_ctx_orig; // yarn
428
458
 
429
459
  const enum llama_pooling_type pooling_type;
@@ -435,10 +465,10 @@ struct llm_graph_context {
435
465
 
436
466
  lm_ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
437
467
 
438
- const llama_adapter_cvec * cvec;
439
- const llama_adapter_loras * loras;
440
- const llama_memory_i * memory;
441
- const llama_cross * cross;
468
+ const llama_adapter_cvec * cvec;
469
+ const llama_adapter_loras * loras;
470
+ const llama_memory_context_i * mctx;
471
+ const llama_cross * cross;
442
472
 
443
473
  const llm_graph_cb & cb_func;
444
474
 
@@ -446,8 +476,6 @@ struct llm_graph_context {
446
476
 
447
477
  llm_graph_context(const llm_graph_params & params);
448
478
 
449
- int64_t n_pos_per_embd() const;
450
-
451
479
  void cb(lm_ggml_tensor * cur, const char * name, int il) const;
452
480
 
453
481
  //
@@ -518,14 +546,14 @@ struct llm_graph_context {
518
546
  lm_ggml_tensor * build_inp_out_ids() const;
519
547
  lm_ggml_tensor * build_inp_mean() const;
520
548
  lm_ggml_tensor * build_inp_cls() const;
521
- lm_ggml_tensor * build_inp_s_copy() const;
522
- lm_ggml_tensor * build_inp_s_mask() const;
523
549
 
524
550
  lm_ggml_tensor * build_inp_cross_embd() const;
525
551
  lm_ggml_tensor * build_inp_pos_bucket_enc() const;
526
552
  lm_ggml_tensor * build_inp_pos_bucket_dec() const;
527
553
  lm_ggml_tensor * build_pos_bias(lm_ggml_tensor * pos_bucket, lm_ggml_tensor * attn_rel_b) const;
528
554
 
555
+ llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
556
+
529
557
  //
530
558
  // attention
531
559
  //
@@ -572,14 +600,15 @@ struct llm_graph_context {
572
600
 
573
601
  llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
574
602
 
603
+ // note: if k_cur or v_cur are not provided, they will not be stored in the memory
575
604
  lm_ggml_tensor * build_attn(
576
605
  llm_graph_input_attn_kv_unified_iswa * inp,
577
606
  lm_ggml_cgraph * gf,
578
607
  lm_ggml_tensor * wo,
579
608
  lm_ggml_tensor * wo_b,
580
609
  lm_ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
581
- lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
582
- lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
610
+ lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
611
+ lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
583
612
  lm_ggml_tensor * kq_b,
584
613
  lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
585
614
  float kq_scale,
@@ -600,23 +629,62 @@ struct llm_graph_context {
600
629
  float kq_scale,
601
630
  int il) const;
602
631
 
632
+ lm_ggml_tensor * build_attn(
633
+ llm_graph_input_mem_hybrid * inp,
634
+ lm_ggml_cgraph * gf,
635
+ lm_ggml_tensor * wo,
636
+ lm_ggml_tensor * wo_b,
637
+ lm_ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
638
+ lm_ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
639
+ lm_ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
640
+ lm_ggml_tensor * kq_b,
641
+ lm_ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
642
+ float kq_scale,
643
+ int il) const;
603
644
  //
604
645
  // recurrent
605
646
  //
606
647
 
607
- lm_ggml_tensor * build_copy_mask_state(
608
- lm_ggml_cgraph * gf,
609
- lm_ggml_tensor * s,
610
- lm_ggml_tensor * state_copy,
611
- lm_ggml_tensor * state_mask,
612
- int32_t n_state,
613
- int32_t n_seqs) const;
648
+ // TODO: avoid notion of "kv"
649
+ // TODO: move this implementation to llama_memory_recurrent.
650
+ // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
651
+ // when moving, avoid passing `lm_ggml_cgraph` - only pass `lm_ggml_context`. would likely need to split the
652
+ // implementation in 2 separate methods. the goal is to avoid calling `lm_ggml_build_forward_expand` in
653
+ // `llama_memory_recurrent`
654
+ lm_ggml_tensor * build_rs(
655
+ lm_ggml_cgraph * gf,
656
+ lm_ggml_tensor * s,
657
+ lm_ggml_tensor * state_copy,
658
+ int32_t state_size,
659
+ int32_t n_seqs,
660
+ uint32_t n_kv,
661
+ uint32_t kv_head,
662
+ uint32_t kv_size,
663
+ int32_t rs_zero,
664
+ bool avoid_copies = false) const;
665
+
666
+ llm_graph_input_rs * build_rs_inp() const;
667
+
668
+ lm_ggml_tensor * build_rs(
669
+ llm_graph_input_rs * inp,
670
+ lm_ggml_cgraph * gf,
671
+ lm_ggml_tensor * s,
672
+ int32_t state_size,
673
+ int32_t n_seqs,
674
+ bool avoid_copies = false) const;
675
+
676
+ lm_ggml_tensor * build_rs(
677
+ llm_graph_input_mem_hybrid * inp,
678
+ lm_ggml_cgraph * gf,
679
+ lm_ggml_tensor * s,
680
+ int32_t state_size,
681
+ int32_t n_seqs,
682
+ bool avoid_copies = false) const;
614
683
 
615
684
  lm_ggml_tensor * build_rwkv_token_shift_load(
616
- lm_ggml_cgraph * gf,
617
- lm_ggml_tensor * state_copy,
618
- lm_ggml_tensor * state_mask,
619
- const llama_ubatch & ubatch,
685
+ llm_graph_input_rs * inp,
686
+ lm_ggml_cgraph * gf,
687
+ const llama_ubatch & ubatch,
620
688
  int il) const;
621
689
 
622
690
  lm_ggml_tensor * build_rwkv_token_shift_store(