cui-llama.rn 1.7.3 → 1.7.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. package/README.md +217 -17
  2. package/android/src/main/CMakeLists.txt +34 -15
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +94 -8
  4. package/android/src/main/java/com/rnllama/RNLlama.java +247 -0
  5. package/android/src/main/jni.cpp +213 -14
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  14. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +35 -0
  15. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +34 -0
  16. package/cpp/README.md +1 -1
  17. package/cpp/chat-parser.cpp +385 -0
  18. package/cpp/chat-parser.h +120 -0
  19. package/cpp/chat.cpp +726 -596
  20. package/cpp/chat.h +71 -6
  21. package/cpp/common.cpp +56 -38
  22. package/cpp/common.h +9 -3
  23. package/cpp/ggml-backend-reg.cpp +5 -0
  24. package/cpp/ggml-backend.cpp +10 -2
  25. package/cpp/ggml-common.h +4 -0
  26. package/cpp/ggml-cpu/amx/amx.cpp +1 -1
  27. package/cpp/ggml-cpu/amx/mmq.cpp +11 -10
  28. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  29. package/cpp/ggml-cpu/arch/arm/quants.c +4114 -0
  30. package/cpp/ggml-cpu/arch/arm/repack.cpp +2163 -0
  31. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  32. package/cpp/ggml-cpu/arch/x86/quants.c +4311 -0
  33. package/cpp/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  34. package/cpp/ggml-cpu/arch-fallback.h +184 -0
  35. package/cpp/ggml-cpu/common.h +4 -3
  36. package/cpp/ggml-cpu/ggml-cpu-impl.h +21 -16
  37. package/cpp/ggml-cpu/ggml-cpu.c +123 -104
  38. package/cpp/ggml-cpu/ggml-cpu.cpp +11 -8
  39. package/cpp/ggml-cpu/ops.cpp +330 -148
  40. package/cpp/ggml-cpu/ops.h +1 -0
  41. package/cpp/ggml-cpu/quants.c +1158 -0
  42. package/cpp/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  43. package/cpp/ggml-cpu/repack.cpp +1571 -0
  44. package/cpp/ggml-cpu/repack.h +98 -0
  45. package/cpp/ggml-cpu/simd-mappings.h +330 -38
  46. package/cpp/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  47. package/cpp/ggml-cpu/vec.cpp +87 -18
  48. package/cpp/ggml-cpu/vec.h +249 -94
  49. package/cpp/ggml-cpu.h +1 -0
  50. package/cpp/ggml-impl.h +63 -183
  51. package/cpp/ggml-llama-sim.metallib +0 -0
  52. package/cpp/ggml-llama.metallib +0 -0
  53. package/cpp/ggml-metal.m +152 -45
  54. package/cpp/ggml-quants.c +0 -2
  55. package/cpp/ggml.c +61 -21
  56. package/cpp/ggml.h +22 -3
  57. package/cpp/gguf.cpp +24 -3
  58. package/cpp/json-partial.cpp +256 -0
  59. package/cpp/json-partial.h +38 -0
  60. package/cpp/json-schema-to-grammar.cpp +5 -47
  61. package/cpp/json-schema-to-grammar.h +4 -4
  62. package/cpp/llama-arch.cpp +153 -3
  63. package/cpp/llama-arch.h +27 -1
  64. package/cpp/llama-batch.cpp +741 -272
  65. package/cpp/llama-batch.h +112 -54
  66. package/cpp/llama-chat.cpp +30 -8
  67. package/cpp/llama-chat.h +1 -0
  68. package/cpp/llama-context.cpp +524 -339
  69. package/cpp/llama-context.h +38 -17
  70. package/cpp/llama-cparams.cpp +4 -0
  71. package/cpp/llama-cparams.h +2 -0
  72. package/cpp/llama-grammar.cpp +12 -2
  73. package/cpp/llama-graph.cpp +431 -356
  74. package/cpp/llama-graph.h +126 -58
  75. package/cpp/llama-hparams.cpp +10 -2
  76. package/cpp/llama-hparams.h +19 -2
  77. package/cpp/llama-kv-cache-unified-iswa.cpp +279 -0
  78. package/cpp/llama-kv-cache-unified-iswa.h +128 -0
  79. package/cpp/llama-kv-cache-unified.cpp +1841 -0
  80. package/cpp/llama-kv-cache-unified.h +303 -0
  81. package/cpp/llama-kv-cells.h +439 -0
  82. package/cpp/llama-memory-hybrid.cpp +246 -0
  83. package/cpp/llama-memory-hybrid.h +138 -0
  84. package/cpp/llama-memory-recurrent.cpp +1112 -0
  85. package/cpp/llama-memory-recurrent.h +183 -0
  86. package/cpp/llama-memory.cpp +41 -0
  87. package/cpp/llama-memory.h +86 -5
  88. package/cpp/llama-mmap.cpp +1 -1
  89. package/cpp/llama-model-loader.cpp +42 -17
  90. package/cpp/llama-model-saver.cpp +1 -0
  91. package/cpp/llama-model.cpp +1639 -513
  92. package/cpp/llama-model.h +26 -0
  93. package/cpp/llama-sampling.cpp +2 -2
  94. package/cpp/llama-vocab.cpp +65 -28
  95. package/cpp/llama-vocab.h +1 -0
  96. package/cpp/llama.cpp +11 -7
  97. package/cpp/llama.h +150 -42
  98. package/cpp/minja/chat-template.hpp +1 -1
  99. package/cpp/minja/minja.hpp +1 -1
  100. package/cpp/{json.hpp → nlohmann/json.hpp} +3027 -2267
  101. package/cpp/nlohmann/json_fwd.hpp +187 -0
  102. package/cpp/regex-partial.cpp +204 -0
  103. package/cpp/regex-partial.h +56 -0
  104. package/cpp/rn-llama.cpp +646 -35
  105. package/cpp/rn-llama.h +32 -1
  106. package/cpp/rn-tts.h +39 -0
  107. package/cpp/sampling.cpp +7 -8
  108. package/cpp/tools/mtmd/clip-impl.h +5 -0
  109. package/cpp/tools/mtmd/clip.cpp +572 -436
  110. package/cpp/tools/mtmd/clip.h +14 -4
  111. package/cpp/tools/mtmd/mtmd-audio.cpp +0 -86
  112. package/cpp/tools/mtmd/mtmd-audio.h +2 -17
  113. package/cpp/tools/mtmd/mtmd-helper.cpp +175 -12
  114. package/cpp/tools/mtmd/mtmd-helper.h +91 -0
  115. package/cpp/tools/mtmd/mtmd.cpp +368 -248
  116. package/cpp/tools/mtmd/mtmd.h +6 -70
  117. package/cpp/unicode.cpp +5 -0
  118. package/ios/CMakeLists.txt +26 -6
  119. package/ios/RNLlama.h +1 -1
  120. package/ios/RNLlama.mm +153 -3
  121. package/ios/RNLlamaContext.h +9 -1
  122. package/ios/RNLlamaContext.mm +112 -9
  123. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
  124. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +71 -6
  125. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +9 -3
  126. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
  127. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
  128. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
  129. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +22 -3
  130. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-partial.h +38 -0
  131. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  132. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
  133. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
  134. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  135. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +38 -17
  136. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
  137. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
  138. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
  139. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  140. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  141. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  142. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  143. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  144. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
  145. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +26 -0
  146. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
  147. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +150 -42
  148. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/{json.hpp → nlohmann/json.hpp} +3027 -2267
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  157. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
  165. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  166. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
  167. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
  168. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  169. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
  170. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
  171. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
  172. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
  173. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  174. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  175. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  176. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  177. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  178. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
  179. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
  180. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
  181. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
  182. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  183. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
  184. package/ios/rnllama.xcframework/{tvos-arm64/rnllama.framework/Headers → ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
  185. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  186. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
  187. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
  188. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
  189. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  190. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  191. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
  192. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +71 -6
  193. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +9 -3
  194. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
  195. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
  196. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
  197. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +22 -3
  198. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-partial.h +38 -0
  199. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  200. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
  201. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
  202. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
  203. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +38 -17
  204. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
  205. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
  206. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
  207. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  208. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  209. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  210. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  211. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  212. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
  213. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +26 -0
  214. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
  215. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +150 -42
  216. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  217. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
  218. package/ios/rnllama.xcframework/{ios-arm64_x86_64-simulator/rnllama.framework/Headers → tvos-arm64/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
  219. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  220. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
  221. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
  222. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
  223. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  224. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  225. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
  226. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
  227. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
  228. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
  229. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
  230. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
  231. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
  232. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
  233. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
  234. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
  235. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
  236. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
  237. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
  238. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
  239. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
  240. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
  241. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
  242. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
  243. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
  244. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
  245. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
  246. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
  247. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
  248. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
  249. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
  250. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
  251. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
  252. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json.hpp +25526 -0
  253. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
  254. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
  255. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
  256. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
  257. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  258. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  259. package/jest/mock.js +24 -0
  260. package/package.json +1 -1
  261. package/src/NativeRNLlama.ts +46 -2
  262. package/src/index.ts +105 -1
  263. package/cpp/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  264. package/cpp/ggml-cpu/ggml-cpu-quants.c +0 -13326
  265. package/cpp/ggml-cpu/sgemm.cpp +0 -3544
  266. package/cpp/ggml-cpu/sgemm.h +0 -14
  267. package/cpp/llama-kv-cache.cpp +0 -2827
  268. package/cpp/llama-kv-cache.h +0 -515
  269. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  270. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  271. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  272. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +0 -24766
  273. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
  274. /package/cpp/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  275. /package/cpp/tools/mtmd/{miniaudio.h → miniaudio/miniaudio.h} +0 -0
  276. /package/cpp/tools/mtmd/{stb_image.h → stb/stb_image.h} +0 -0
@@ -1,14 +1,16 @@
1
1
  #include "llama-context.h"
2
2
 
3
3
  #include "llama-impl.h"
4
+ #include "llama-batch.h"
4
5
  #include "llama-io.h"
6
+ #include "llama-memory.h"
5
7
  #include "llama-mmap.h"
6
8
  #include "llama-model.h"
7
- #include "llama-kv-cache.h"
8
9
 
10
+ #include <cinttypes>
9
11
  #include <cstring>
12
+ #include <limits>
10
13
  #include <stdexcept>
11
- #include <cinttypes>
12
14
 
13
15
  //
14
16
  // llama_context
@@ -17,7 +19,8 @@
17
19
  llama_context::llama_context(
18
20
  const llama_model & model,
19
21
  llama_context_params params) :
20
- model(model) {
22
+ model(model),
23
+ balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
21
24
  LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
22
25
 
23
26
  t_start_us = model.t_start_us;
@@ -25,7 +28,11 @@ llama_context::llama_context(
25
28
 
26
29
  const auto & hparams = model.hparams;
27
30
 
28
- cparams.n_seq_max = std::max(1u, params.n_seq_max);
31
+ cparams.n_seq_max = std::max(1u, params.n_seq_max);
32
+ if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
33
+ throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
34
+ }
35
+
29
36
  cparams.n_threads = params.n_threads;
30
37
  cparams.n_threads_batch = params.n_threads_batch;
31
38
  cparams.yarn_ext_factor = params.yarn_ext_factor;
@@ -118,6 +125,11 @@ llama_context::llama_context(
118
125
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
119
126
  }
120
127
 
128
+ if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
129
+ LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
130
+ __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
131
+ }
132
+
121
133
  if (!hparams.vocab_only) {
122
134
  // GPU backends
123
135
  for (auto * dev : model.devices) {
@@ -255,15 +267,9 @@ llama_context::llama_context(
255
267
 
256
268
  // reserve worst-case graph
257
269
  if (!hparams.vocab_only && memory) {
258
- const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
270
+ const uint32_t n_seqs = cparams.n_seq_max;
259
271
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
260
272
 
261
- llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
262
-
263
- // restore later
264
- // TODO: something cleaner
265
- const auto n_outputs_save = n_outputs;
266
-
267
273
  LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
268
274
 
269
275
  int n_splits_pp = -1;
@@ -273,25 +279,18 @@ llama_context::llama_context(
273
279
  int n_nodes_tg = -1;
274
280
 
275
281
  // simulate full KV cache
276
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
277
282
 
278
- kv_self->set_full();
283
+ const auto mctx = memory->init_full();
284
+ if (!mctx) {
285
+ throw std::runtime_error("failed to initialize KV cache");
286
+ }
279
287
 
280
288
  cross.v_embd.clear();
281
289
 
282
290
  // reserve pp graph first so that buffers are only allocated once
283
291
  {
284
- llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
285
-
286
- // max number of outputs
287
- n_outputs = ubatch_pp.n_tokens;
288
-
289
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
290
-
291
- auto * gf = graph_init();
292
- graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
293
-
294
- if (!lm_ggml_backend_sched_reserve(sched.get(), gf)) {
292
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
293
+ if (!gf) {
295
294
  throw std::runtime_error("failed to allocate compute pp buffers");
296
295
  }
297
296
 
@@ -301,16 +300,8 @@ llama_context::llama_context(
301
300
 
302
301
  // reserve with tg graph to get the number of splits and nodes
303
302
  {
304
- llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
305
-
306
- n_outputs = ubatch_tg.n_tokens;
307
-
308
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
309
-
310
- auto * gf = graph_init();
311
- graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
312
-
313
- if (!lm_ggml_backend_sched_reserve(sched.get(), gf)) {
303
+ auto * gf = graph_reserve(1, 1, 1, mctx.get());
304
+ if (!gf) {
314
305
  throw std::runtime_error("failed to allocate compute tg buffers");
315
306
  }
316
307
 
@@ -320,22 +311,12 @@ llama_context::llama_context(
320
311
 
321
312
  // reserve again with pp graph to avoid ggml-alloc reallocations during inference
322
313
  {
323
- llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
324
-
325
- n_outputs = ubatch_pp.n_tokens;
326
-
327
- LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
328
-
329
- auto * gf = graph_init();
330
- graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
331
-
332
- if (!lm_ggml_backend_sched_reserve(sched.get(), gf)) {
314
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
315
+ if (!gf) {
333
316
  throw std::runtime_error("failed to allocate compute pp buffers");
334
317
  }
335
318
  }
336
319
 
337
- n_outputs = n_outputs_save;
338
-
339
320
  for (size_t i = 0; i < backend_ptrs.size(); ++i) {
340
321
  lm_ggml_backend_t backend = backend_ptrs[i];
341
322
  lm_ggml_backend_buffer_type_t buft = backend_buft[i];
@@ -439,46 +420,71 @@ uint32_t llama_context::n_threads_batch() const {
439
420
  return cparams.n_threads_batch;
440
421
  }
441
422
 
442
- llama_kv_cache * llama_context::get_kv_self() {
443
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
444
- return kv_self;
445
- }
446
-
447
- const llama_kv_cache * llama_context::get_kv_self() const {
448
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
449
- return kv_self;
423
+ llama_memory_t llama_context::get_memory() const {
424
+ return memory.get();
450
425
  }
451
426
 
452
- void llama_context::kv_self_update() {
453
- bool need_reserve = false;
427
+ // deprecated
428
+ void llama_context::kv_self_defrag_sched() {
429
+ if (!memory) {
430
+ return;
431
+ }
454
432
 
455
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
433
+ memory_force_optimize = true;
434
+ }
456
435
 
457
- need_reserve = kv_self->update(*this);
436
+ // deprecated
437
+ bool llama_context::kv_self_update(bool optimize) {
438
+ if (!memory) {
439
+ return false;
440
+ }
458
441
 
459
- // reserve a worst case graph if needed
460
- if (need_reserve) {
461
- LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
442
+ {
443
+ // TODO: remove in the future
444
+ optimize |= memory_force_optimize;
445
+ memory_force_optimize = false;
462
446
 
463
- // build worst-case graph
464
- uint32_t n_seqs = 1; // TODO: worst-case number of sequences
465
- uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
447
+ const auto mctx = memory->init_update(this, optimize);
448
+ switch (mctx->get_status()) {
449
+ case LLAMA_MEMORY_STATUS_SUCCESS:
450
+ {
451
+ // noop
452
+ } break;
453
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
454
+ {
455
+ // no updates need to be performed
456
+ return false;
457
+ }
458
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
459
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
460
+ {
461
+ LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
462
+ return false;
463
+ }
464
+ }
466
465
 
467
- // simulate full KV cache
468
- kv_self->set_full();
466
+ if (!mctx->apply()) {
467
+ LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
468
+ }
469
+ }
469
470
 
470
- llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
471
- llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
471
+ // if the memory module did any computation, we have to reserve a new worst-case graph
472
+ {
473
+ const auto mctx = memory->init_full();
474
+ if (!mctx) {
475
+ throw std::runtime_error("failed to initialize memory context");
476
+ }
472
477
 
473
- auto * gf = graph_init();
474
- graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
478
+ const uint32_t n_seqs = cparams.n_seq_max;
479
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
475
480
 
476
- // initialize scheduler with the worst-case graph
477
- lm_ggml_backend_sched_reset(sched.get());
478
- if (!lm_ggml_backend_sched_reserve(sched.get(), gf)) {
479
- LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
481
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
482
+ if (!gf) {
483
+ LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
480
484
  }
481
485
  }
486
+
487
+ return true;
482
488
  }
483
489
 
484
490
  enum llama_pooling_type llama_context::pooling_type() const {
@@ -490,7 +496,7 @@ float * llama_context::get_logits() {
490
496
  }
491
497
 
492
498
  float * llama_context::get_logits_ith(int32_t i) {
493
- int32_t j = -1;
499
+ int64_t j = -1;
494
500
 
495
501
  try {
496
502
  if (logits == nullptr) {
@@ -513,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) {
513
519
  }
514
520
  if (j >= n_outputs) {
515
521
  // This should not happen
516
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
522
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
517
523
  }
518
524
 
519
525
  return logits + j*model.vocab.n_tokens();
@@ -532,7 +538,7 @@ float * llama_context::get_embeddings() {
532
538
  }
533
539
 
534
540
  float * llama_context::get_embeddings_ith(int32_t i) {
535
- int32_t j = -1;
541
+ int64_t j = -1;
536
542
 
537
543
  try {
538
544
  if (embd == nullptr) {
@@ -555,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
555
561
  }
556
562
  if (j >= n_outputs) {
557
563
  // This should not happen
558
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
564
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
559
565
  }
560
566
 
561
567
  return embd + j*model.hparams.n_embd;
@@ -672,63 +678,95 @@ bool llama_context::apply_adapter_cvec(
672
678
  return cvec.apply(model, data, len, n_embd, il_start, il_end);
673
679
  }
674
680
 
675
- int llama_context::encode(llama_batch & inp_batch) {
676
- if (inp_batch.n_tokens == 0) {
677
- LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
678
- return -1;
681
+ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, lm_ggml_status & ret) {
682
+ if (mctx && !mctx->apply()) {
683
+ LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
684
+ ret = LM_GGML_STATUS_FAILED;
685
+ return nullptr;
679
686
  }
680
687
 
681
- // temporary allocate memory for the input batch if needed
682
- // note: during encode, we always pass the full sequence starting from pos = 0
683
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
688
+ auto * gf = graph_init();
689
+ if (!gf) {
690
+ LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
691
+ ret = LM_GGML_STATUS_FAILED;
692
+ return nullptr;
693
+ }
694
+
695
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
696
+ if (!res) {
697
+ LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
698
+ ret = LM_GGML_STATUS_FAILED;
699
+ return nullptr;
700
+ }
701
+
702
+ // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (lm_ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
703
+
704
+ if (!lm_ggml_backend_sched_alloc_graph(sched.get(), gf)) {
705
+ LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
706
+ ret = LM_GGML_STATUS_ALLOC_FAILED;
707
+ return nullptr;
708
+ }
709
+
710
+ res->set_inputs(&ubatch);
711
+
712
+ const auto status = graph_compute(gf, ubatch.n_tokens > 1);
713
+ if (status != LM_GGML_STATUS_SUCCESS) {
714
+ LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
715
+ ret = status;
716
+ return nullptr;
717
+ }
718
+
719
+ ret = LM_GGML_STATUS_SUCCESS;
684
720
 
685
- const llama_batch & batch = batch_allocr.batch;
686
- const int32_t n_tokens = batch.n_tokens;
721
+ return res;
722
+ }
723
+
724
+ int llama_context::encode(const llama_batch & batch_inp) {
725
+ LM_GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
726
+
727
+ if (batch_inp.n_tokens == 0) {
728
+ LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
729
+ return -1;
730
+ }
687
731
 
688
732
  const auto & hparams = model.hparams;
689
733
 
690
- LM_GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
734
+ const int64_t n_embd = hparams.n_embd;
691
735
 
692
- if (batch.token) {
693
- for (int32_t i = 0; i < n_tokens; ++i) {
694
- if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
695
- LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
696
- return -1;
697
- }
698
- }
736
+ // note: during encode, we always pass the full sequence starting from pos = 0
737
+ if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
738
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
739
+ return -1;
699
740
  }
700
741
 
742
+ const uint32_t n_tokens = balloc->get_n_tokens();
743
+
744
+ const llama_ubatch ubatch = balloc->split_simple(n_tokens);
745
+
701
746
  // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
702
- LM_GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
747
+ LM_GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
703
748
 
704
749
  if (t_compute_start_us == 0) {
705
750
  t_compute_start_us = lm_ggml_time_us();
706
751
  }
707
752
 
753
+ // TODO: this clear of the buffer can easily be forgotten - need something better
708
754
  embd_seq.clear();
709
755
 
710
756
  n_queued_tokens += n_tokens;
711
757
 
712
- const int64_t n_embd = hparams.n_embd;
713
-
714
- llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
715
-
716
- const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
717
-
718
758
  // reserve output buffer
719
759
  if (output_reserve(n_tokens) < n_tokens) {
720
760
  LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
721
761
  return -2;
722
762
  };
723
763
 
724
- for (int32_t i = 0; i < n_tokens; ++i) {
764
+ for (uint32_t i = 0; i < n_tokens; ++i) {
725
765
  output_ids[i] = i;
726
766
  }
727
767
 
728
768
  n_outputs = n_tokens;
729
769
 
730
- //batch_manager->prepare(ubatch);
731
-
732
770
  lm_ggml_backend_sched_reset(sched.get());
733
771
  lm_ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
734
772
 
@@ -739,26 +777,18 @@ int llama_context::encode(llama_batch & inp_batch) {
739
777
  // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
740
778
  cparams.causal_attn = false;
741
779
 
742
- auto * gf = graph_init();
743
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
744
-
745
- lm_ggml_backend_sched_alloc_graph(sched.get(), gf);
746
-
747
- res->set_inputs(&ubatch);
780
+ lm_ggml_status status;
781
+ const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
748
782
 
749
783
  cparams.causal_attn = causal_attn_org;
750
784
 
751
- const auto compute_status = graph_compute(gf, n_tokens > 1);
752
- switch (compute_status) {
753
- case LM_GGML_STATUS_SUCCESS:
754
- break;
755
- case LM_GGML_STATUS_ABORTED:
756
- return 2;
757
- case LM_GGML_STATUS_ALLOC_FAILED:
758
- return -2;
759
- case LM_GGML_STATUS_FAILED:
760
- default:
761
- return -3;
785
+ if (!res) {
786
+ switch (status) {
787
+ case LM_GGML_STATUS_ABORTED: return 2;
788
+ case LM_GGML_STATUS_ALLOC_FAILED: return -2;
789
+ case LM_GGML_STATUS_FAILED: return -3;
790
+ case LM_GGML_STATUS_SUCCESS: LM_GGML_ABORT("should not happen");
791
+ }
762
792
  }
763
793
 
764
794
  auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
@@ -783,31 +813,28 @@ int llama_context::encode(llama_batch & inp_batch) {
783
813
  {
784
814
  // extract sequence embeddings
785
815
  auto & embd_seq_out = embd_seq;
786
- embd_seq_out.clear();
787
816
 
788
- LM_GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
817
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
818
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
819
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
789
820
 
790
- for (int32_t i = 0; i < n_tokens; i++) {
791
- const llama_seq_id seq_id = ubatch.seq_id[i][0];
792
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
793
- continue;
794
- }
795
821
  embd_seq_out[seq_id].resize(n_embd);
796
- lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
822
+ lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
797
823
  }
798
824
  } break;
799
825
  case LLAMA_POOLING_TYPE_RANK:
800
826
  {
801
- // extract the rerank score - a single float per sequence
827
+ // extract the rerank score - n_cls_out floats per sequence
802
828
  auto & embd_seq_out = embd_seq;
803
829
 
804
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
805
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
806
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
807
- continue;
808
- }
809
- embd_seq_out[seq_id].resize(1);
810
- lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
830
+ const uint32_t n_cls_out = hparams.n_cls_out;
831
+
832
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
833
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
834
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
835
+
836
+ embd_seq_out[seq_id].resize(n_cls_out);
837
+ lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
811
838
  }
812
839
  } break;
813
840
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -832,12 +859,16 @@ int llama_context::encode(llama_batch & inp_batch) {
832
859
  cross.v_embd.resize(cross.n_embd*cross.n_enc);
833
860
  memcpy(cross.v_embd.data(), embd, lm_ggml_nbytes(t_embd));
834
861
 
862
+ const auto & batch = balloc->get_batch();
863
+
835
864
  // remember the sequence ids used during the encoding - needed for cross attention later
836
865
  cross.seq_ids_enc.resize(n_tokens);
837
- for (int32_t i = 0; i < n_tokens; i++) {
866
+ for (uint32_t i = 0; i < n_tokens; i++) {
838
867
  cross.seq_ids_enc[i].clear();
839
- for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
840
- llama_seq_id seq_id = ubatch.seq_id[i][s];
868
+
869
+ for (int s = 0; s < batch.n_seq_id[i]; s++) {
870
+ const llama_seq_id seq_id = batch.seq_id[i][s];
871
+
841
872
  cross.seq_ids_enc[i].insert(seq_id);
842
873
  }
843
874
  }
@@ -846,49 +877,42 @@ int llama_context::encode(llama_batch & inp_batch) {
846
877
  return 0;
847
878
  }
848
879
 
849
- int llama_context::decode(llama_batch & inp_batch) {
880
+ int llama_context::decode(const llama_batch & batch_inp) {
881
+ LM_GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
882
+
850
883
  if (!memory) {
851
- LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
852
- return encode(inp_batch);
884
+ LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
885
+ return encode(batch_inp);
853
886
  }
854
887
 
855
- if (inp_batch.n_tokens == 0) {
888
+ if (batch_inp.n_tokens == 0) {
856
889
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
857
890
  return -1;
858
891
  }
859
892
 
860
- if (!inp_batch.pos) {
861
- if (inp_batch.seq_id) {
862
- LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
863
- return -1;
864
- }
865
- }
866
-
867
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
868
-
869
- // temporary allocate memory for the input batch if needed
870
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
871
-
872
- const llama_batch & batch = batch_allocr.batch;
873
-
874
893
  const auto & vocab = model.vocab;
875
894
  const auto & hparams = model.hparams;
876
895
 
877
896
  const int32_t n_vocab = vocab.n_tokens();
897
+ const int64_t n_embd = hparams.n_embd;
878
898
 
879
- const int64_t n_tokens_all = batch.n_tokens;
880
- const int64_t n_embd = hparams.n_embd;
899
+ // when computing embeddings, all tokens are output
900
+ const bool output_all = cparams.embeddings;
881
901
 
882
- llama_kv_cache_guard kv_guard(kv_self);
902
+ if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
903
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
904
+ return -1;
905
+ }
883
906
 
884
- LM_GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
907
+ const uint32_t n_tokens_all = balloc->get_n_tokens();
908
+ const uint32_t n_outputs_all = balloc->get_n_outputs();
885
909
 
886
- if (batch.token) {
887
- for (int64_t i = 0; i < n_tokens_all; ++i) {
888
- if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
889
- LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
890
- throw std::runtime_error("invalid token");
891
- }
910
+ if (output_all) {
911
+ // require that all tokens are output
912
+ if (n_outputs_all != n_tokens_all) {
913
+ LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
914
+ __func__, n_outputs_all, n_tokens_all);
915
+ return -1;
892
916
  }
893
917
  }
894
918
 
@@ -901,49 +925,77 @@ int llama_context::decode(llama_batch & inp_batch) {
901
925
  }
902
926
  n_queued_tokens += n_tokens_all;
903
927
 
904
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
905
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
906
-
928
+ // TODO: this clear of the buffer can easily be forgotten - need something better
907
929
  embd_seq.clear();
908
930
 
909
- int64_t n_outputs_all = 0;
931
+ bool did_optimize = false;
932
+
933
+ // handle any pending defrags/shifts
934
+ kv_self_update(false);
935
+
936
+ llama_memory_context_ptr mctx;
910
937
 
911
- // count outputs
912
- if (batch.logits && !embd_pooled) {
913
- for (uint32_t i = 0; i < n_tokens_all; ++i) {
914
- n_outputs_all += batch.logits[i] != 0;
938
+ while (true) {
939
+ mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
940
+ if (!mctx) {
941
+ return -2;
915
942
  }
916
- } else if (embd_pooled) {
917
- n_outputs_all = n_tokens_all;
918
- } else {
919
- // keep last output only
920
- n_outputs_all = 1;
921
- }
922
943
 
923
- llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
944
+ switch (mctx->get_status()) {
945
+ case LLAMA_MEMORY_STATUS_SUCCESS:
946
+ {
947
+ } break;
948
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
949
+ {
950
+ LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
951
+
952
+ return -2;
953
+ }
954
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
955
+ {
956
+ if (!did_optimize) {
957
+ did_optimize = true;
958
+
959
+ if (kv_self_update(true)) {
960
+ LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
961
+
962
+ continue;
963
+ }
964
+ }
965
+
966
+ LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
967
+
968
+ return 1;
969
+ }
970
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
971
+ {
972
+ LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
973
+
974
+ return -2;
975
+ }
976
+ }
977
+
978
+ break;
979
+ }
924
980
 
925
981
  // reserve output buffer
926
982
  if (output_reserve(n_outputs_all) < n_outputs_all) {
927
- LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
983
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
928
984
  return -2;
929
985
  };
930
986
 
931
- // handle any pending defrags/shifts
932
- kv_self_update();
933
-
934
987
  int64_t n_outputs_prev = 0;
935
988
 
936
- while (sbatch.n_tokens > 0) {
937
- llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
989
+ do {
990
+ const auto & ubatch = mctx->get_ubatch();
938
991
 
939
- // count the outputs in this u_batch
992
+ // count the outputs in this ubatch
940
993
  {
941
994
  int32_t n_outputs_new = 0;
942
995
 
943
996
  if (n_outputs_all == n_tokens_all) {
944
997
  n_outputs_new = ubatch.n_tokens;
945
998
  } else {
946
- LM_GGML_ASSERT(ubatch.output);
947
999
  for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
948
1000
  n_outputs_new += (int32_t) (ubatch.output[i] != 0);
949
1001
  }
@@ -953,33 +1005,40 @@ int llama_context::decode(llama_batch & inp_batch) {
953
1005
  n_outputs = n_outputs_new;
954
1006
  }
955
1007
 
956
- // find KV slot
957
- if (!kv_self->find_slot(ubatch)) {
958
- return 1;
959
- }
960
-
961
1008
  lm_ggml_backend_sched_reset(sched.get());
962
1009
  lm_ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
963
1010
 
964
- auto * gf = graph_init();
965
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
1011
+ lm_ggml_status status;
1012
+ const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
966
1013
 
967
- // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (lm_ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1014
+ if (!res) {
1015
+ // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
1016
+ llama_pos pos_min[LLAMA_MAX_SEQ];
1017
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1018
+ pos_min[s] = std::numeric_limits<llama_pos>::max();
1019
+ }
968
1020
 
969
- lm_ggml_backend_sched_alloc_graph(sched.get(), gf);
1021
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
1022
+ const auto & seq_id = ubatch.seq_id[i][0];
970
1023
 
971
- res->set_inputs(&ubatch);
1024
+ pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
1025
+ }
972
1026
 
973
- const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
974
- if (compute_status != LM_GGML_STATUS_SUCCESS) {
975
- switch (compute_status) {
976
- case LM_GGML_STATUS_ABORTED:
977
- return 2;
978
- case LM_GGML_STATUS_ALLOC_FAILED:
979
- return -2;
980
- case LM_GGML_STATUS_FAILED:
981
- default:
982
- return -3;
1027
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
1028
+ if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
1029
+ continue;
1030
+ }
1031
+
1032
+ LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
1033
+
1034
+ memory->seq_rm(s, pos_min[s], -1);
1035
+ }
1036
+
1037
+ switch (status) {
1038
+ case LM_GGML_STATUS_ABORTED: return 2;
1039
+ case LM_GGML_STATUS_ALLOC_FAILED: return -2;
1040
+ case LM_GGML_STATUS_FAILED: return -3;
1041
+ case LM_GGML_STATUS_SUCCESS: LM_GGML_ABORT("should not happen");
983
1042
  }
984
1043
  }
985
1044
 
@@ -988,7 +1047,7 @@ int llama_context::decode(llama_batch & inp_batch) {
988
1047
  // lm_ggml_graph_dump_dot(gf, NULL, "llama.dot");
989
1048
  //}
990
1049
 
991
- auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
1050
+ auto * t_logits = res->get_logits();
992
1051
  auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
993
1052
 
994
1053
  if (t_embd && res->get_embd_pooled()) {
@@ -1035,27 +1094,27 @@ int llama_context::decode(llama_batch & inp_batch) {
1035
1094
  // extract sequence embeddings (cleared before processing each batch)
1036
1095
  auto & embd_seq_out = embd_seq;
1037
1096
 
1038
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
1039
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
1040
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1041
- continue;
1042
- }
1097
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
1098
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
1099
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
1100
+
1043
1101
  embd_seq_out[seq_id].resize(n_embd);
1044
- lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
1102
+ lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
1045
1103
  }
1046
1104
  } break;
1047
1105
  case LLAMA_POOLING_TYPE_RANK:
1048
1106
  {
1049
- // extract the rerank score - a single float per sequence
1107
+ // extract the rerank score - n_cls_out floats per sequence
1050
1108
  auto & embd_seq_out = embd_seq;
1051
1109
 
1052
- for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
1053
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
1054
- if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1055
- continue;
1056
- }
1057
- embd_seq_out[seq_id].resize(1);
1058
- lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
1110
+ const uint32_t n_cls_out = hparams.n_cls_out;
1111
+
1112
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
1113
+ const llama_seq_id seq_id = ubatch.seq_id_unq[s];
1114
+ const int32_t seq_idx = ubatch.seq_idx[seq_id];
1115
+
1116
+ embd_seq_out[seq_id].resize(n_cls_out);
1117
+ lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
1059
1118
  }
1060
1119
  } break;
1061
1120
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -1066,23 +1125,20 @@ int llama_context::decode(llama_batch & inp_batch) {
1066
1125
  }
1067
1126
 
1068
1127
  n_outputs_prev += n_outputs;
1069
- }
1070
-
1071
- // finalize the batch processing
1072
- kv_guard.commit();
1128
+ } while (mctx->next());
1073
1129
 
1074
1130
  // set to total number of outputs in the batch, for use in llama_get_logits_ith
1075
1131
  n_outputs = n_outputs_all;
1076
1132
 
1077
1133
  // set output mappings
1078
- {
1134
+ if (n_outputs > 0) {
1079
1135
  bool sorted_output = true;
1080
1136
 
1081
- auto & out_ids = sbatch.out_ids;
1137
+ auto & out_ids = balloc->get_out_ids();
1082
1138
 
1083
- LM_GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
1139
+ LM_GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
1084
1140
 
1085
- for (int64_t i = 0; i < n_outputs_all; ++i) {
1141
+ for (int64_t i = 0; i < n_outputs; ++i) {
1086
1142
  int64_t out_id = out_ids[i];
1087
1143
  output_ids[out_id] = i;
1088
1144
  if (out_id != i) {
@@ -1094,20 +1150,22 @@ int llama_context::decode(llama_batch & inp_batch) {
1094
1150
  // note: this is mostly relevant for recurrent models atm
1095
1151
  if (!sorted_output) {
1096
1152
  const uint32_t n_vocab = model.vocab.n_tokens();
1097
- const uint32_t n_embd = model.hparams.n_embd;
1153
+ const uint64_t n_embd = model.hparams.n_embd;
1098
1154
 
1099
1155
  LM_GGML_ASSERT((size_t) n_outputs == out_ids.size());
1100
1156
 
1101
1157
  // TODO: is there something more efficient which also minimizes swaps?
1102
1158
  // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1103
- for (int32_t i = 0; i < n_outputs - 1; ++i) {
1104
- int32_t j_min = i;
1105
- for (int32_t j = i + 1; j < n_outputs; ++j) {
1159
+ for (uint32_t i = 0; i < n_outputs - 1; ++i) {
1160
+ uint32_t j_min = i;
1161
+ for (uint32_t j = i + 1; j < n_outputs; ++j) {
1106
1162
  if (out_ids[j] < out_ids[j_min]) {
1107
1163
  j_min = j;
1108
1164
  }
1109
1165
  }
1110
- if (j_min == i) { continue; }
1166
+ if (j_min == i) {
1167
+ continue;
1168
+ }
1111
1169
  std::swap(out_ids[i], out_ids[j_min]);
1112
1170
  if (logits_size > 0) {
1113
1171
  for (uint32_t k = 0; k < n_vocab; k++) {
@@ -1120,8 +1178,10 @@ int llama_context::decode(llama_batch & inp_batch) {
1120
1178
  }
1121
1179
  }
1122
1180
  }
1181
+
1123
1182
  std::fill(output_ids.begin(), output_ids.end(), -1);
1124
- for (int32_t i = 0; i < n_outputs; ++i) {
1183
+
1184
+ for (uint32_t i = 0; i < n_outputs; ++i) {
1125
1185
  output_ids[out_ids[i]] = i;
1126
1186
  }
1127
1187
  }
@@ -1130,11 +1190,6 @@ int llama_context::decode(llama_batch & inp_batch) {
1130
1190
  // wait for the computation to finish (automatically done when obtaining the model output)
1131
1191
  //synchronize();
1132
1192
 
1133
- // decide if we need to defrag the kv cache
1134
- if (cparams.defrag_thold > 0.0f) {
1135
- kv_self->defrag_sched(cparams.defrag_thold);
1136
- }
1137
-
1138
1193
  // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1139
1194
  // overlap with device computation.
1140
1195
  lm_ggml_backend_sched_reset(sched.get());
@@ -1146,7 +1201,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1146
1201
  // output
1147
1202
  //
1148
1203
 
1149
- int32_t llama_context::output_reserve(int32_t n_outputs) {
1204
+ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1150
1205
  const auto & hparams = model.hparams;
1151
1206
  const auto & vocab = model.vocab;
1152
1207
 
@@ -1156,9 +1211,8 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1156
1211
  const auto n_vocab = vocab.n_tokens();
1157
1212
  const auto n_embd = hparams.n_embd;
1158
1213
 
1159
- // TODO: use a per-batch flag for logits presence instead
1160
- bool has_logits = !cparams.embeddings;
1161
- bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
1214
+ bool has_logits = true;
1215
+ bool has_embd = cparams.embeddings;
1162
1216
 
1163
1217
  // TODO: hacky enc-dec support
1164
1218
  if (model.arch == LLM_ARCH_T5) {
@@ -1212,8 +1266,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1212
1266
  // set all ids as invalid (negative)
1213
1267
  std::fill(output_ids.begin(), output_ids.end(), -1);
1214
1268
 
1215
- this->n_outputs = 0;
1216
- this->n_outputs_max = n_outputs_max;
1269
+ this->n_outputs = 0;
1217
1270
 
1218
1271
  return n_outputs_max;
1219
1272
  }
@@ -1238,11 +1291,52 @@ lm_ggml_cgraph * llama_context::graph_init() {
1238
1291
  return lm_ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1239
1292
  }
1240
1293
 
1294
+ lm_ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
1295
+ LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
1296
+
1297
+ if (n_tokens % n_seqs != 0) {
1298
+ n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
1299
+ n_outputs = std::min(n_outputs, n_tokens);
1300
+
1301
+ LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
1302
+ }
1303
+
1304
+ // store the n_outputs as it is, and restore it afterwards
1305
+ // TODO: not sure if needed, might simplify in the future by removing this
1306
+ const auto save_n_outputs = this->n_outputs;
1307
+
1308
+ this->n_outputs = n_outputs;
1309
+
1310
+ llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
1311
+ llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
1312
+
1313
+ auto * gf = graph_init();
1314
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
1315
+
1316
+ this->n_outputs = save_n_outputs;
1317
+
1318
+ if (!res) {
1319
+ LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
1320
+ return nullptr;
1321
+ }
1322
+
1323
+ lm_ggml_backend_sched_reset(sched.get());
1324
+
1325
+ // initialize scheduler with the specified graph
1326
+ if (!lm_ggml_backend_sched_reserve(sched.get(), gf)) {
1327
+ LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
1328
+ return nullptr;
1329
+ }
1330
+
1331
+ return gf;
1332
+ }
1333
+
1241
1334
  llm_graph_result_ptr llama_context::graph_build(
1242
- lm_ggml_context * ctx,
1243
- lm_ggml_cgraph * gf,
1244
- const llama_ubatch & ubatch,
1245
- llm_graph_type gtype) {
1335
+ lm_ggml_context * ctx,
1336
+ lm_ggml_cgraph * gf,
1337
+ const llama_ubatch & ubatch,
1338
+ llm_graph_type gtype,
1339
+ const llama_memory_context_i * mctx) {
1246
1340
  return model.build_graph(
1247
1341
  {
1248
1342
  /*.ctx =*/ ctx,
@@ -1254,7 +1348,7 @@ llm_graph_result_ptr llama_context::graph_build(
1254
1348
  /*.backend_cpu =*/ backend_cpu,
1255
1349
  /*.cvec =*/ &cvec,
1256
1350
  /*.loras =*/ &loras,
1257
- /*.memory =*/ memory.get(),
1351
+ /*.mctx =*/ mctx,
1258
1352
  /*.cross =*/ &cross,
1259
1353
  /*.n_outputs =*/ n_outputs,
1260
1354
  /*.cb =*/ graph_get_cb(),
@@ -1663,14 +1757,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1663
1757
 
1664
1758
  std::vector<int32_t> w_output_pos;
1665
1759
 
1666
- LM_GGML_ASSERT(n_outputs <= n_outputs_max);
1667
-
1668
1760
  w_output_pos.resize(n_outputs);
1669
1761
 
1670
1762
  // build a more compact representation of the output ids
1671
1763
  for (size_t i = 0; i < n_batch(); ++i) {
1672
1764
  // map an output id to a position in the batch
1673
- int32_t pos = output_ids[i];
1765
+ int64_t pos = output_ids[i];
1674
1766
  if (pos >= 0) {
1675
1767
  LM_GGML_ASSERT(pos < n_outputs);
1676
1768
  w_output_pos[pos] = i;
@@ -1710,11 +1802,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1710
1802
  }
1711
1803
  }
1712
1804
 
1713
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1714
-
1715
- if (kv_self != nullptr) {
1805
+ if (memory != nullptr) {
1716
1806
  LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1717
- kv_self->state_write(io);
1807
+ memory->state_write(io);
1718
1808
  }
1719
1809
 
1720
1810
  return io.n_bytes();
@@ -1801,9 +1891,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
1801
1891
  if (memory) {
1802
1892
  LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1803
1893
 
1804
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1805
-
1806
- kv_self->state_read(io);
1894
+ memory->state_read(io);
1807
1895
  }
1808
1896
 
1809
1897
  return io.n_bytes();
@@ -1813,9 +1901,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
1813
1901
  LM_GGML_UNUSED(seq_id);
1814
1902
 
1815
1903
  if (memory) {
1816
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1817
-
1818
- kv_self->state_write(io, seq_id);
1904
+ memory->state_write(io, seq_id);
1819
1905
  }
1820
1906
 
1821
1907
  return io.n_bytes();
@@ -1825,9 +1911,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
1825
1911
  LM_GGML_UNUSED(seq_id);
1826
1912
 
1827
1913
  if (memory) {
1828
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1829
-
1830
- kv_self->state_read(io, seq_id);
1914
+ memory->state_read(io, seq_id);
1831
1915
  }
1832
1916
 
1833
1917
  return io.n_bytes();
@@ -1932,10 +2016,7 @@ void llama_context::opt_epoch_iter(
1932
2016
  const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
1933
2017
  const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
1934
2018
 
1935
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1936
-
1937
- kv_self->clear();
1938
- llama_kv_cache_guard kv_guard(kv_self);
2019
+ memory->clear(true);
1939
2020
 
1940
2021
  for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
1941
2022
  batch.n_tokens = n_batch;
@@ -1947,39 +2028,44 @@ void llama_context::opt_epoch_iter(
1947
2028
  batch.logits [pos_batch] = true;
1948
2029
  }
1949
2030
 
1950
- const auto n_tokens_all = batch.n_tokens;
2031
+ if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
2032
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
2033
+ return;
2034
+ }
1951
2035
 
1952
- n_queued_tokens += n_tokens_all;
2036
+ const uint32_t n_tokens_all = balloc->get_n_tokens();
1953
2037
 
1954
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
1955
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
2038
+ n_queued_tokens += n_tokens_all;
1956
2039
 
1957
2040
  embd_seq.clear();
1958
2041
 
1959
- int64_t n_outputs_all = n_tokens_all;
2042
+ uint32_t n_outputs_all = n_tokens_all;
1960
2043
 
1961
- llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
2044
+ auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
2045
+ if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
2046
+ LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
2047
+ break;
2048
+ }
1962
2049
 
1963
2050
  // reserve output buffer
1964
2051
  if (output_reserve(n_outputs_all) < n_outputs_all) {
1965
- LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
2052
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
1966
2053
  LM_GGML_ABORT("TODO: handle this error");
1967
2054
  };
1968
2055
 
1969
- for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
1970
- llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
2056
+ uint32_t pos_batch = 0;
2057
+ do {
2058
+ const auto & ubatch = mctx->get_ubatch();
1971
2059
 
1972
2060
  n_outputs = ubatch.n_tokens;
1973
2061
 
1974
- // TODO: not sure if this is needed
1975
- if (!kv_self->find_slot(ubatch)) {
1976
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1977
-
1978
- LM_GGML_ABORT("TODO: handle this error");
2062
+ if (!mctx->apply()) {
2063
+ LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
2064
+ break;
1979
2065
  }
1980
2066
 
1981
2067
  auto * gf = graph_init();
1982
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
2068
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
1983
2069
 
1984
2070
  struct lm_ggml_context * ctx_compute_opt;
1985
2071
  {
@@ -1994,6 +2080,7 @@ void llama_context::opt_epoch_iter(
1994
2080
  }
1995
2081
  lm_ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
1996
2082
  lm_ggml_opt_alloc(opt_ctx, train);
2083
+
1997
2084
  res->set_inputs(&ubatch);
1998
2085
  {
1999
2086
  struct lm_ggml_tensor * labels = lm_ggml_opt_labels(opt_ctx);
@@ -2011,10 +2098,10 @@ void llama_context::opt_epoch_iter(
2011
2098
  callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
2012
2099
  }
2013
2100
  lm_ggml_free(ctx_compute_opt);
2014
- }
2015
- }
2016
2101
 
2017
- kv_guard.commit();
2102
+ pos_batch += ubatch.n_tokens;
2103
+ } while (mctx->next());
2104
+ }
2018
2105
  }
2019
2106
 
2020
2107
  void llama_context::opt_epoch(
@@ -2174,12 +2261,14 @@ const llama_model * llama_get_model(const llama_context * ctx) {
2174
2261
  return &ctx->get_model();
2175
2262
  }
2176
2263
 
2264
+ // deprecated
2177
2265
  llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
2178
- return ctx->get_kv_self();
2266
+ return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
2179
2267
  }
2180
2268
 
2269
+ // deprecated
2181
2270
  void llama_kv_self_update(llama_context * ctx) {
2182
- ctx->kv_self_update();
2271
+ ctx->kv_self_update(false);
2183
2272
  }
2184
2273
 
2185
2274
  enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
@@ -2294,13 +2383,118 @@ int32_t llama_apply_adapter_cvec(
2294
2383
  return res ? 0 : -1;
2295
2384
  }
2296
2385
 
2386
+ //
2387
+ // memory
2388
+ //
2389
+
2390
+ llama_memory_t llama_get_memory(const struct llama_context * ctx) {
2391
+ return ctx->get_memory();
2392
+ }
2393
+
2394
+ void llama_memory_clear(llama_memory_t mem, bool data) {
2395
+ if (!mem) {
2396
+ return;
2397
+ }
2398
+
2399
+ mem->clear(data);
2400
+ }
2401
+
2402
+ bool llama_memory_seq_rm(
2403
+ llama_memory_t mem,
2404
+ llama_seq_id seq_id,
2405
+ llama_pos p0,
2406
+ llama_pos p1) {
2407
+ if (!mem) {
2408
+ return true;
2409
+ }
2410
+
2411
+ return mem->seq_rm(seq_id, p0, p1);
2412
+ }
2413
+
2414
+ void llama_memory_seq_cp(
2415
+ llama_memory_t mem,
2416
+ llama_seq_id seq_id_src,
2417
+ llama_seq_id seq_id_dst,
2418
+ llama_pos p0,
2419
+ llama_pos p1) {
2420
+ if (!mem) {
2421
+ return;
2422
+ }
2423
+
2424
+ mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2425
+ }
2426
+
2427
+ void llama_memory_seq_keep(
2428
+ llama_memory_t mem,
2429
+ llama_seq_id seq_id) {
2430
+ if (!mem) {
2431
+ return;
2432
+ }
2433
+
2434
+ mem->seq_keep(seq_id);
2435
+ }
2436
+
2437
+ void llama_memory_seq_add(
2438
+ llama_memory_t mem,
2439
+ llama_seq_id seq_id,
2440
+ llama_pos p0,
2441
+ llama_pos p1,
2442
+ llama_pos delta) {
2443
+ if (!mem) {
2444
+ return;
2445
+ }
2446
+
2447
+ mem->seq_add(seq_id, p0, p1, delta);
2448
+ }
2449
+
2450
+ void llama_memory_seq_div(
2451
+ llama_memory_t mem,
2452
+ llama_seq_id seq_id,
2453
+ llama_pos p0,
2454
+ llama_pos p1,
2455
+ int d) {
2456
+ if (!mem) {
2457
+ return;
2458
+ }
2459
+
2460
+ mem->seq_div(seq_id, p0, p1, d);
2461
+ }
2462
+
2463
+ llama_pos llama_memory_seq_pos_min(
2464
+ llama_memory_t mem,
2465
+ llama_seq_id seq_id) {
2466
+ if (!mem) {
2467
+ return -1;
2468
+ }
2469
+
2470
+ return mem->seq_pos_min(seq_id);
2471
+ }
2472
+
2473
+ llama_pos llama_memory_seq_pos_max(
2474
+ llama_memory_t mem,
2475
+ llama_seq_id seq_id) {
2476
+ if (!mem) {
2477
+ return -1;
2478
+ }
2479
+
2480
+ return mem->seq_pos_max(seq_id);
2481
+ }
2482
+
2483
+ bool llama_memory_can_shift(llama_memory_t mem) {
2484
+ if (!mem) {
2485
+ return false;
2486
+ }
2487
+
2488
+ return mem->get_can_shift();
2489
+ }
2490
+
2297
2491
  //
2298
2492
  // kv cache
2299
2493
  //
2300
2494
 
2301
2495
  // deprecated
2302
2496
  int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2303
- const auto * kv = ctx->get_kv_self();
2497
+ const auto * kv = llama_get_memory(ctx);
2304
2498
  if (!kv) {
2305
2499
  return 0;
2306
2500
  }
@@ -2322,7 +2516,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2322
2516
  // deprecated
2323
2517
  // note: this is the same as above - will be removed anyway, so it's ok
2324
2518
  int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2325
- const auto * kv = ctx->get_kv_self();
2519
+ const auto * kv = llama_get_memory(ctx);
2326
2520
  if (!kv) {
2327
2521
  return 0;
2328
2522
  }
@@ -2341,114 +2535,119 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2341
2535
  return res;
2342
2536
  }
2343
2537
 
2538
+ // deprecated
2344
2539
  void llama_kv_self_clear(llama_context * ctx) {
2345
- auto * kv = ctx->get_kv_self();
2540
+ auto * kv = llama_get_memory(ctx);
2346
2541
  if (!kv) {
2347
2542
  return;
2348
2543
  }
2349
2544
 
2350
- kv->clear();
2545
+ llama_memory_clear(kv, true);
2351
2546
  }
2352
2547
 
2548
+ // deprecated
2353
2549
  bool llama_kv_self_seq_rm(
2354
2550
  llama_context * ctx,
2355
2551
  llama_seq_id seq_id,
2356
2552
  llama_pos p0,
2357
2553
  llama_pos p1) {
2358
- auto * kv = ctx->get_kv_self();
2554
+ auto * kv = llama_get_memory(ctx);
2359
2555
  if (!kv) {
2360
2556
  return true;
2361
2557
  }
2362
2558
 
2363
- return kv->seq_rm(seq_id, p0, p1);
2559
+ return llama_memory_seq_rm(kv, seq_id, p0, p1);
2364
2560
  }
2365
2561
 
2562
+ // deprecated
2366
2563
  void llama_kv_self_seq_cp(
2367
2564
  llama_context * ctx,
2368
2565
  llama_seq_id seq_id_src,
2369
2566
  llama_seq_id seq_id_dst,
2370
2567
  llama_pos p0,
2371
2568
  llama_pos p1) {
2372
- auto * kv = ctx->get_kv_self();
2569
+ auto * kv = llama_get_memory(ctx);
2373
2570
  if (!kv) {
2374
2571
  return;
2375
2572
  }
2376
2573
 
2377
- kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2574
+ llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
2378
2575
  }
2379
2576
 
2577
+ // deprecated
2380
2578
  void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2381
- auto * kv = ctx->get_kv_self();
2579
+ auto * kv = llama_get_memory(ctx);
2382
2580
  if (!kv) {
2383
2581
  return;
2384
2582
  }
2385
2583
 
2386
- kv->seq_keep(seq_id);
2584
+ llama_memory_seq_keep(kv, seq_id);
2387
2585
  }
2388
2586
 
2587
+ // deprecated
2389
2588
  void llama_kv_self_seq_add(
2390
2589
  llama_context * ctx,
2391
2590
  llama_seq_id seq_id,
2392
2591
  llama_pos p0,
2393
2592
  llama_pos p1,
2394
2593
  llama_pos delta) {
2395
- auto * kv = ctx->get_kv_self();
2594
+ auto * kv = llama_get_memory(ctx);
2396
2595
  if (!kv) {
2397
2596
  return;
2398
2597
  }
2399
2598
 
2400
- kv->seq_add(seq_id, p0, p1, delta);
2599
+ llama_memory_seq_add(kv, seq_id, p0, p1, delta);
2401
2600
  }
2402
2601
 
2602
+ // deprecated
2403
2603
  void llama_kv_self_seq_div(
2404
2604
  llama_context * ctx,
2405
2605
  llama_seq_id seq_id,
2406
2606
  llama_pos p0,
2407
2607
  llama_pos p1,
2408
2608
  int d) {
2409
- auto * kv = ctx->get_kv_self();
2609
+ auto * kv = llama_get_memory(ctx);
2410
2610
  if (!kv) {
2411
2611
  return;
2412
2612
  }
2413
2613
 
2414
- kv->seq_div(seq_id, p0, p1, d);
2614
+ llama_memory_seq_div(kv, seq_id, p0, p1, d);
2415
2615
  }
2416
2616
 
2617
+ // deprecated
2417
2618
  llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2418
- const auto * kv = ctx->get_kv_self();
2619
+ auto * kv = llama_get_memory(ctx);
2419
2620
  if (!kv) {
2420
2621
  return -1;
2421
2622
  }
2422
2623
 
2423
- return kv->seq_pos_min(seq_id);
2624
+ return llama_memory_seq_pos_min(kv, seq_id);
2424
2625
  }
2425
2626
 
2627
+ // deprecated
2426
2628
  llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2427
- const auto * kv = ctx->get_kv_self();
2629
+ auto * kv = llama_get_memory(ctx);
2428
2630
  if (!kv) {
2429
2631
  return -1;
2430
2632
  }
2431
2633
 
2432
- return kv->seq_pos_max(seq_id);
2634
+ return llama_memory_seq_pos_max(kv, seq_id);
2433
2635
  }
2434
2636
 
2637
+ // deprecated
2435
2638
  void llama_kv_self_defrag(llama_context * ctx) {
2436
- auto * kv = ctx->get_kv_self();
2437
- if (!kv) {
2438
- return;
2439
- }
2440
-
2441
2639
  // force defrag
2442
- kv->defrag_sched(-1.0f);
2640
+ ctx->kv_self_defrag_sched();
2443
2641
  }
2444
2642
 
2643
+ // deprecated
2445
2644
  bool llama_kv_self_can_shift(const llama_context * ctx) {
2446
- const auto * kv = ctx->get_kv_self();
2645
+ auto * kv = llama_get_memory(ctx);
2447
2646
  if (!kv) {
2448
2647
  return false;
2449
2648
  }
2450
2649
 
2451
- return kv->get_can_shift();
2650
+ return llama_memory_can_shift(kv);
2452
2651
  }
2453
2652
 
2454
2653
  // llama state API
@@ -2573,22 +2772,8 @@ int32_t llama_encode(
2573
2772
  int32_t llama_decode(
2574
2773
  llama_context * ctx,
2575
2774
  llama_batch batch) {
2576
- int ret = ctx->decode(batch);
2577
-
2578
- // defrag and try again
2579
- // TODO: distinguish return code when we are sure that even after defrag there is no space available
2580
- if (ret == 1) {
2581
- llama_kv_self_defrag(ctx);
2582
- ret = ctx->decode(batch);
2583
-
2584
- if (ret == 1) {
2585
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2586
-
2587
- return ret;
2588
- }
2589
- }
2590
-
2591
- if (ret != 0) {
2775
+ const int ret = ctx->decode(batch);
2776
+ if (ret != 0 && ret != 1) {
2592
2777
  LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2593
2778
  }
2594
2779