cui-llama.rn 1.4.6 → 1.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (366) hide show
  1. package/LICENSE +20 -20
  2. package/README.md +317 -319
  3. package/android/build.gradle +116 -116
  4. package/android/gradle.properties +5 -5
  5. package/android/src/main/AndroidManifest.xml +4 -4
  6. package/android/src/main/CMakeLists.txt +124 -117
  7. package/android/src/main/java/com/rnllama/LlamaContext.java +645 -645
  8. package/android/src/main/java/com/rnllama/RNLlama.java +695 -695
  9. package/android/src/main/java/com/rnllama/RNLlamaPackage.java +48 -48
  10. package/android/src/main/jni-utils.h +100 -100
  11. package/android/src/main/jni.cpp +1263 -1245
  12. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  14. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  15. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  16. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  17. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  20. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +135 -135
  21. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +136 -136
  22. package/cpp/README.md +4 -4
  23. package/cpp/binary-ops.cpp +158 -0
  24. package/cpp/binary-ops.h +16 -0
  25. package/cpp/chat.cpp +1769 -1779
  26. package/cpp/chat.h +9 -1
  27. package/cpp/common.cpp +20 -522
  28. package/cpp/common.h +13 -36
  29. package/cpp/cpu-common.h +72 -0
  30. package/cpp/ggml-common.h +12 -6
  31. package/cpp/ggml-cpu-aarch64.cpp +1557 -80
  32. package/cpp/ggml-cpu-impl.h +2 -21
  33. package/cpp/ggml-cpu-quants.c +904 -405
  34. package/cpp/ggml-cpu.c +909 -13237
  35. package/cpp/ggml-impl.h +50 -23
  36. package/cpp/ggml-llama-sim.metallib +0 -0
  37. package/cpp/ggml-llama.metallib +0 -0
  38. package/cpp/ggml-metal-impl.h +597 -523
  39. package/cpp/ggml-metal.m +798 -580
  40. package/cpp/ggml.c +92 -3
  41. package/cpp/ggml.h +30 -6
  42. package/cpp/gguf.cpp +1 -0
  43. package/cpp/llama-adapter.cpp +55 -20
  44. package/cpp/llama-adapter.h +11 -9
  45. package/cpp/llama-arch.cpp +217 -16
  46. package/cpp/llama-arch.h +25 -0
  47. package/cpp/llama-batch.h +2 -2
  48. package/cpp/llama-chat.cpp +54 -2
  49. package/cpp/llama-chat.h +3 -0
  50. package/cpp/llama-context.cpp +2294 -1238
  51. package/cpp/llama-context.h +214 -77
  52. package/cpp/llama-cparams.h +1 -0
  53. package/cpp/llama-graph.cpp +1695 -0
  54. package/cpp/llama-graph.h +592 -0
  55. package/cpp/llama-hparams.cpp +8 -0
  56. package/cpp/llama-hparams.h +17 -0
  57. package/cpp/llama-io.cpp +15 -0
  58. package/cpp/llama-io.h +35 -0
  59. package/cpp/llama-kv-cache.cpp +965 -303
  60. package/cpp/llama-kv-cache.h +145 -151
  61. package/cpp/llama-memory.cpp +1 -0
  62. package/cpp/llama-memory.h +21 -0
  63. package/cpp/llama-mmap.cpp +1 -1
  64. package/cpp/llama-model-loader.cpp +10 -5
  65. package/cpp/llama-model-loader.h +5 -3
  66. package/cpp/llama-model.cpp +9194 -201
  67. package/cpp/llama-model.h +40 -1
  68. package/cpp/llama-sampling.cpp +5 -0
  69. package/cpp/llama-vocab.cpp +36 -5
  70. package/cpp/llama.cpp +51 -9984
  71. package/cpp/llama.h +102 -22
  72. package/cpp/log.cpp +34 -0
  73. package/cpp/minja/chat-template.hpp +15 -7
  74. package/cpp/minja/minja.hpp +120 -94
  75. package/cpp/ops.cpp +8723 -0
  76. package/cpp/ops.h +128 -0
  77. package/cpp/rn-llama.cpp +873 -882
  78. package/cpp/rn-llama.h +138 -148
  79. package/cpp/sampling.cpp +3 -0
  80. package/cpp/sampling.h +107 -107
  81. package/cpp/sgemm.cpp +533 -88
  82. package/cpp/simd-mappings.h +888 -0
  83. package/cpp/speculative.cpp +4 -4
  84. package/cpp/unary-ops.cpp +186 -0
  85. package/cpp/unary-ops.h +28 -0
  86. package/cpp/unicode-data.cpp +7034 -7034
  87. package/cpp/unicode-data.h +20 -20
  88. package/cpp/unicode.cpp +849 -849
  89. package/cpp/unicode.h +66 -66
  90. package/cpp/vec.cpp +258 -0
  91. package/cpp/vec.h +802 -0
  92. package/ios/CMakeLists.txt +116 -105
  93. package/ios/RNLlama.h +7 -7
  94. package/ios/RNLlama.mm +418 -405
  95. package/ios/RNLlamaContext.h +57 -57
  96. package/ios/RNLlamaContext.mm +835 -819
  97. package/ios/rnllama.xcframework/Info.plist +74 -74
  98. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +16 -0
  99. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +143 -0
  100. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +677 -0
  101. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/cpu-common.h +72 -0
  102. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-alloc.h +76 -0
  103. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  104. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +354 -0
  105. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +1857 -0
  106. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +39 -0
  107. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
  108. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
  109. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
  110. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
  111. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +138 -0
  112. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +594 -0
  113. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  114. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal.h +66 -0
  115. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +216 -0
  116. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-quants.h +100 -0
  117. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-threading.h +14 -0
  118. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +2222 -0
  119. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/gguf.h +202 -0
  120. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  121. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json.hpp +24766 -0
  122. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-adapter.h +76 -0
  123. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +428 -0
  124. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +88 -0
  125. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +56 -0
  126. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +265 -0
  127. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +38 -0
  128. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cpp.h +30 -0
  129. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-grammar.h +173 -0
  130. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +592 -0
  131. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +156 -0
  132. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-impl.h +61 -0
  133. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-io.h +35 -0
  134. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +213 -0
  135. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +21 -0
  136. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-mmap.h +68 -0
  137. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-loader.h +169 -0
  138. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +409 -0
  139. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-sampling.h +32 -0
  140. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +125 -0
  141. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +1434 -0
  142. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/log.h +132 -0
  143. package/{cpp → ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja}/chat-template.hpp +15 -7
  144. package/{cpp → ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja}/minja.hpp +120 -94
  145. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +128 -0
  146. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +138 -0
  147. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sampling.h +107 -0
  148. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +14 -0
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +888 -0
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/speculative.h +28 -0
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unary-ops.h +28 -0
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unicode-data.h +20 -0
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unicode.h +66 -0
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +802 -0
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  157. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +16 -0
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +143 -0
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +677 -0
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/cpu-common.h +72 -0
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-alloc.h +76 -0
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +354 -0
  165. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +1857 -0
  166. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +39 -0
  167. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
  168. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
  169. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
  170. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
  171. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +138 -0
  172. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +594 -0
  173. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  174. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal.h +66 -0
  175. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +216 -0
  176. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-quants.h +100 -0
  177. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-threading.h +14 -0
  178. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +2222 -0
  179. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/gguf.h +202 -0
  180. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  181. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +24766 -0
  182. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-adapter.h +76 -0
  183. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +428 -0
  184. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +88 -0
  185. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +56 -0
  186. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +265 -0
  187. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +38 -0
  188. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cpp.h +30 -0
  189. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-grammar.h +173 -0
  190. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +592 -0
  191. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +156 -0
  192. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-impl.h +61 -0
  193. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-io.h +35 -0
  194. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +213 -0
  195. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +21 -0
  196. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-mmap.h +68 -0
  197. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-loader.h +169 -0
  198. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +409 -0
  199. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-sampling.h +32 -0
  200. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +125 -0
  201. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +1434 -0
  202. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/log.h +132 -0
  203. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  204. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  205. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +128 -0
  206. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +138 -0
  207. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sampling.h +107 -0
  208. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +14 -0
  209. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +888 -0
  210. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/speculative.h +28 -0
  211. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +28 -0
  212. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unicode-data.h +20 -0
  213. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unicode.h +66 -0
  214. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +802 -0
  215. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  216. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +101 -0
  217. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  218. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  219. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +16 -0
  220. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +143 -0
  221. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +677 -0
  222. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/cpu-common.h +72 -0
  223. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-alloc.h +76 -0
  224. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  225. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +354 -0
  226. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +1857 -0
  227. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +39 -0
  228. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
  229. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
  230. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
  231. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
  232. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +138 -0
  233. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +594 -0
  234. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  235. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal.h +66 -0
  236. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +216 -0
  237. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-quants.h +100 -0
  238. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-threading.h +14 -0
  239. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +2222 -0
  240. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/gguf.h +202 -0
  241. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  242. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json.hpp +24766 -0
  243. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-adapter.h +76 -0
  244. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +428 -0
  245. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +88 -0
  246. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +56 -0
  247. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +265 -0
  248. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +38 -0
  249. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cpp.h +30 -0
  250. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-grammar.h +173 -0
  251. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +592 -0
  252. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +156 -0
  253. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-impl.h +61 -0
  254. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-io.h +35 -0
  255. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +213 -0
  256. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +21 -0
  257. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-mmap.h +68 -0
  258. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-loader.h +169 -0
  259. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +409 -0
  260. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-sampling.h +32 -0
  261. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +125 -0
  262. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +1434 -0
  263. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/log.h +132 -0
  264. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  265. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  266. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +128 -0
  267. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +138 -0
  268. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sampling.h +107 -0
  269. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +14 -0
  270. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +888 -0
  271. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/speculative.h +28 -0
  272. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +28 -0
  273. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unicode-data.h +20 -0
  274. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unicode.h +66 -0
  275. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/vec.h +802 -0
  276. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  277. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  278. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  279. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +16 -0
  280. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +143 -0
  281. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +677 -0
  282. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/cpu-common.h +72 -0
  283. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-alloc.h +76 -0
  284. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
  285. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +354 -0
  286. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +1857 -0
  287. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +39 -0
  288. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +8 -0
  289. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +512 -0
  290. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +63 -0
  291. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +38 -0
  292. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +138 -0
  293. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +594 -0
  294. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
  295. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal.h +66 -0
  296. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +216 -0
  297. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-quants.h +100 -0
  298. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-threading.h +14 -0
  299. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +2222 -0
  300. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/gguf.h +202 -0
  301. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
  302. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +24766 -0
  303. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-adapter.h +76 -0
  304. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +428 -0
  305. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +88 -0
  306. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +56 -0
  307. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +265 -0
  308. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +38 -0
  309. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cpp.h +30 -0
  310. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-grammar.h +173 -0
  311. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +592 -0
  312. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +156 -0
  313. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-impl.h +61 -0
  314. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-io.h +35 -0
  315. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +213 -0
  316. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +21 -0
  317. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-mmap.h +68 -0
  318. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-loader.h +169 -0
  319. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +409 -0
  320. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-sampling.h +32 -0
  321. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +125 -0
  322. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +1434 -0
  323. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/log.h +132 -0
  324. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
  325. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +2941 -0
  326. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +128 -0
  327. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +138 -0
  328. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sampling.h +107 -0
  329. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +14 -0
  330. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +888 -0
  331. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/speculative.h +28 -0
  332. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +28 -0
  333. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unicode-data.h +20 -0
  334. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unicode.h +66 -0
  335. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +802 -0
  336. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  337. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +101 -0
  338. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  339. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  340. package/jest/mock.js +203 -203
  341. package/lib/commonjs/NativeRNLlama.js +1 -2
  342. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  343. package/lib/commonjs/chat.js.map +1 -1
  344. package/lib/commonjs/grammar.js +12 -31
  345. package/lib/commonjs/grammar.js.map +1 -1
  346. package/lib/commonjs/index.js +47 -47
  347. package/lib/commonjs/index.js.map +1 -1
  348. package/lib/commonjs/package.json +1 -0
  349. package/lib/module/NativeRNLlama.js +2 -0
  350. package/lib/module/NativeRNLlama.js.map +1 -1
  351. package/lib/module/chat.js +2 -0
  352. package/lib/module/chat.js.map +1 -1
  353. package/lib/module/grammar.js +14 -31
  354. package/lib/module/grammar.js.map +1 -1
  355. package/lib/module/index.js +47 -45
  356. package/lib/module/index.js.map +1 -1
  357. package/lib/module/package.json +1 -0
  358. package/lib/typescript/NativeRNLlama.d.ts +6 -4
  359. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  360. package/lib/typescript/index.d.ts.map +1 -1
  361. package/llama-rn.podspec +48 -48
  362. package/package.json +233 -233
  363. package/src/NativeRNLlama.ts +426 -424
  364. package/src/chat.ts +44 -44
  365. package/src/grammar.ts +854 -854
  366. package/src/index.ts +495 -485
@@ -1,1404 +1,1729 @@
1
1
  #include "llama-context.h"
2
2
 
3
3
  #include "llama-impl.h"
4
+ #include "llama-io.h"
4
5
  #include "llama-mmap.h"
6
+ #include "llama-model.h"
7
+ #include "llama-kv-cache.h"
5
8
 
6
9
  #include <cassert>
7
- #include <cmath>
8
10
  #include <cstring>
9
11
  #include <stdexcept>
12
+ #include <cinttypes>
10
13
 
11
- void llama_set_k_shift(struct llama_context & lctx) {
12
- const int64_t kv_size = lctx.kv_self.size;
13
-
14
- assert(lm_ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
15
-
16
- int32_t * data = (int32_t *) lctx.inp_K_shift->data;
14
+ //
15
+ // llama_context
16
+ //
17
17
 
18
- for (int i = 0; i < kv_size; ++i) {
19
- data[i] = lctx.kv_self.cells[i].delta;
20
- }
21
- }
18
+ llama_context::llama_context(
19
+ const llama_model & model,
20
+ llama_context_params params) :
21
+ model(model) {
22
+ LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
22
23
 
23
- void llama_set_s_copy(struct llama_context & lctx) {
24
- const int64_t kv_size = lctx.kv_self.size;
24
+ t_start_us = model.t_start_us;
25
+ t_load_us = model.t_load_us;
25
26
 
26
- assert(lm_ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
27
+ const auto & hparams = model.hparams;
27
28
 
28
- int32_t * data = (int32_t *) lctx.inp_s_copy->data;
29
+ cparams.n_seq_max = std::max(1u, params.n_seq_max);
30
+ cparams.n_threads = params.n_threads;
31
+ cparams.n_threads_batch = params.n_threads_batch;
32
+ cparams.yarn_ext_factor = params.yarn_ext_factor;
33
+ cparams.yarn_attn_factor = params.yarn_attn_factor;
34
+ cparams.yarn_beta_fast = params.yarn_beta_fast;
35
+ cparams.yarn_beta_slow = params.yarn_beta_slow;
36
+ cparams.defrag_thold = params.defrag_thold;
37
+ cparams.embeddings = params.embeddings;
38
+ cparams.offload_kqv = params.offload_kqv;
39
+ cparams.flash_attn = params.flash_attn;
40
+ cparams.no_perf = params.no_perf;
41
+ cparams.pooling_type = params.pooling_type;
42
+ cparams.warmup = false;
29
43
 
30
- for (int i = 0; i < kv_size; ++i) {
31
- data[i] = lctx.kv_self.cells[i].src;
32
- }
33
- }
44
+ cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
45
+ cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
46
+ cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
34
47
 
35
- // llama input
48
+ cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
49
+ hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
50
+ hparams.n_ctx_train;
36
51
 
37
- static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
38
- // TODO move to hparams if a T5 variant appears that uses a different value
39
- const int64_t max_distance = 128;
52
+ cparams.cb_eval = params.cb_eval;
53
+ cparams.cb_eval_user_data = params.cb_eval_user_data;
40
54
 
41
- if (bidirectional) {
42
- n_buckets >>= 1;
55
+ auto rope_scaling_type = params.rope_scaling_type;
56
+ if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
57
+ rope_scaling_type = hparams.rope_scaling_type_train;
43
58
  }
44
59
 
45
- const int64_t max_exact = n_buckets >> 1;
46
-
47
- int32_t relative_position = x - y;
48
- int32_t relative_bucket = 0;
49
- if (bidirectional) {
50
- relative_bucket += (relative_position > 0) * n_buckets;
51
- relative_position = abs(relative_position);
52
- } else {
53
- relative_position = -std::min<int32_t>(relative_position, 0);
60
+ if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
61
+ cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
54
62
  }
55
- int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
56
- relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
57
- relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
58
- return relative_bucket;
59
- }
60
-
61
- void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
62
- //
63
- // set input data
64
- //
65
-
66
- const auto & hparams = lctx.model.hparams;
67
- const auto & cparams = lctx.cparams;
68
- const auto & kv_self = lctx.kv_self;
69
-
70
- if (ubatch.token) {
71
- const int64_t n_tokens = ubatch.n_tokens;
72
63
 
73
- lm_ggml_backend_tensor_set(lctx.inp_tokens, ubatch.token, 0, n_tokens*lm_ggml_element_size(lctx.inp_tokens));
64
+ if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
65
+ cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
74
66
  }
75
67
 
76
- if (ubatch.embd) {
77
- const int64_t n_embd = hparams.n_embd;
78
- const int64_t n_tokens = ubatch.n_tokens;
68
+ cparams.yarn_attn_factor *= hparams.rope_attn_factor;
79
69
 
80
- lm_ggml_backend_tensor_set(lctx.inp_embd, ubatch.embd, 0, n_tokens*n_embd*lm_ggml_element_size(lctx.inp_embd));
70
+ if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
71
+ if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
72
+ cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
73
+ } else {
74
+ cparams.pooling_type = hparams.pooling_type;
75
+ }
81
76
  }
82
77
 
83
- if (ubatch.pos && lctx.inp_pos) {
84
- const int64_t n_tokens = ubatch.n_tokens;
85
- auto n_pos = lctx.n_pos_per_token;
86
- lm_ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*n_pos*lm_ggml_element_size(lctx.inp_pos));
78
+ if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
79
+ cparams.causal_attn = hparams.causal_attn;
80
+ } else {
81
+ cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
87
82
  }
88
83
 
89
- if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
90
- //LM_GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
84
+ // with causal attention, the batch size is limited by the context size
85
+ cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
91
86
 
92
- if (!lctx.inp_out_ids) {
93
- LLAMA_LOG_WARN("%s: 'lctx.inp_out_ids' is not created\n", __func__);
94
- } else {
95
- const int64_t n_tokens = ubatch.n_tokens;
87
+ // the batch has to be at least LM_GGML_KQ_MASK_PAD because we will be padding the KQ_mask
88
+ // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. lm_ggml_flash_attn_ext)
89
+ // ref: https://github.com/ggerganov/llama.cpp/pull/5021
90
+ // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
91
+ if (cparams.n_batch < LM_GGML_KQ_MASK_PAD) {
92
+ LLAMA_LOG_WARN("%s: n_batch is less than LM_GGML_KQ_MASK_PAD - increasing to %d\n", __func__, LM_GGML_KQ_MASK_PAD);
93
+ cparams.n_batch = LM_GGML_KQ_MASK_PAD;
94
+ }
96
95
 
97
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
98
- int32_t * data = (int32_t *) lctx.inp_out_ids->data;
96
+ cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
99
97
 
100
- if (lctx.n_outputs == n_tokens) {
101
- for (int i = 0; i < n_tokens; ++i) {
102
- data[i] = i;
103
- }
104
- } else if (ubatch.output) {
105
- int32_t n_outputs = 0;
106
- for (int i = 0; i < n_tokens; ++i) {
107
- if (ubatch.output[i]) {
108
- data[n_outputs++] = i;
109
- }
110
- }
111
- // the graph needs to have been passed the correct number of outputs
112
- LM_GGML_ASSERT(lctx.n_outputs == n_outputs);
113
- } else if (lctx.n_outputs == 1) {
114
- // only keep last output
115
- data[0] = n_tokens - 1;
116
- } else {
117
- LM_GGML_ASSERT(lctx.n_outputs == 0);
118
- }
119
- }
120
- }
98
+ const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
121
99
 
122
- LM_GGML_ASSERT(
123
- // (!a || b) is a logical implication (a -> b)
124
- // !hparams.causal_attn -> !cparams.causal_attn
125
- (hparams.causal_attn || !cparams.causal_attn) &&
126
- "causal attention is not supported by this model"
127
- );
100
+ LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
101
+ LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
102
+ LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
103
+ LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
104
+ LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
105
+ LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
106
+ LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
107
+ LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
108
+ LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
128
109
 
129
- if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
130
- // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
131
- if (cparams.causal_attn && !lctx.is_encoding) {
132
- const int64_t n_kv = kv_self.n;
133
- const int64_t n_tokens = ubatch.n_tokens;
134
- const int64_t n_seq_tokens = ubatch.n_seq_tokens;
135
- const int64_t n_seqs = ubatch.n_seqs;
110
+ if (n_ctx_per_seq < hparams.n_ctx_train) {
111
+ LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
112
+ __func__, n_ctx_per_seq, hparams.n_ctx_train);
113
+ }
136
114
 
115
+ if (n_ctx_per_seq > hparams.n_ctx_train) {
116
+ LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
117
+ __func__, n_ctx_per_seq, hparams.n_ctx_train);
118
+ }
137
119
 
138
- float * data = nullptr;
139
- float * data_swa = nullptr;
120
+ logits_all = params.logits_all;
140
121
 
141
- if (lctx.inp_KQ_mask) {
142
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
143
- data = (float *) lctx.inp_KQ_mask->data;
122
+ if (!hparams.vocab_only) {
123
+ // GPU backends
124
+ for (auto * dev : model.devices) {
125
+ lm_ggml_backend_t backend = lm_ggml_backend_dev_init(dev, nullptr);
126
+ if (backend == nullptr) {
127
+ throw std::runtime_error(format("failed to initialize %s backend", lm_ggml_backend_dev_name(dev)));
144
128
  }
129
+ backends.emplace_back(backend);
130
+ }
145
131
 
146
- if (lctx.inp_KQ_mask_swa) {
147
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
148
- data_swa = (float *) lctx.inp_KQ_mask_swa->data;
132
+ // add ACCEL backends (such as BLAS)
133
+ for (size_t i = 0; i < lm_ggml_backend_dev_count(); ++i) {
134
+ lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_get(i);
135
+ if (lm_ggml_backend_dev_type(dev) == LM_GGML_BACKEND_DEVICE_TYPE_ACCEL) {
136
+ lm_ggml_backend_t backend = lm_ggml_backend_dev_init(dev, nullptr);
137
+ if (backend == nullptr) {
138
+ throw std::runtime_error(format("failed to initialize %s backend", lm_ggml_backend_dev_name(dev)));
139
+ }
140
+ backends.emplace_back(backend);
149
141
  }
142
+ }
150
143
 
151
- // For causal attention, use only the previous KV cells
152
- // of the correct sequence for each token of the ubatch.
153
- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
154
- for (int h = 0; h < 1; ++h) {
155
- for (int s = 0; s < n_seqs; ++s) {
156
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
157
-
158
- for (int j = 0; j < n_seq_tokens; ++j) {
159
- const llama_pos pos = ubatch.pos[s*n_seq_tokens + j];
160
-
161
- for (int i = 0; i < n_kv; ++i) {
162
- float f;
163
- if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
164
- f = -INFINITY;
165
- } else {
166
- if (hparams.use_alibi) {
167
- f = -std::abs(kv_self.cells[i].pos - pos);
168
- } else {
169
- f = 0.0f;
170
- }
171
- }
172
-
173
- if (data) {
174
- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
175
- }
176
-
177
- // may need to cut off old tokens for sliding window
178
- if (data_swa) {
179
- if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
180
- f = -INFINITY;
181
- }
182
- data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
183
- }
184
- }
185
- }
144
+ // add CPU backend
145
+ backend_cpu = lm_ggml_backend_init_by_type(LM_GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
146
+ if (backend_cpu == nullptr) {
147
+ throw std::runtime_error("failed to initialize CPU backend");
148
+ }
149
+ backends.emplace_back(backend_cpu);
150
+
151
+ // create a list of the set_n_threads functions in the backends
152
+ for (auto & backend : backends) {
153
+ lm_ggml_backend_dev_t dev = lm_ggml_backend_get_device(backend.get());
154
+ lm_ggml_backend_reg_t reg = dev ? lm_ggml_backend_dev_backend_reg(dev) : nullptr;
155
+ if (reg) {
156
+ auto lm_ggml_backend_set_n_threads_fn = (lm_ggml_backend_set_n_threads_t) lm_ggml_backend_reg_get_proc_address(reg, "lm_ggml_backend_set_n_threads");
157
+ if (lm_ggml_backend_set_n_threads_fn) {
158
+ set_n_threads_fns.emplace_back(backend.get(), lm_ggml_backend_set_n_threads_fn);
186
159
  }
160
+ }
161
+ }
187
162
 
188
- if (data) {
189
- for (int i = n_tokens; i < LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD); ++i) {
190
- for (int j = 0; j < n_kv; ++j) {
191
- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
192
- }
193
- }
194
- }
163
+ llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data);
195
164
 
196
- if (data_swa) {
197
- for (int i = n_tokens; i < LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD); ++i) {
198
- for (int j = 0; j < n_kv; ++j) {
199
- data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
200
- }
201
- }
202
- }
165
+ // graph outputs buffer
166
+ {
167
+ // resized during inference when a batch uses more outputs
168
+ if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) {
169
+ throw std::runtime_error("failed to reserve initial output buffer");
203
170
  }
204
- } else {
205
- const int64_t n_tokens = ubatch.n_tokens;
206
- const int64_t n_seq_tokens = ubatch.n_seq_tokens;
207
- const int64_t n_seqs = ubatch.n_seqs;
208
- // when using kv cache, the mask needs to match the kv cache size
209
- const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
210
-
211
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
212
-
213
- float * data = (float *) lctx.inp_KQ_mask->data;
214
-
215
- for (int h = 0; h < 1; ++h) {
216
- for (int s1 = 0; s1 < n_seqs; ++s1) {
217
- const llama_seq_id seq_id = ubatch.seq_id[s1][0];
218
-
219
- for (int j = 0; j < n_seq_tokens; ++j) {
220
- const int32_t tj = s1*n_seq_tokens + j;
221
-
222
- for (int s0 = 0; s0 < n_seqs; ++s0) {
223
- for (int i = 0; i < n_seq_tokens; ++i) {
224
- const int32_t ti = s0*n_seq_tokens + i;
225
- float f = -INFINITY;
226
-
227
- for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) {
228
- if (ubatch.seq_id[s0][s] == seq_id) {
229
- if (hparams.use_alibi) {
230
- f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]);
231
- } else {
232
- f = 0.0f;
233
- }
234
- break;
235
- }
236
- }
237
-
238
- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
239
- }
240
- }
241
171
 
242
- for (int i = n_tokens; i < n_stride; ++i) {
243
- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
244
- }
245
- }
246
- }
247
- }
172
+ LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__,
173
+ lm_ggml_backend_buffer_name (buf_output.get()),
174
+ lm_ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0);
248
175
  }
249
176
  }
250
177
 
251
- if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
252
- const int64_t n_tokens = ubatch.n_tokens;
253
- const int64_t n_seq_tokens = ubatch.n_seq_tokens;
254
- const int64_t n_seqs = ubatch.n_seqs;
178
+ // init the memory module
179
+ // TODO: for now, always create a unified KV cache
180
+ if (!hparams.vocab_only) {
181
+ kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
255
182
 
256
- LM_GGML_ASSERT(lctx.inp_mean);
257
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
183
+ LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
258
184
 
259
- float * data = (float *) lctx.inp_mean->data;
260
- memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * lm_ggml_element_size(lctx.inp_mean));
185
+ cparams.n_ctx = LM_GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
261
186
 
262
- std::vector<uint64_t> sum(n_tokens, 0);
187
+ LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
263
188
 
264
- for (int s = 0; s < n_seqs; ++s) {
265
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
189
+ uint32_t kv_size = cparams.n_ctx;
190
+ lm_ggml_type type_k = params.type_k;
191
+ lm_ggml_type type_v = params.type_v;
266
192
 
267
- // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
268
- LM_GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
269
-
270
- sum[seq_id] += ubatch.n_seq_tokens;
193
+ if (llama_model_is_recurrent(&model)) {
194
+ // Mamba needs at least as many KV cells as there are sequences kept at any time
195
+ kv_size = std::max((uint32_t) 1, params.n_seq_max);
196
+ // it's probably best to keep as much precision as possible for the states
197
+ type_k = LM_GGML_TYPE_F32; // required by lm_ggml_ssm_conv for Mamba's conv_states
198
+ type_v = LM_GGML_TYPE_F32; // required by lm_ggml_ssm_scan for Mamba's ssm_states
271
199
  }
272
200
 
273
- std::vector<float> div(n_tokens, 0.0f);
274
- for (int i = 0; i < n_tokens; ++i) {
275
- const uint64_t s = sum[i];
276
- if (s > 0) {
277
- div[i] = 1.0f/float(s);
278
- }
201
+ LM_GGML_ASSERT(hparams.n_embd_head_k % lm_ggml_blck_size(type_k) == 0);
202
+ LM_GGML_ASSERT(hparams.n_embd_head_v % lm_ggml_blck_size(type_v) == 0);
203
+
204
+ if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
205
+ throw std::runtime_error("failed to initialize self-attention cache");
279
206
  }
280
207
 
281
- for (int s = 0; s < n_seqs; ++s) {
282
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
208
+ {
209
+ const size_t memory_size_k = kv_self->size_k_bytes();
210
+ const size_t memory_size_v = kv_self->size_v_bytes();
283
211
 
284
- for (int i = 0; i < n_seq_tokens; ++i) {
285
- data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
286
- }
212
+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
213
+ (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
214
+ lm_ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
215
+ lm_ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
287
216
  }
288
217
  }
289
218
 
290
- if (cparams.embeddings && (
291
- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
292
- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
293
- const int64_t n_tokens = ubatch.n_tokens;
294
- const int64_t n_seq_tokens = ubatch.n_seq_tokens;
295
- const int64_t n_seqs = ubatch.n_seqs;
219
+ // init backends
220
+ if (!hparams.vocab_only) {
221
+ LLAMA_LOG_DEBUG("%s: enumerating backends\n", __func__);
296
222
 
297
- LM_GGML_ASSERT(lctx.inp_cls);
298
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
223
+ backend_buft.clear();
224
+ backend_ptrs.clear();
299
225
 
300
- uint32_t * data = (uint32_t *) lctx.inp_cls->data;
301
- memset(lctx.inp_cls->data, 0, n_tokens * lm_ggml_element_size(lctx.inp_cls));
226
+ for (auto & backend : backends) {
227
+ auto * buft = lm_ggml_backend_get_default_buffer_type(backend.get());
228
+ auto backend_type = lm_ggml_backend_dev_type(lm_ggml_backend_get_device(backend.get()));
302
229
 
303
- for (int s = 0; s < n_seqs; ++s) {
304
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
305
-
306
- // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
307
- LM_GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
308
-
309
- for (int i = 0; i < n_seq_tokens; ++i) {
310
- const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
311
-
312
- if (pos == 0) {
313
- data[seq_id] = s*n_seq_tokens + i;
230
+ if (backend_type == LM_GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) {
231
+ // use the host buffer of the first device CPU for faster transfer of the intermediate state
232
+ auto * dev = model.devices[0];
233
+ auto * host_buft = lm_ggml_backend_dev_host_buffer_type(dev);
234
+ if (host_buft) {
235
+ buft = host_buft;
314
236
  }
315
237
  }
316
- }
317
- }
318
-
319
- if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
320
- const int64_t n_tokens = ubatch.n_tokens;
321
- const int64_t n_seq_tokens = ubatch.n_seq_tokens;
322
- const int64_t n_seqs = ubatch.n_seqs;
323
238
 
324
- LM_GGML_ASSERT(lctx.inp_cls);
325
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
239
+ backend_buft.push_back(buft);
240
+ backend_ptrs.push_back(backend.get());
241
+ }
326
242
 
327
- uint32_t * data = (uint32_t *) lctx.inp_cls->data;
328
- memset(lctx.inp_cls->data, 0, n_tokens * lm_ggml_element_size(lctx.inp_cls));
243
+ LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
329
244
 
330
- std::vector<int> last_pos(n_tokens, -1);
331
- std::vector<int> last_row(n_tokens, -1);
245
+ const size_t max_nodes = this->graph_max_nodes();
332
246
 
333
- for (int s = 0; s < n_seqs; ++s) {
334
- const llama_seq_id seq_id = ubatch.seq_id[s][0];
247
+ LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
335
248
 
336
- // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
337
- LM_GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
249
+ // buffer used to store the computation graph and the tensor meta data
250
+ buf_compute_meta.resize(lm_ggml_tensor_overhead()*max_nodes + lm_ggml_graph_overhead_custom(max_nodes, false));
338
251
 
339
- for (int i = 0; i < n_seq_tokens; ++i) {
340
- const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
252
+ // TODO: move these checks to lm_ggml_backend_sched
253
+ // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
254
+ bool pipeline_parallel =
255
+ model.n_devices() > 1 &&
256
+ model.params.n_gpu_layers > (int) model.hparams.n_layer &&
257
+ model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
258
+ cparams.offload_kqv &&
259
+ !model.has_tensor_overrides();
341
260
 
342
- if (pos >= last_pos[seq_id]) {
343
- last_pos[seq_id] = pos;
344
- last_row[seq_id] = s*n_seq_tokens + i;
261
+ // pipeline parallelism requires support for async compute and events in all devices
262
+ if (pipeline_parallel) {
263
+ for (auto & backend : backends) {
264
+ auto dev_type = lm_ggml_backend_dev_type(lm_ggml_backend_get_device(backend.get()));
265
+ if (dev_type == LM_GGML_BACKEND_DEVICE_TYPE_CPU) {
266
+ // ignore CPU backend
267
+ continue;
268
+ }
269
+ auto * dev = lm_ggml_backend_get_device(backend.get());
270
+ lm_ggml_backend_dev_props props;
271
+ lm_ggml_backend_dev_get_props(dev, &props);
272
+ if (!props.caps.async || !props.caps.events) {
273
+ // device does not support async compute or events
274
+ pipeline_parallel = false;
275
+ break;
345
276
  }
346
277
  }
347
278
  }
348
279
 
349
- for (int i = 0; i < n_tokens; ++i) {
350
- if (last_row[i] >= 0) {
351
- data[i] = last_row[i];
352
- }
280
+ sched.reset(lm_ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
281
+
282
+ if (pipeline_parallel) {
283
+ LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, lm_ggml_backend_sched_get_n_copies(sched.get()));
353
284
  }
354
285
  }
355
286
 
356
- if (kv_self.recurrent) {
357
- const int64_t n_kv = kv_self.n;
287
+ // reserve worst-case graph
288
+ if (!hparams.vocab_only) {
289
+ const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
290
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
358
291
 
359
- if (lctx.inp_s_mask) {
360
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
361
- float * data = (float *) lctx.inp_s_mask->data;
292
+ llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
362
293
 
363
- // clear unused states
364
- for (int i = 0; i < n_kv; ++i) {
365
- const uint32_t cell_id = i + kv_self.head;
366
- llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
294
+ // restore later
295
+ // TODO: something cleaner
296
+ const auto n_outputs_save = n_outputs;
367
297
 
368
- data[i] = (float) (kv_cell.src >= 0);
369
-
370
- // only clear once
371
- if (kv_cell.src < 0) {
372
- kv_cell.src = cell_id;
373
- }
374
- }
375
- }
298
+ LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
376
299
 
377
- if (lctx.inp_s_copy) {
378
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
379
- int32_t * data = (int32_t *) lctx.inp_s_copy->data;
300
+ int n_splits_pp = -1;
301
+ int n_nodes_pp = -1;
380
302
 
381
- // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
382
- for (uint32_t i = 0; i < n_kv; ++i) {
383
- const uint32_t cell_id = i + kv_self.head;
384
- llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
303
+ int n_splits_tg = -1;
304
+ int n_nodes_tg = -1;
385
305
 
386
- // prevent out-of-bound sources
387
- if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
388
- kv_cell.src = cell_id;
389
- }
306
+ // simulate full KV cache
307
+ kv_self->n = kv_self->size;
390
308
 
391
- data[i] = kv_cell.src;
309
+ cross.v_embd.clear();
392
310
 
393
- // ensure copy only happens once
394
- if (kv_cell.src != (int32_t) cell_id) {
395
- kv_cell.src = cell_id;
396
- }
397
- }
398
- }
399
- }
311
+ // reserve pp graph first so that buffers are only allocated once
312
+ {
313
+ llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
400
314
 
401
- if (lctx.inp_pos_bucket) {
402
- const int64_t n_tokens = ubatch.n_tokens;
315
+ // max number of outputs
316
+ n_outputs = ubatch_pp.n_tokens;
403
317
 
404
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));
405
- LM_GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
318
+ LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
406
319
 
407
- int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;
320
+ auto * gf = graph_init();
321
+ graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
408
322
 
409
- if (!lctx.is_encoding) {
410
- const int64_t n_kv = kv_self.n;
411
- for (int h = 0; h < 1; ++h) {
412
- for (int j = 0; j < n_tokens; ++j) {
413
- for (int i = 0; i < n_kv; ++i) {
414
- data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
415
- }
416
- }
417
- }
418
- } else {
419
- for (int h = 0; h < 1; ++h) {
420
- for (int j = 0; j < n_tokens; ++j) {
421
- for (int i = 0; i < n_tokens; ++i) {
422
- data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
423
- }
424
- }
323
+ if (!lm_ggml_backend_sched_reserve(sched.get(), gf)) {
324
+ throw std::runtime_error("failed to allocate compute pp buffers");
425
325
  }
426
- }
427
- }
428
326
 
429
- if (!lctx.is_encoding && lctx.inp_embd_enc) {
430
- assert(lctx.inp_embd_enc->type == LM_GGML_TYPE_F32);
431
- assert((size_t) lm_ggml_nelements(lctx.inp_embd_enc) == lctx.embd_enc.size());
327
+ n_splits_pp = lm_ggml_backend_sched_get_n_splits(sched.get());
328
+ n_nodes_pp = lm_ggml_graph_n_nodes(gf);
329
+ }
432
330
 
433
- lm_ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.embd_enc.data(), 0, lm_ggml_nbytes(lctx.inp_embd_enc));
434
- }
331
+ // reserve with tg graph to get the number of splits and nodes
332
+ {
333
+ llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
435
334
 
436
- if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) {
437
- const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd;
438
- const int64_t n_tokens = ubatch.n_tokens;
335
+ n_outputs = ubatch_tg.n_tokens;
439
336
 
440
- LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
441
- LM_GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
337
+ LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
442
338
 
443
- float * data = (float *) lctx.inp_KQ_mask_cross->data;
339
+ auto * gf = graph_init();
340
+ graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
444
341
 
445
- for (int h = 0; h < 1; ++h) {
446
- for (int j = 0; j < n_tokens; ++j) {
447
- for (int i = 0; i < n_output_enc; ++i) {
448
- float f = -INFINITY;
449
- for (int s = 0; s < ubatch.n_seq_id[j]; ++s) {
450
- const llama_seq_id seq_id = ubatch.seq_id[j][s];
451
- if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) {
452
- f = 0.0f;
453
- }
454
- }
455
- data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f;
456
- }
342
+ if (!lm_ggml_backend_sched_reserve(sched.get(), gf)) {
343
+ throw std::runtime_error("failed to allocate compute tg buffers");
457
344
  }
458
345
 
459
- for (int i = n_tokens; i < LM_GGML_PAD(n_tokens, LM_GGML_KQ_MASK_PAD); ++i) {
460
- for (int j = 0; j < n_output_enc; ++j) {
461
- data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY;
462
- }
463
- }
346
+ n_splits_tg = lm_ggml_backend_sched_get_n_splits(sched.get());
347
+ n_nodes_tg = lm_ggml_graph_n_nodes(gf);
464
348
  }
465
- }
466
- }
467
349
 
468
- // llama output
350
+ // reserve again with pp graph to avoid ggml-alloc reallocations during inference
351
+ {
352
+ llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
469
353
 
470
- size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
471
- const auto & cparams = lctx.cparams;
472
- const auto & hparams = lctx.model.hparams;
473
- const auto & vocab = lctx.model.vocab;
354
+ n_outputs = ubatch_pp.n_tokens;
474
355
 
475
- const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
356
+ LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
476
357
 
477
- const auto n_batch = cparams.n_batch;
478
- const auto n_vocab = vocab.n_tokens();
479
- const auto n_embd = hparams.n_embd;
480
-
481
- // TODO: use a per-batch flag for logits presence instead
482
- const bool has_logits = !cparams.embeddings;
483
- const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
484
-
485
- const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
486
- const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
358
+ auto * gf = graph_init();
359
+ graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
487
360
 
488
- if (lctx.output_ids.empty()) {
489
- // init, never resized afterwards
490
- lctx.output_ids.resize(n_batch);
491
- }
361
+ if (!lm_ggml_backend_sched_reserve(sched.get(), gf)) {
362
+ throw std::runtime_error("failed to allocate compute pp buffers");
363
+ }
364
+ }
492
365
 
493
- const size_t prev_size = lctx.buf_output ? lm_ggml_backend_buffer_get_size(lctx.buf_output.get()) : 0;
494
- const size_t new_size = (logits_size + embd_size) * sizeof(float);
366
+ n_outputs = n_outputs_save;
495
367
 
496
- // alloc only when more than the current capacity is required
497
- // TODO: also consider shrinking the buffer
498
- if (!lctx.buf_output || prev_size < new_size) {
499
- if (lctx.buf_output) {
500
- #ifndef NDEBUG
501
- // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
502
- LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
503
- #endif
504
- lctx.buf_output = nullptr;
505
- lctx.logits = nullptr;
506
- lctx.embd = nullptr;
368
+ for (size_t i = 0; i < backend_ptrs.size(); ++i) {
369
+ lm_ggml_backend_t backend = backend_ptrs[i];
370
+ lm_ggml_backend_buffer_type_t buft = backend_buft[i];
371
+ size_t size = lm_ggml_backend_sched_get_buffer_size(sched.get(), backend);
372
+ if (size > 1) {
373
+ LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
374
+ lm_ggml_backend_buft_name(buft),
375
+ size / 1024.0 / 1024.0);
376
+ }
507
377
  }
508
378
 
509
- auto * buft = lm_ggml_backend_cpu_buffer_type();
510
- // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
511
- auto * output_dev = lctx.model.dev_output();
512
- auto * output_dev_host_buft = output_dev ? lm_ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
513
- if (output_dev_host_buft) {
514
- buft = output_dev_host_buft;
379
+ if (n_nodes_pp == n_nodes_tg) {
380
+ LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp);
381
+ } else {
382
+ LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
515
383
  }
516
- lctx.buf_output.reset(lm_ggml_backend_buft_alloc_buffer(buft, new_size));
517
- if (lctx.buf_output == nullptr) {
518
- LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
519
- return 0;
384
+
385
+ if (n_splits_pp == n_splits_tg) {
386
+ LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
387
+ } else {
388
+ LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
520
389
  }
521
390
  }
522
-
523
- float * output_base = (float *) lm_ggml_backend_buffer_get_base(lctx.buf_output.get());
524
-
525
- lctx.logits = has_logits ? output_base : nullptr;
526
- lctx.embd = has_embd ? output_base + logits_size : nullptr;
527
-
528
- lctx.output_size = n_outputs_max;
529
- lctx.logits_size = logits_size;
530
- lctx.embd_size = embd_size;
531
-
532
- // set all ids as invalid (negative)
533
- std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);
534
-
535
- lm_ggml_backend_buffer_clear(lctx.buf_output.get(), 0);
536
-
537
- lctx.n_outputs = 0;
538
-
539
- return n_outputs_max;
540
391
  }
541
392
 
542
- void llama_output_reorder(struct llama_context & ctx) {
543
- std::vector<size_t> & out_ids = ctx.sbatch.out_ids;
544
- if (!out_ids.empty()) {
545
- const uint32_t n_vocab = ctx.model.vocab.n_tokens();
546
- const uint32_t n_embd = ctx.model.hparams.n_embd;
393
+ llama_context::~llama_context() = default;
547
394
 
548
- const int32_t n_outputs = ctx.n_outputs;
549
- LM_GGML_ASSERT((size_t) n_outputs == out_ids.size());
395
+ void llama_context::synchronize() {
396
+ lm_ggml_backend_sched_synchronize(sched.get());
550
397
 
551
- // TODO: is there something more efficient which also minimizes swaps?
552
- // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
553
- for (int32_t i = 0; i < n_outputs - 1; ++i) {
554
- int32_t j_min = i;
555
- for (int32_t j = i + 1; j < n_outputs; ++j) {
556
- if (out_ids[j] < out_ids[j_min]) {
557
- j_min = j;
558
- }
559
- }
560
- if (j_min == i) { continue; }
561
- std::swap(out_ids[i], out_ids[j_min]);
562
- if (ctx.logits_size > 0) {
563
- for (uint32_t k = 0; k < n_vocab; k++) {
564
- std::swap(ctx.logits[i*n_vocab + k], ctx.logits[j_min*n_vocab + k]);
565
- }
566
- }
567
- if (ctx.embd_size > 0) {
568
- for (uint32_t k = 0; k < n_embd; k++) {
569
- std::swap(ctx.embd[i*n_embd + k], ctx.embd[j_min*n_embd + k]);
570
- }
571
- }
398
+ // FIXME: if multiple single tokens are evaluated without a synchronization,
399
+ // the stats will be added to the prompt evaluation stats
400
+ // this should only happen when using batch size 1 to evaluate a batch
401
+
402
+ // add the evaluation to the stats
403
+ if (n_queued_tokens == 1) {
404
+ if (!cparams.no_perf) {
405
+ t_eval_us += lm_ggml_time_us() - t_compute_start_us;
572
406
  }
573
- std::fill(ctx.output_ids.begin(), ctx.output_ids.end(), -1);
574
- for (int32_t i = 0; i < n_outputs; ++i) {
575
- ctx.output_ids[out_ids[i]] = i;
407
+ n_eval++;
408
+ } else if (n_queued_tokens > 1) {
409
+ if (!cparams.no_perf) {
410
+ t_p_eval_us += lm_ggml_time_us() - t_compute_start_us;
576
411
  }
577
- out_ids.clear();
412
+ n_p_eval += n_queued_tokens;
578
413
  }
579
- }
580
414
 
581
- //
582
- // interface implementation
583
- //
415
+ // get a more accurate load time, upon first eval
416
+ if (n_queued_tokens > 0 && !has_evaluated_once) {
417
+ t_load_us = lm_ggml_time_us() - t_start_us;
418
+ has_evaluated_once = true;
419
+ }
584
420
 
585
- void llama_free(struct llama_context * ctx) {
586
- delete ctx;
421
+ n_queued_tokens = 0;
422
+ t_compute_start_us = 0;
587
423
  }
588
424
 
589
- uint32_t llama_n_ctx(const struct llama_context * ctx) {
590
- return ctx->cparams.n_ctx;
425
+ const llama_model & llama_context::get_model() const {
426
+ return model;
591
427
  }
592
428
 
593
- uint32_t llama_n_batch(const struct llama_context * ctx) {
594
- return ctx->cparams.n_batch;
429
+ uint32_t llama_context::n_ctx() const {
430
+ return cparams.n_ctx;
595
431
  }
596
432
 
597
- uint32_t llama_n_ubatch(const struct llama_context * ctx) {
598
- return ctx->cparams.n_ubatch;
433
+ uint32_t llama_context::n_ctx_per_seq() const {
434
+ return cparams.n_ctx / cparams.n_seq_max;
599
435
  }
600
436
 
601
- uint32_t llama_n_seq_max(const struct llama_context * ctx) {
602
- return ctx->kv_self.size;
437
+ uint32_t llama_context::n_batch() const {
438
+ return cparams.n_batch;
603
439
  }
604
440
 
605
- const struct llama_model * llama_get_model(const struct llama_context * ctx) {
606
- return &ctx->model;
441
+ uint32_t llama_context::n_ubatch() const {
442
+ return cparams.n_ubatch;
607
443
  }
608
444
 
609
- enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
610
- return ctx->cparams.pooling_type;
445
+ uint32_t llama_context::n_seq_max() const {
446
+ return cparams.n_seq_max;
611
447
  }
612
448
 
613
- void llama_attach_threadpool(
614
- struct llama_context * ctx,
615
- lm_ggml_threadpool_t threadpool,
616
- lm_ggml_threadpool_t threadpool_batch) {
617
- ctx->threadpool = threadpool;
618
- ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
449
+ uint32_t llama_context::n_threads() const {
450
+ return cparams.n_threads;
619
451
  }
620
452
 
621
- void llama_detach_threadpool(struct llama_context * ctx) {
622
- ctx->threadpool = nullptr;
623
- ctx->threadpool_batch = nullptr;
453
+ uint32_t llama_context::n_threads_batch() const {
454
+ return cparams.n_threads_batch;
624
455
  }
625
456
 
626
- void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
627
- ctx->cparams.n_threads = n_threads;
628
- ctx->cparams.n_threads_batch = n_threads_batch;
457
+ llama_kv_cache * llama_context::get_kv_self() {
458
+ return kv_self.get();
629
459
  }
630
460
 
631
- int32_t llama_n_threads(struct llama_context * ctx) {
632
- return ctx->cparams.n_threads;
461
+ const llama_kv_cache * llama_context::get_kv_self() const {
462
+ return kv_self.get();
633
463
  }
634
464
 
635
- int32_t llama_n_threads_batch(struct llama_context * ctx) {
636
- return ctx->cparams.n_threads_batch;
637
- }
465
+ lm_ggml_tensor * llama_context::build_rope_shift(
466
+ lm_ggml_context * ctx0,
467
+ lm_ggml_tensor * cur,
468
+ lm_ggml_tensor * shift,
469
+ lm_ggml_tensor * factors,
470
+ float freq_base,
471
+ float freq_scale,
472
+ lm_ggml_backend_buffer * bbuf) const {
473
+ const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
474
+
475
+ const auto & yarn_ext_factor = cparams.yarn_ext_factor;
476
+ const auto & yarn_attn_factor = cparams.yarn_attn_factor;
477
+ const auto & yarn_beta_fast = cparams.yarn_beta_fast;
478
+ const auto & yarn_beta_slow = cparams.yarn_beta_slow;
479
+
480
+ const auto & hparams = model.hparams;
481
+
482
+ const auto & n_rot = hparams.n_rot;
483
+ const auto & rope_type = hparams.rope_type;
484
+
485
+ lm_ggml_tensor * tmp;
486
+
487
+ if (lm_ggml_is_quantized(cur->type)) {
488
+ // dequantize to f32 -> RoPE -> quantize back
489
+ tmp = lm_ggml_cast(ctx0, cur, LM_GGML_TYPE_F32);
490
+
491
+ if (bbuf) {
492
+ for (const auto & backend : backends) {
493
+ // Figure out which backend KV cache belongs to
494
+ if (lm_ggml_backend_supports_buft(backend.get(), lm_ggml_backend_buffer_get_type(bbuf))) {
495
+ lm_ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get());
496
+ break;
497
+ }
498
+ }
499
+ }
638
500
 
639
- void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
640
- ctx->abort_callback = abort_callback;
641
- ctx->abort_callback_data = abort_callback_data;
501
+ tmp = lm_ggml_rope_ext_inplace(ctx0, tmp,
502
+ shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
503
+ yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
642
504
 
643
- for (auto & backend : ctx->backends) {
644
- auto * reg = lm_ggml_backend_dev_backend_reg(lm_ggml_backend_get_device(backend.get()));
645
- auto * set_abort_callback_fn = (lm_ggml_backend_set_abort_callback_t) lm_ggml_backend_reg_get_proc_address(reg, "lm_ggml_backend_set_abort_callback");
646
- if (set_abort_callback_fn) {
647
- set_abort_callback_fn(backend.get(), ctx->abort_callback, ctx->abort_callback_data);
648
- }
505
+ tmp = lm_ggml_cpy(ctx0, tmp, cur);
506
+ } else {
507
+ // we rotate only the first n_rot dimensions
508
+ tmp = lm_ggml_rope_ext_inplace(ctx0, cur,
509
+ shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
510
+ yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
649
511
  }
650
- }
651
512
 
652
- void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
653
- ctx->cparams.embeddings = embeddings;
513
+ return tmp;
654
514
  }
655
515
 
656
- void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
657
- ctx->cparams.causal_attn = causal_attn;
658
- }
516
+ class llm_graph_input_k_shift : public llm_graph_input_i {
517
+ public:
518
+ llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
519
+ virtual ~llm_graph_input_k_shift() = default;
659
520
 
660
- void llama_synchronize(struct llama_context * ctx) {
661
- lm_ggml_backend_sched_synchronize(ctx->sched.get());
521
+ void set_input(const llama_ubatch * ubatch) override;
662
522
 
663
- // FIXME: if multiple single tokens are evaluated without a synchronization,
664
- // the stats will be added to the prompt evaluation stats
665
- // this should only happen when using batch size 1 to evaluate a batch
523
+ lm_ggml_tensor * k_shift; // I32 [kv_size]
666
524
 
667
- // add the evaluation to the stats
668
- if (ctx->n_queued_tokens == 1) {
669
- if (!ctx->cparams.no_perf) {
670
- ctx->t_eval_us += lm_ggml_time_us() - ctx->t_compute_start_us;
671
- }
672
- ctx->n_eval++;
673
- } else if (ctx->n_queued_tokens > 1) {
674
- if (!ctx->cparams.no_perf) {
675
- ctx->t_p_eval_us += lm_ggml_time_us() - ctx->t_compute_start_us;
525
+ const llama_kv_cache_unified * kv_self;
526
+ };
527
+
528
+ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
529
+ LM_GGML_UNUSED(ubatch);
530
+
531
+ if (k_shift) {
532
+ assert(lm_ggml_backend_buffer_is_host(k_shift->buffer));
533
+
534
+ int32_t * data = (int32_t *) k_shift->data;
535
+
536
+ for (uint32_t i = 0; i < kv_self->size; ++i) {
537
+ data[i] = kv_self->cells[i].delta;
676
538
  }
677
- ctx->n_p_eval += ctx->n_queued_tokens;
678
539
  }
540
+ }
679
541
 
680
- // get a more accurate load time, upon first eval
681
- if (ctx->n_queued_tokens > 0 && !ctx->has_evaluated_once) {
682
- ctx->t_load_us = lm_ggml_time_us() - ctx->t_start_us;
683
- ctx->has_evaluated_once = true;
684
- }
542
+ llm_graph_result_ptr llama_context::build_kv_self_shift(
543
+ lm_ggml_context * ctx0,
544
+ lm_ggml_cgraph * gf) const {
545
+ auto res = std::make_unique<llm_graph_result>();
685
546
 
686
- ctx->n_queued_tokens = 0;
687
- ctx->t_compute_start_us = 0;
688
- }
547
+ const auto & hparams = model.hparams;
689
548
 
690
- float * llama_get_logits(struct llama_context * ctx) {
691
- llama_synchronize(ctx);
549
+ const auto & n_layer = hparams.n_layer;
692
550
 
693
- // reorder logits for backward compatibility
694
- // TODO: maybe deprecate this
695
- llama_output_reorder(*ctx);
551
+ const auto & n_embd_head_k = hparams.n_embd_head_k;
552
+ //const auto & n_embd_head_v = hparams.n_embd_head_v;
696
553
 
697
- return ctx->logits;
698
- }
554
+ //LM_GGML_ASSERT(kv_self->size == n_ctx);
699
555
 
700
- float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
701
- int32_t j = -1;
556
+ auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
702
557
 
703
- llama_synchronize(ctx);
558
+ inp->k_shift = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, cparams.n_ctx);
559
+ lm_ggml_set_input(inp->k_shift);
704
560
 
705
- try {
706
- if (ctx->logits == nullptr) {
707
- throw std::runtime_error("no logits");
708
- }
561
+ for (uint32_t il = 0; il < n_layer; ++il) {
562
+ const int64_t n_head_kv = hparams.n_head_kv(il);
563
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
709
564
 
710
- if (i < 0) {
711
- j = ctx->n_outputs + i;
712
- if (j < 0) {
713
- throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
714
- }
715
- } else if ((size_t) i >= ctx->output_ids.size()) {
716
- throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size()));
717
- } else {
718
- j = ctx->output_ids[i];
719
- }
565
+ const bool is_swa = hparams.is_swa(il);
720
566
 
721
- if (j < 0) {
722
- throw std::runtime_error(format("batch.logits[%d] != true", i));
723
- }
724
- if (j >= ctx->n_outputs) {
725
- // This should not happen
726
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
727
- }
567
+ // note: the swa rope params could become part of the cparams in the future
568
+ // if we decide to make them configurable, like the non-sliding ones
569
+ const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
570
+ const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
728
571
 
729
- return ctx->logits + j*ctx->model.vocab.n_tokens();
730
- } catch (const std::exception & err) {
731
- LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
732
- #ifndef NDEBUG
733
- LM_GGML_ABORT("fatal error");
734
- #else
735
- return nullptr;
736
- #endif
737
- }
738
- }
572
+ lm_ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
739
573
 
740
- float * llama_get_embeddings(struct llama_context * ctx) {
741
- llama_synchronize(ctx);
574
+ lm_ggml_tensor * k =
575
+ lm_ggml_view_3d(ctx0, kv_self->k_l[il],
576
+ n_embd_head_k, n_head_kv, kv_self->size,
577
+ lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
578
+ lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
579
+ 0);
742
580
 
743
- // reorder embeddings for backward compatibility
744
- // TODO: maybe deprecate this
745
- llama_output_reorder(*ctx);
581
+ lm_ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
746
582
 
747
- return ctx->embd;
583
+ lm_ggml_build_forward_expand(gf, cur);
584
+ }
585
+
586
+ res->add_input(std::move(inp));
587
+
588
+ return res;
748
589
  }
749
590
 
750
- float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
751
- int32_t j = -1;
591
+ llm_graph_result_ptr llama_context::build_kv_self_defrag(
592
+ lm_ggml_context * ctx0,
593
+ lm_ggml_cgraph * gf) const {
594
+ auto res = std::make_unique<llm_graph_result>();
752
595
 
753
- llama_synchronize(ctx);
596
+ const auto & hparams = model.hparams;
754
597
 
755
- try {
756
- if (ctx->embd == nullptr) {
757
- throw std::runtime_error("no embeddings");
758
- }
598
+ const auto & ids = kv_self->defrag_info.ids;
759
599
 
760
- if (i < 0) {
761
- j = ctx->n_outputs + i;
762
- if (j < 0) {
763
- throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
764
- }
765
- } else if ((size_t) i >= ctx->output_ids.size()) {
766
- throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size()));
767
- } else {
768
- j = ctx->output_ids[i];
769
- }
600
+ #if 0
601
+ // CPU defrag
602
+ //
603
+ // TODO: optimizations are possible:
604
+ // - multiple threads
605
+ // - avoid copying to the host memory when already there
606
+ //
607
+ // likely not worth the effort, as we have lm_ggml_graph based defrag
608
+ //
770
609
 
771
- if (j < 0) {
772
- throw std::runtime_error(format("batch.logits[%d] != true", i));
773
- }
774
- if (j >= ctx->n_outputs) {
775
- // This should not happen
776
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
777
- }
610
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
611
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
778
612
 
779
- return ctx->embd + j*ctx->model.hparams.n_embd;
780
- } catch (const std::exception & err) {
781
- LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
782
- #ifndef NDEBUG
783
- LM_GGML_ABORT("fatal error");
784
- #else
785
- return nullptr;
786
- #endif
787
- }
788
- }
613
+ const uint32_t kv_size = size;
789
614
 
790
- float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
791
- llama_synchronize(ctx);
615
+ std::vector<uint8_t> buf_k;
616
+ std::vector<uint8_t> buf_v;
792
617
 
793
- auto it = ctx->embd_seq.find(seq_id);
794
- if (it == ctx->embd_seq.end()) {
795
- return nullptr;
796
- }
618
+ for (uint32_t il = 0; il < n_layer; ++il) {
619
+ const size_t k_size_row = lm_ggml_row_size(k_l[il]->type, n_embd_k_gqa);
620
+ const size_t k_size = lm_ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
797
621
 
798
- return it->second.data();
799
- }
622
+ const size_t v_size_el = lm_ggml_type_size(v_l[il]->type);
623
+ const size_t v_size = lm_ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
800
624
 
801
- // llama state API
625
+ buf_k.resize(k_size);
626
+ buf_v.resize(v_size);
802
627
 
803
- // deprecated
804
- size_t llama_get_state_size(struct llama_context * ctx) {
805
- return llama_state_get_size(ctx);
806
- }
628
+ lm_ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
629
+ lm_ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
807
630
 
808
- // deprecated
809
- size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
810
- return llama_state_get_data(ctx, dst, -1);
811
- }
631
+ // batch move [i, i+nm) to [id, id+nm)
632
+ // note: cells can move only to a lower index
633
+ for (uint32_t i = 0; i < n_kv; ++i) {
634
+ const uint32_t id = ids[i];
812
635
 
813
- // deprecated
814
- size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
815
- return llama_state_set_data(ctx, src, -1);
816
- }
636
+ if (i == id || id == n_kv) {
637
+ continue;
638
+ }
817
639
 
818
- // deprecated
819
- bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
820
- return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
821
- }
640
+ uint32_t nm = 1;
822
641
 
823
- // deprecated
824
- bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
825
- return llama_state_save_file(ctx, path_session, tokens, n_token_count);
826
- }
642
+ while (i + nm < n_kv && ids[i + nm] == id + nm) {
643
+ nm++;
644
+ }
827
645
 
828
- // TODO: replace all non-fatal assertions with returned errors or exceptions
829
- struct llama_data_write {
830
- virtual void write(const void * src, size_t size) = 0;
831
- virtual void write_tensor_data(const struct lm_ggml_tensor * tensor, size_t offset, size_t size) = 0;
832
- virtual size_t get_size_written() = 0;
833
- virtual ~llama_data_write() = default;
646
+ // move keys
647
+ {
648
+ const int64_t os = i*k_size_row;
649
+ const int64_t od = id*k_size_row;
834
650
 
835
- void write_string(const std::string & str) {
836
- uint32_t str_size = str.size();
651
+ memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
652
+ }
837
653
 
838
- write(&str_size, sizeof(str_size));
839
- write(str.data(), str_size);
840
- }
654
+ // move values (note: they are transposed)
655
+ {
656
+ const int64_t os = i;
657
+ const int64_t od = id;
841
658
 
842
- void write_model_info(const struct llama_context * ctx) {
843
- const std::string arch_str = llm_arch_name(ctx->model.arch);
844
- write_string(arch_str);
845
- // TODO: add more model-specific info which should prevent loading the session file if not identical
659
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
660
+ memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
661
+ }
662
+ }
663
+
664
+ i += nm - 1;
665
+ }
666
+
667
+ lm_ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
668
+ lm_ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
846
669
  }
670
+ #else
671
+ for (uint32_t i = 0; i < ids.size(); ++i) {
672
+ const uint32_t id = ids[i];
847
673
 
848
- //void write_rng(const std::mt19937 & rng) {
849
- // std::ostringstream rng_ss;
850
- // rng_ss << rng;
674
+ if (i == id || id == ids.size()) {
675
+ continue;
676
+ }
851
677
 
852
- // const std::string & rng_str = rng_ss.str();
678
+ uint32_t nm = 1;
853
679
 
854
- // write_string(rng_str);
855
- //}
680
+ while (i + nm < ids.size() && ids[i + nm] == id + nm) {
681
+ nm++;
682
+ }
856
683
 
857
- void write_output_ids(struct llama_context * ctx) {
858
- llama_output_reorder(*ctx);
684
+ for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
685
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
686
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
687
+
688
+ lm_ggml_tensor * view_k_src = lm_ggml_view_2d(ctx0, kv_self->k_l[il],
689
+ n_embd_k_gqa, nm,
690
+ lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
691
+ lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
692
+
693
+ lm_ggml_tensor * view_k_dst = lm_ggml_view_2d(ctx0, kv_self->k_l[il],
694
+ n_embd_k_gqa, nm,
695
+ lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
696
+ lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
697
+
698
+ lm_ggml_tensor * view_v_src;
699
+ lm_ggml_tensor * view_v_dst;
700
+
701
+ if (cparams.flash_attn) {
702
+ // NOTE: the V cache is not transposed when using flash attention
703
+ view_v_src = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
704
+ n_embd_v_gqa, nm,
705
+ lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
706
+ lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
707
+
708
+ view_v_dst = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
709
+ n_embd_v_gqa, nm,
710
+ lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
711
+ lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
712
+ } else {
713
+ view_v_src = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
714
+ nm, n_embd_v_gqa,
715
+ lm_ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
716
+ lm_ggml_row_size(kv_self->v_l[il]->type, i));
717
+
718
+ view_v_dst = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
719
+ nm, n_embd_v_gqa,
720
+ lm_ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
721
+ lm_ggml_row_size(kv_self->v_l[il]->type, id));
722
+ }
859
723
 
860
- const uint32_t n_outputs = ctx->n_outputs;
724
+ lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, view_k_src, view_k_dst));
725
+ lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, view_v_src, view_v_dst));
726
+ }
861
727
 
862
- std::vector<int32_t> output_pos;
728
+ i += nm - 1;
729
+ }
863
730
 
864
- const size_t n_batch = ctx->cparams.n_batch;
865
- const auto & output_ids = ctx->output_ids;
731
+ //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
732
+ #endif
866
733
 
867
- LM_GGML_ASSERT(n_outputs <= ctx->output_size);
734
+ return res;
735
+ }
868
736
 
869
- output_pos.resize(n_outputs);
737
+ void llama_context::kv_self_update() {
738
+ auto & kv = kv_self;
870
739
 
871
- // build a more compact representation of the output ids
872
- for (size_t i = 0; i < n_batch; ++i) {
873
- // map an output id to a position in the batch
874
- int32_t pos = output_ids[i];
875
- if (pos >= 0) {
876
- LM_GGML_ASSERT((uint32_t) pos < n_outputs);
877
- output_pos[pos] = i;
878
- }
740
+ bool need_reserve = false;
741
+
742
+ if (kv->has_shift) {
743
+ if (!kv->get_can_shift()) {
744
+ LM_GGML_ABORT("The current context does not support K-shift");
879
745
  }
880
746
 
881
- write(&n_outputs, sizeof(n_outputs));
747
+ LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
882
748
 
883
- if (n_outputs) {
884
- write(output_pos.data(), n_outputs * sizeof(int32_t));
885
- }
886
- }
749
+ // apply K-shift if needed
750
+ if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
751
+ lm_ggml_backend_sched_reset(sched.get());
887
752
 
888
- void write_logits(const struct llama_context * ctx) {
889
- const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.vocab.n_tokens());
753
+ auto * gf = graph_init();
890
754
 
891
- write(&logits_size, sizeof(logits_size));
755
+ auto res = build_kv_self_shift(ctx_compute.get(), gf);
892
756
 
893
- if (logits_size) {
894
- write(ctx->logits, logits_size * sizeof(float));
895
- }
896
- }
757
+ lm_ggml_backend_sched_alloc_graph(sched.get(), gf);
897
758
 
898
- void write_embeddings(const struct llama_context * ctx) {
899
- const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_embd);
759
+ res->set_inputs(nullptr);
900
760
 
901
- write(&embeddings_size, sizeof(embeddings_size));
761
+ graph_compute(gf, false);
902
762
 
903
- if (embeddings_size) {
904
- write(ctx->embd, embeddings_size * sizeof(float));
763
+ need_reserve = true;
905
764
  }
906
- }
907
-
908
- void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
909
- for (const auto & range : cell_ranges) {
910
- for (uint32_t i = range.first; i < range.second; ++i) {
911
- const auto & cell = kv_self.cells[i];
912
- const llama_pos pos = cell.pos;
913
- const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
914
765
 
915
- write(&pos, sizeof(pos));
916
- write(&n_seq_id, sizeof(n_seq_id));
766
+ {
767
+ kv->has_shift = false;
917
768
 
918
- if (n_seq_id) {
919
- for (auto seq_id : cell.seq_id) {
920
- write(&seq_id, sizeof(seq_id));
921
- }
922
- }
769
+ for (uint32_t i = 0; i < kv->size; ++i) {
770
+ kv->cells[i].delta = 0;
923
771
  }
924
772
  }
925
773
  }
926
774
 
927
- void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
928
- const struct llama_kv_cache & kv_self = ctx->kv_self;
929
- const struct llama_hparams & hparams = ctx->model.hparams;
775
+ // defragment the KV cache if needed
776
+ if (kv->do_defrag) {
777
+ LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
930
778
 
931
- const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
932
- const uint32_t n_layer = hparams.n_layer;
779
+ if (kv->defrag_prepare(graph_max_nodes())) {
780
+ lm_ggml_backend_sched_reset(sched.get());
933
781
 
934
- write(&v_trans, sizeof(v_trans));
935
- write(&n_layer, sizeof(n_layer));
782
+ auto * gf = graph_init();
936
783
 
937
- std::vector<uint8_t> tmp_buf;
784
+ auto res = build_kv_self_defrag(ctx_compute.get(), gf);
938
785
 
939
- // Iterate and write all the keys first, each row is a cell
940
- // Get whole range at a time
941
- for (uint32_t il = 0; il < n_layer; ++il) {
942
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
786
+ lm_ggml_backend_sched_alloc_graph(sched.get(), gf);
943
787
 
944
- // Write key type
945
- const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
946
- write(&k_type_i, sizeof(k_type_i));
788
+ res->set_inputs(nullptr);
947
789
 
948
- // Write row size of key
949
- const uint64_t k_size_row = lm_ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
950
- write(&k_size_row, sizeof(k_size_row));
790
+ graph_compute(gf, false);
951
791
 
952
- // Read each range of cells of k_size length each into tmp_buf and write out
953
- for (const auto & range : cell_ranges) {
954
- const size_t range_size = range.second - range.first;
955
- const size_t buf_size = range_size * k_size_row;
956
- write_tensor_data(kv_self.k_l[il], range.first * k_size_row, buf_size);
957
- }
792
+ need_reserve = true;
958
793
  }
959
794
 
960
- if (!kv_self.v_trans) {
961
- for (uint32_t il = 0; il < n_layer; ++il) {
962
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
795
+ kv->do_defrag = false;
796
+ }
963
797
 
964
- // Write value type
965
- const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
966
- write(&v_type_i, sizeof(v_type_i));
798
+ // reserve a worst case graph if needed
799
+ if (need_reserve) {
800
+ LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
967
801
 
968
- // Write row size of value
969
- const uint64_t v_size_row = lm_ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
970
- write(&v_size_row, sizeof(v_size_row));
802
+ // build worst-case graph
803
+ uint32_t n_seqs = 1; // TODO: worst-case number of sequences
804
+ uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
971
805
 
972
- // Read each range of cells of v_size length each into tmp_buf and write out
973
- for (const auto & range : cell_ranges) {
974
- const size_t range_size = range.second - range.first;
975
- const size_t buf_size = range_size * v_size_row;
976
- write_tensor_data(kv_self.v_l[il], range.first * v_size_row, buf_size);
977
- }
978
- }
979
- } else {
980
- // When v is transposed, we also need the element size and get the element ranges from each row
981
- const uint32_t kv_size = kv_self.size;
982
- for (uint32_t il = 0; il < n_layer; ++il) {
983
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
806
+ // simulate full KV cache
807
+ kv_self->n = kv_self->size;
984
808
 
985
- // Write value type
986
- const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
987
- write(&v_type_i, sizeof(v_type_i));
809
+ llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
810
+ llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
988
811
 
989
- // Write element size
990
- const uint32_t v_size_el = lm_ggml_type_size(kv_self.v_l[il]->type);
991
- write(&v_size_el, sizeof(v_size_el));
812
+ auto * gf = graph_init();
813
+ graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
992
814
 
993
- // Write GQA embedding size
994
- write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
815
+ // initialize scheduler with the worst-case graph
816
+ lm_ggml_backend_sched_reset(sched.get());
817
+ if (!lm_ggml_backend_sched_reserve(sched.get(), gf)) {
818
+ LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
819
+ }
820
+ }
821
+ }
995
822
 
996
- // For each row, we get the element values of each cell
997
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
998
- // Read each range of cells of v_size_el length each into tmp_buf and write out
999
- for (const auto & range : cell_ranges) {
1000
- const size_t range_size = range.second - range.first;
1001
- const size_t src_offset = (range.first + j * kv_size) * v_size_el;
1002
- const size_t buf_size = range_size * v_size_el;
1003
- write_tensor_data(kv_self.v_l[il], src_offset, buf_size);
1004
- }
1005
- }
823
+ enum llama_pooling_type llama_context::pooling_type() const {
824
+ return cparams.pooling_type;
825
+ }
826
+
827
+ float * llama_context::get_logits() {
828
+ // reorder logits for backward compatibility
829
+ output_reorder();
830
+
831
+ return logits;
832
+ }
833
+
834
+ float * llama_context::get_logits_ith(int32_t i) {
835
+ int32_t j = -1;
836
+
837
+ try {
838
+ if (logits == nullptr) {
839
+ throw std::runtime_error("no logits");
840
+ }
841
+
842
+ if (i < 0) {
843
+ j = n_outputs + i;
844
+ if (j < 0) {
845
+ throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
1006
846
  }
847
+ } else if ((size_t) i >= output_ids.size()) {
848
+ throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
849
+ } else {
850
+ j = output_ids[i];
851
+ }
852
+
853
+ if (j < 0) {
854
+ throw std::runtime_error(format("batch.logits[%d] != true", i));
855
+ }
856
+ if (j >= n_outputs) {
857
+ // This should not happen
858
+ throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
1007
859
  }
860
+
861
+ return logits + j*model.vocab.n_tokens();
862
+ } catch (const std::exception & err) {
863
+ LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
864
+ #ifndef NDEBUG
865
+ LM_GGML_ABORT("fatal error");
866
+ #else
867
+ return nullptr;
868
+ #endif
1008
869
  }
870
+ }
1009
871
 
1010
- void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
1011
- const struct llama_kv_cache & kv_self = ctx->kv_self;
1012
- std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
1013
- uint32_t cell_count = 0;
872
+ float * llama_context::get_embeddings() {
873
+ // reorder embeddings for backward compatibility
874
+ output_reorder();
1014
875
 
1015
- // Count the number of cells with the specified seq_id
1016
- // Find all the ranges of cells with this seq id (or all, when -1)
1017
- uint32_t cell_range_begin = kv_self.size;
1018
- for (uint32_t i = 0; i < kv_self.size; ++i) {
1019
- const auto & cell = kv_self.cells[i];
1020
- if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
1021
- ++cell_count;
1022
- if (cell_range_begin == kv_self.size) {
1023
- cell_range_begin = i;
1024
- }
1025
- } else {
1026
- if (cell_range_begin != kv_self.size) {
1027
- cell_ranges.emplace_back(cell_range_begin, i);
1028
- cell_range_begin = kv_self.size;
1029
- }
1030
- }
876
+ return embd;
877
+ }
878
+
879
+ float * llama_context::get_embeddings_ith(int32_t i) {
880
+ int32_t j = -1;
881
+
882
+ try {
883
+ if (embd == nullptr) {
884
+ throw std::runtime_error("no embeddings");
1031
885
  }
1032
- if (cell_range_begin != kv_self.size) {
1033
- cell_ranges.emplace_back(cell_range_begin, kv_self.size);
886
+
887
+ if (i < 0) {
888
+ j = n_outputs + i;
889
+ if (j < 0) {
890
+ throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
891
+ }
892
+ } else if ((size_t) i >= output_ids.size()) {
893
+ throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
894
+ } else {
895
+ j = output_ids[i];
1034
896
  }
1035
897
 
1036
- // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
1037
- uint32_t cell_count_check = 0;
1038
- for (const auto & range : cell_ranges) {
1039
- cell_count_check += range.second - range.first;
898
+ if (j < 0) {
899
+ throw std::runtime_error(format("batch.logits[%d] != true", i));
900
+ }
901
+ if (j >= n_outputs) {
902
+ // This should not happen
903
+ throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
1040
904
  }
1041
- LM_GGML_ASSERT(cell_count == cell_count_check);
1042
905
 
1043
- write(&cell_count, sizeof(cell_count));
906
+ return embd + j*model.hparams.n_embd;
907
+ } catch (const std::exception & err) {
908
+ LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
909
+ #ifndef NDEBUG
910
+ LM_GGML_ABORT("fatal error");
911
+ #else
912
+ return nullptr;
913
+ #endif
914
+ }
915
+ }
1044
916
 
1045
- write_kv_cache_meta(kv_self, cell_ranges, seq_id);
1046
- write_kv_cache_data(ctx, cell_ranges);
917
+ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
918
+ auto it = embd_seq.find(seq_id);
919
+ if (it == embd_seq.end()) {
920
+ return nullptr;
1047
921
  }
1048
- };
1049
922
 
1050
- struct llama_data_read {
1051
- virtual const uint8_t * read(size_t size) = 0;
1052
- virtual void read_to(void * dst, size_t size) = 0;
1053
- virtual size_t get_size_read() = 0;
1054
- virtual ~llama_data_read() = default;
923
+ return it->second.data();
924
+ }
1055
925
 
1056
- void read_string(std::string & str) {
1057
- uint32_t str_size;
1058
- read_to(&str_size, sizeof(str_size));
926
+ void llama_context::attach_threadpool(
927
+ lm_ggml_threadpool_t threadpool,
928
+ lm_ggml_threadpool_t threadpool_batch) {
929
+ LLAMA_LOG_DEBUG("%s: call\n", __func__);
1059
930
 
1060
- str.assign((const char *) read(str_size), str_size);
1061
- }
931
+ this->threadpool = threadpool;
932
+ this->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
933
+ }
1062
934
 
1063
- // validate model information
1064
- void read_model_info(const struct llama_context * ctx) {
1065
- const std::string cur_arch_str = llm_arch_name(ctx->model.arch);
935
+ void llama_context::detach_threadpool() {
936
+ LLAMA_LOG_DEBUG("%s: call\n", __func__);
1066
937
 
1067
- std::string arch_str;
1068
- read_string(arch_str);
1069
- if (cur_arch_str != arch_str) {
1070
- throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
938
+ this->threadpool = nullptr;
939
+ this->threadpool_batch = nullptr;
940
+ }
941
+
942
+ void llama_context::set_n_threads(int32_t n_threads, int32_t n_threads_batch) {
943
+ LLAMA_LOG_DEBUG("%s: n_threads = %d, n_threads_batch = %d\n", __func__, n_threads, n_threads_batch);
944
+
945
+ cparams.n_threads = n_threads;
946
+ cparams.n_threads_batch = n_threads_batch;
947
+ }
948
+
949
+ void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) {
950
+ LLAMA_LOG_DEBUG("%s: call\n", __func__);
951
+
952
+ this->abort_callback = abort_callback;
953
+ this->abort_callback_data = abort_callback_data;
954
+
955
+ for (auto & backend : backends) {
956
+ auto * reg = lm_ggml_backend_dev_backend_reg(lm_ggml_backend_get_device(backend.get()));
957
+ auto * set_abort_callback_fn = (lm_ggml_backend_set_abort_callback_t) lm_ggml_backend_reg_get_proc_address(reg, "lm_ggml_backend_set_abort_callback");
958
+ if (set_abort_callback_fn) {
959
+ set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data);
1071
960
  }
1072
- // TODO: add more info which needs to be identical but which is not verified otherwise
1073
961
  }
962
+ }
963
+
964
+ void llama_context::set_embeddings(bool value) {
965
+ LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
1074
966
 
1075
- //void read_rng(std::mt19937 & rng) {
1076
- // std::string rng_str;
1077
- // read_string(rng_str);
967
+ cparams.embeddings = value;
968
+ }
1078
969
 
1079
- // std::istringstream rng_ss(rng_str);
1080
- // rng_ss >> rng;
970
+ void llama_context::set_causal_attn(bool value) {
971
+ LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
1081
972
 
1082
- // if (rng_ss.fail()) {
1083
- // throw std::runtime_error("failed to load RNG state");
1084
- // }
1085
- //}
973
+ cparams.causal_attn = value;
974
+ }
1086
975
 
1087
- void read_output_ids(struct llama_context * ctx) {
1088
- std::vector<int32_t> output_pos;
976
+ void llama_context::set_warmup(bool value) {
977
+ LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
1089
978
 
1090
- uint32_t n_outputs;
1091
- read_to(&n_outputs, sizeof(n_outputs));
979
+ cparams.warmup = value;
980
+ }
1092
981
 
1093
- if (n_outputs > llama_output_reserve(*ctx, n_outputs)) {
1094
- throw std::runtime_error("could not reserve outputs");
1095
- }
982
+ void llama_context::set_adapter_lora(
983
+ llama_adapter_lora * adapter,
984
+ float scale) {
985
+ LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale);
1096
986
 
1097
- if (n_outputs) {
1098
- output_pos.resize(n_outputs);
1099
- read_to(output_pos.data(), n_outputs * sizeof(int32_t));
987
+ loras[adapter] = scale;
988
+ }
1100
989
 
1101
- for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
1102
- int32_t id = output_pos[i];
1103
- if ((uint32_t) id >= ctx->cparams.n_batch) {
1104
- throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch));
1105
- }
1106
- ctx->output_ids[id] = i;
1107
- }
990
+ bool llama_context::rm_adapter_lora(
991
+ llama_adapter_lora * adapter) {
992
+ LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter);
1108
993
 
1109
- ctx->n_outputs = n_outputs;
1110
- }
994
+ auto pos = loras.find(adapter);
995
+ if (pos != loras.end()) {
996
+ loras.erase(pos);
997
+ return true;
1111
998
  }
1112
999
 
1113
- void read_logits(struct llama_context * ctx) {
1114
- uint64_t logits_size;
1115
- read_to(&logits_size, sizeof(logits_size));
1000
+ return false;
1001
+ }
1116
1002
 
1117
- if (ctx->logits_size < logits_size) {
1118
- throw std::runtime_error("logits buffer too small");
1119
- }
1003
+ void llama_context::clear_adapter_lora() {
1004
+ LLAMA_LOG_DEBUG("%s: call\n", __func__);
1120
1005
 
1121
- if (logits_size) {
1122
- read_to(ctx->logits, logits_size * sizeof(float));
1123
- }
1006
+ loras.clear();
1007
+ }
1008
+
1009
+ bool llama_context::apply_adapter_cvec(
1010
+ const float * data,
1011
+ size_t len,
1012
+ int32_t n_embd,
1013
+ int32_t il_start,
1014
+ int32_t il_end) {
1015
+ LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end);
1016
+
1017
+ return cvec.apply(model, data, len, n_embd, il_start, il_end);
1018
+ }
1019
+
1020
+ int llama_context::encode(llama_batch & inp_batch) {
1021
+ if (inp_batch.n_tokens == 0) {
1022
+ LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
1023
+ return -1;
1124
1024
  }
1125
1025
 
1126
- void read_embeddings(struct llama_context * ctx) {
1127
- uint64_t embeddings_size;
1128
- read_to(&embeddings_size, sizeof(embeddings_size));
1026
+ // temporary allocate memory for the input batch if needed
1027
+ // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
1028
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
1129
1029
 
1130
- if (ctx->embd_size < embeddings_size) {
1131
- throw std::runtime_error("embeddings buffer too small");
1132
- }
1030
+ const llama_batch & batch = batch_allocr.batch;
1031
+ const int32_t n_tokens = batch.n_tokens;
1133
1032
 
1134
- if (embeddings_size) {
1135
- read_to(ctx->embd, embeddings_size * sizeof(float));
1033
+ const auto & hparams = model.hparams;
1034
+
1035
+ LM_GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
1036
+
1037
+ if (batch.token) {
1038
+ for (int32_t i = 0; i < n_tokens; ++i) {
1039
+ if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
1040
+ LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
1041
+ return -1;
1042
+ }
1136
1043
  }
1137
1044
  }
1138
1045
 
1139
- bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) {
1140
- struct llama_kv_cache & kv_self = ctx->kv_self;
1046
+ // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
1047
+ LM_GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
1048
+
1049
+ if (t_compute_start_us == 0) {
1050
+ t_compute_start_us = lm_ggml_time_us();
1051
+ }
1052
+
1053
+ n_queued_tokens += n_tokens;
1141
1054
 
1142
- if (dest_seq_id != -1) {
1143
- // single sequence
1055
+ const int64_t n_embd = hparams.n_embd;
1144
1056
 
1145
- llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
1057
+ sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
1146
1058
 
1147
- llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1148
- batch.n_tokens = cell_count;
1149
- batch.n_seq_tokens = cell_count;
1150
- batch.n_seqs = 1;
1059
+ const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
1151
1060
 
1152
- for (uint32_t i = 0; i < cell_count; ++i) {
1153
- llama_pos pos;
1154
- uint32_t n_seq_id;
1061
+ // reserve output buffer
1062
+ if (output_reserve(n_tokens) < n_tokens) {
1063
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
1064
+ return -2;
1065
+ };
1155
1066
 
1156
- read_to(&pos, sizeof(pos));
1157
- read_to(&n_seq_id, sizeof(n_seq_id));
1067
+ for (int32_t i = 0; i < n_tokens; ++i) {
1068
+ output_ids[i] = i;
1069
+ }
1158
1070
 
1159
- if (n_seq_id != 0) {
1160
- LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
1161
- return false;
1162
- }
1071
+ n_outputs = n_tokens;
1163
1072
 
1164
- batch.pos[i] = pos;
1165
- }
1166
- batch.n_seq_id[0] = 1;
1167
- batch.seq_id[0] = &dest_seq_id;
1168
- if (!llama_kv_cache_find_slot(kv_self, batch)) {
1169
- LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1170
- return false;
1171
- }
1073
+ //batch_manager->prepare(ubatch);
1172
1074
 
1173
- // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
1174
- // Assume that this is one contiguous block of cells
1175
- LM_GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
1176
- LM_GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
1177
- LM_GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
1178
- LM_GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
1179
- LM_GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
1180
- } else {
1181
- // whole KV cache restore
1075
+ lm_ggml_backend_sched_reset(sched.get());
1076
+ lm_ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1182
1077
 
1183
- if (cell_count > kv_self.size) {
1184
- LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
1185
- return false;
1186
- }
1078
+ const auto causal_attn_org = cparams.causal_attn;
1187
1079
 
1188
- llama_kv_cache_clear(kv_self);
1080
+ // always use non-causal attention for encoder graphs
1081
+ // TODO: this is a tmp solution until we have a proper way to support enc-dec models
1082
+ // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
1083
+ cparams.causal_attn = false;
1189
1084
 
1190
- for (uint32_t i = 0; i < cell_count; ++i) {
1191
- llama_kv_cell & cell = kv_self.cells[i];
1085
+ auto * gf = graph_init();
1086
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
1192
1087
 
1193
- llama_pos pos;
1194
- uint32_t n_seq_id;
1088
+ lm_ggml_backend_sched_alloc_graph(sched.get(), gf);
1195
1089
 
1196
- read_to(&pos, sizeof(pos));
1197
- read_to(&n_seq_id, sizeof(n_seq_id));
1090
+ res->set_inputs(&ubatch);
1198
1091
 
1199
- cell.pos = pos;
1092
+ cparams.causal_attn = causal_attn_org;
1200
1093
 
1201
- for (uint32_t j = 0; j < n_seq_id; ++j) {
1202
- llama_seq_id seq_id;
1203
- read_to(&seq_id, sizeof(seq_id));
1094
+ const auto compute_status = graph_compute(gf, n_tokens > 1);
1095
+ switch (compute_status) {
1096
+ case LM_GGML_STATUS_SUCCESS:
1097
+ break;
1098
+ case LM_GGML_STATUS_ABORTED:
1099
+ return 2;
1100
+ case LM_GGML_STATUS_ALLOC_FAILED:
1101
+ return -2;
1102
+ case LM_GGML_STATUS_FAILED:
1103
+ default:
1104
+ return -3;
1105
+ }
1204
1106
 
1205
- if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
1206
- LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
1207
- return false;
1208
- }
1107
+ auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
1108
+
1109
+ // extract embeddings
1110
+ if (t_embd) {
1111
+ lm_ggml_backend_t backend_embd = lm_ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
1112
+ LM_GGML_ASSERT(backend_embd != nullptr);
1113
+
1114
+ LM_GGML_ASSERT(embd != nullptr);
1209
1115
 
1210
- cell.seq_id.insert(seq_id);
1116
+ switch (cparams.pooling_type) {
1117
+ case LLAMA_POOLING_TYPE_NONE:
1118
+ {
1119
+ // extract token embeddings
1120
+ LM_GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
1121
+ lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
1122
+ } break;
1123
+ case LLAMA_POOLING_TYPE_MEAN:
1124
+ case LLAMA_POOLING_TYPE_CLS:
1125
+ case LLAMA_POOLING_TYPE_LAST:
1126
+ {
1127
+ // extract sequence embeddings
1128
+ auto & embd_seq_out = embd_seq;
1129
+ embd_seq_out.clear();
1211
1130
 
1212
- if (kv_self.recurrent) {
1213
- int32_t & tail = kv_self.cells[seq_id].tail;
1214
- if (tail != -1) {
1215
- LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
1216
- return false;
1131
+ LM_GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
1132
+
1133
+ for (int32_t i = 0; i < n_tokens; i++) {
1134
+ const llama_seq_id seq_id = ubatch.seq_id[i][0];
1135
+ if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1136
+ continue;
1217
1137
  }
1218
- tail = i;
1138
+ embd_seq_out[seq_id].resize(n_embd);
1139
+ lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
1219
1140
  }
1141
+ } break;
1142
+ case LLAMA_POOLING_TYPE_RANK:
1143
+ {
1144
+ // TODO: this likely should be the same logic as in llama_decoder_internal, but better to
1145
+ // wait for an encoder model that requires this pooling type in order to test it
1146
+ // https://github.com/ggerganov/llama.cpp/pull/9510
1147
+ LM_GGML_ABORT("RANK pooling not implemented yet");
1148
+ }
1149
+ case LLAMA_POOLING_TYPE_UNSPECIFIED:
1150
+ {
1151
+ LM_GGML_ABORT("unknown pooling type");
1220
1152
  }
1221
- }
1222
-
1223
- kv_self.head = 0;
1224
- kv_self.used = cell_count;
1225
1153
  }
1154
+ }
1155
+
1156
+ // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1157
+ // overlap with device computation.
1158
+ lm_ggml_backend_sched_reset(sched.get());
1159
+
1160
+ // TODO: hacky solution
1161
+ if (model.arch == LLM_ARCH_T5 && t_embd) {
1162
+ //cross.t_embd = t_embd;
1163
+
1164
+ synchronize();
1226
1165
 
1227
- if (kv_self.recurrent) {
1228
- for (uint32_t i = 0; i < cell_count; ++i) {
1229
- uint32_t cell_id = kv_self.head + i;
1230
- // make sure the recurrent states will keep their restored state
1231
- kv_self.cells[cell_id].src = cell_id;
1166
+ cross.n_embd = t_embd->ne[0];
1167
+ cross.n_enc = t_embd->ne[1];
1168
+ cross.v_embd.resize(cross.n_embd*cross.n_enc);
1169
+ memcpy(cross.v_embd.data(), embd, lm_ggml_nbytes(t_embd));
1170
+
1171
+ // remember the sequence ids used during the encoding - needed for cross attention later
1172
+ cross.seq_ids_enc.resize(n_tokens);
1173
+ for (int32_t i = 0; i < n_tokens; i++) {
1174
+ cross.seq_ids_enc[i].clear();
1175
+ for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
1176
+ llama_seq_id seq_id = ubatch.seq_id[i][s];
1177
+ cross.seq_ids_enc[i].insert(seq_id);
1232
1178
  }
1233
1179
  }
1180
+ }
1234
1181
 
1235
- return true;
1182
+ return 0;
1183
+ }
1184
+
1185
+ int llama_context::decode(llama_batch & inp_batch) {
1186
+ if (inp_batch.n_tokens == 0) {
1187
+ LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
1188
+ return -1;
1236
1189
  }
1237
1190
 
1238
- bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
1239
- const struct llama_hparams & hparams = ctx->model.hparams;
1240
- struct llama_kv_cache & kv_self = ctx->kv_self;
1241
- uint32_t v_trans;
1242
- uint32_t n_layer;
1243
- read_to(&v_trans, sizeof(v_trans));
1244
- read_to(&n_layer, sizeof(n_layer));
1191
+ // temporary allocate memory for the input batch if needed
1192
+ // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
1193
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
1245
1194
 
1246
- if (n_layer != hparams.n_layer) {
1247
- LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
1248
- return false;
1249
- }
1250
- if (cell_count > kv_self.size) {
1251
- LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
1252
- return false;
1253
- }
1254
- if (kv_self.v_trans != (bool) v_trans) {
1255
- LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
1256
- return false;
1257
- }
1195
+ const llama_batch & batch = batch_allocr.batch;
1258
1196
 
1259
- // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
1260
- for (uint32_t il = 0; il < n_layer; ++il) {
1261
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1197
+ const auto & vocab = model.vocab;
1198
+ const auto & hparams = model.hparams;
1262
1199
 
1263
- // Read type of key
1264
- int32_t k_type_i_ref;
1265
- read_to(&k_type_i_ref, sizeof(k_type_i_ref));
1266
- const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
1267
- if (k_type_i != k_type_i_ref) {
1268
- LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
1269
- return false;
1270
- }
1200
+ const int32_t n_vocab = vocab.n_tokens();
1271
1201
 
1272
- // Read row size of key
1273
- uint64_t k_size_row_ref;
1274
- read_to(&k_size_row_ref, sizeof(k_size_row_ref));
1275
- const size_t k_size_row = lm_ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
1276
- if (k_size_row != k_size_row_ref) {
1277
- LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
1278
- return false;
1279
- }
1202
+ const int64_t n_tokens_all = batch.n_tokens;
1203
+ const int64_t n_embd = hparams.n_embd;
1204
+
1205
+ llama_kv_cache_guard kv_guard(kv_self.get());
1280
1206
 
1281
- if (cell_count) {
1282
- // Read and set the keys for the whole cell range
1283
- lm_ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
1207
+ LM_GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
1208
+
1209
+ if (batch.token) {
1210
+ for (int64_t i = 0; i < n_tokens_all; ++i) {
1211
+ if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
1212
+ LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
1213
+ throw std::runtime_error("invalid token");
1284
1214
  }
1285
1215
  }
1216
+ }
1286
1217
 
1287
- if (!kv_self.v_trans) {
1288
- for (uint32_t il = 0; il < n_layer; ++il) {
1289
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1218
+ LM_GGML_ASSERT(n_tokens_all <= cparams.n_batch);
1290
1219
 
1291
- // Read type of value
1292
- int32_t v_type_i_ref;
1293
- read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1294
- const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
1295
- if (v_type_i != v_type_i_ref) {
1296
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1297
- return false;
1298
- }
1220
+ LM_GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
1299
1221
 
1300
- // Read row size of value
1301
- uint64_t v_size_row_ref;
1302
- read_to(&v_size_row_ref, sizeof(v_size_row_ref));
1303
- const size_t v_size_row = lm_ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
1304
- if (v_size_row != v_size_row_ref) {
1305
- LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
1306
- return false;
1307
- }
1222
+ if (t_compute_start_us == 0) {
1223
+ t_compute_start_us = lm_ggml_time_us();
1224
+ }
1225
+ n_queued_tokens += n_tokens_all;
1308
1226
 
1309
- if (cell_count) {
1310
- // Read and set the values for the whole cell range
1311
- lm_ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
1312
- }
1313
- }
1314
- } else {
1315
- // For each layer, read the values for each cell (transposed)
1316
- for (uint32_t il = 0; il < n_layer; ++il) {
1317
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1318
-
1319
- // Read type of value
1320
- int32_t v_type_i_ref;
1321
- read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1322
- const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
1323
- if (v_type_i != v_type_i_ref) {
1324
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1325
- return false;
1326
- }
1227
+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
1228
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
1327
1229
 
1328
- // Read element size of value
1329
- uint32_t v_size_el_ref;
1330
- read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1331
- const size_t v_size_el = lm_ggml_type_size(kv_self.v_l[il]->type);
1332
- if (v_size_el != v_size_el_ref) {
1333
- LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1334
- return false;
1335
- }
1230
+ embd_seq.clear();
1336
1231
 
1337
- // Read GQA embedding size
1338
- uint32_t n_embd_v_gqa_ref;
1339
- read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
1340
- if (n_embd_v_gqa != n_embd_v_gqa_ref) {
1341
- LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
1342
- return false;
1343
- }
1232
+ int64_t n_outputs_all = 0;
1344
1233
 
1345
- if (cell_count) {
1346
- // For each row in the transposed matrix, read the values for the whole cell range
1347
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1348
- const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;
1349
- lm_ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1350
- }
1351
- }
1352
- }
1234
+ // count outputs
1235
+ if (batch.logits && !embd_pooled) {
1236
+ for (uint32_t i = 0; i < n_tokens_all; ++i) {
1237
+ n_outputs_all += batch.logits[i] != 0;
1353
1238
  }
1354
- return true;
1239
+ } else if (logits_all || embd_pooled) {
1240
+ n_outputs_all = n_tokens_all;
1241
+ } else {
1242
+ // keep last output only
1243
+ n_outputs_all = 1;
1355
1244
  }
1356
1245
 
1357
- void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
1358
- uint32_t cell_count;
1359
- read_to(&cell_count, sizeof(cell_count));
1246
+ const bool logits_all = n_outputs_all == n_tokens_all;
1247
+
1248
+ sbatch.from_batch(batch, n_embd,
1249
+ /* simple_split */ !kv_self->recurrent,
1250
+ /* logits_all */ logits_all);
1251
+
1252
+ // reserve output buffer
1253
+ if (output_reserve(n_outputs_all) < n_outputs_all) {
1254
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
1255
+ return -2;
1256
+ };
1360
1257
 
1361
- bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);
1258
+ // handle any pending defrags/shifts
1259
+ kv_self_update();
1362
1260
 
1363
- if (!res) {
1364
- if (seq_id == -1) {
1365
- llama_kv_cache_clear(ctx);
1261
+ int64_t n_outputs_prev = 0;
1262
+
1263
+ while (sbatch.n_tokens > 0) {
1264
+ llama_ubatch ubatch = llama_ubatch();
1265
+
1266
+ const auto & n_ubatch = cparams.n_ubatch;
1267
+
1268
+ if (kv_self->recurrent) {
1269
+ if (embd_pooled) {
1270
+ // Pooled embeddings cannot be split across ubatches (yet)
1271
+ ubatch = sbatch.split_seq(cparams.n_ubatch);
1366
1272
  } else {
1367
- llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
1273
+ // recurrent model architectures are easier to implement
1274
+ // with equal-length sequences
1275
+ ubatch = sbatch.split_equal(cparams.n_ubatch);
1368
1276
  }
1369
- throw std::runtime_error("failed to restore kv cache");
1277
+ } else {
1278
+ ubatch = sbatch.split_simple(n_ubatch);
1370
1279
  }
1371
- }
1372
- };
1373
1280
 
1374
- struct llama_data_write_dummy : llama_data_write {
1375
- size_t size_written = 0;
1281
+ // count the outputs in this u_batch
1282
+ {
1283
+ int32_t n_outputs_new = 0;
1376
1284
 
1377
- llama_data_write_dummy() {}
1285
+ if (n_outputs_all == n_tokens_all) {
1286
+ n_outputs_new = ubatch.n_tokens;
1287
+ } else {
1288
+ LM_GGML_ASSERT(ubatch.output);
1289
+ for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
1290
+ n_outputs_new += (int32_t) (ubatch.output[i] != 0);
1291
+ }
1292
+ }
1378
1293
 
1379
- void write(const void * /* src */, size_t size) override {
1380
- size_written += size;
1381
- }
1294
+ // needs to happen before the graph is built
1295
+ n_outputs = n_outputs_new;
1296
+ }
1382
1297
 
1383
- void write_tensor_data(const struct lm_ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
1384
- size_written += size;
1385
- }
1298
+ // find KV slot
1299
+ {
1300
+ if (!kv_self->find_slot(ubatch)) {
1301
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1386
1302
 
1387
- size_t get_size_written() override {
1388
- return size_written;
1389
- }
1390
- };
1303
+ return 1;
1304
+ }
1391
1305
 
1392
- struct llama_data_write_buffer : llama_data_write {
1393
- uint8_t * ptr;
1394
- size_t buf_size = 0;
1395
- size_t size_written = 0;
1306
+ if (!kv_self->recurrent) {
1307
+ // a heuristic, to avoid attending the full cache if it is not yet utilized
1308
+ // after enough generations, the benefit from this heuristic disappears
1309
+ // if we start defragmenting the cache, the benefit from this will be more important
1310
+ const uint32_t pad = kv_self->get_padding(cparams);
1311
+ kv_self->n = std::min(kv_self->size, std::max(pad, LM_GGML_PAD(kv_self->cell_max(), pad)));
1312
+ }
1313
+ }
1396
1314
 
1397
- llama_data_write_buffer(uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
1315
+ //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
1398
1316
 
1399
- void write(const void * src, size_t size) override {
1400
- if (size > buf_size) {
1401
- throw std::runtime_error("unexpectedly reached end of buffer");
1317
+ lm_ggml_backend_sched_reset(sched.get());
1318
+ lm_ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1319
+
1320
+ auto * gf = graph_init();
1321
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
1322
+
1323
+ // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (lm_ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1324
+
1325
+ lm_ggml_backend_sched_alloc_graph(sched.get(), gf);
1326
+
1327
+ res->set_inputs(&ubatch);
1328
+
1329
+ const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
1330
+ if (compute_status != LM_GGML_STATUS_SUCCESS) {
1331
+ switch (compute_status) {
1332
+ case LM_GGML_STATUS_ABORTED:
1333
+ return 2;
1334
+ case LM_GGML_STATUS_ALLOC_FAILED:
1335
+ return -2;
1336
+ case LM_GGML_STATUS_FAILED:
1337
+ default:
1338
+ return -3;
1339
+ }
1340
+ }
1341
+
1342
+ // plot the computation graph in dot format (for debugging purposes)
1343
+ //if (n_past%100 == 0) {
1344
+ // lm_ggml_graph_dump_dot(gf, NULL, "llama.dot");
1345
+ //}
1346
+
1347
+ auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
1348
+ auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
1349
+
1350
+ if (t_embd && res->get_embd_pooled()) {
1351
+ t_embd = res->get_embd_pooled();
1352
+ }
1353
+
1354
+ // extract logits
1355
+ if (t_logits && n_outputs > 0) {
1356
+ lm_ggml_backend_t backend_res = lm_ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
1357
+ LM_GGML_ASSERT(backend_res != nullptr);
1358
+ LM_GGML_ASSERT(logits != nullptr);
1359
+
1360
+ float * logits_out = logits + n_outputs_prev*n_vocab;
1361
+
1362
+ if (n_outputs) {
1363
+ LM_GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1364
+ LM_GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
1365
+ lm_ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
1366
+ }
1367
+ }
1368
+
1369
+ // extract embeddings
1370
+ if (t_embd && n_outputs > 0) {
1371
+ lm_ggml_backend_t backend_embd = lm_ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
1372
+ LM_GGML_ASSERT(backend_embd != nullptr);
1373
+
1374
+ switch (cparams.pooling_type) {
1375
+ case LLAMA_POOLING_TYPE_NONE:
1376
+ {
1377
+ // extract token embeddings
1378
+ LM_GGML_ASSERT(embd != nullptr);
1379
+ float * embd_out = embd + n_outputs_prev*n_embd;
1380
+
1381
+ if (n_outputs) {
1382
+ LM_GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1383
+ LM_GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
1384
+ lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
1385
+ }
1386
+ } break;
1387
+ case LLAMA_POOLING_TYPE_MEAN:
1388
+ case LLAMA_POOLING_TYPE_CLS:
1389
+ case LLAMA_POOLING_TYPE_LAST:
1390
+ {
1391
+ // extract sequence embeddings (cleared before processing each batch)
1392
+ auto & embd_seq_out = embd_seq;
1393
+
1394
+ for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
1395
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
1396
+ if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1397
+ continue;
1398
+ }
1399
+ embd_seq_out[seq_id].resize(n_embd);
1400
+ lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
1401
+ }
1402
+ } break;
1403
+ case LLAMA_POOLING_TYPE_RANK:
1404
+ {
1405
+ // extract the rerank score - a single float per sequence
1406
+ auto & embd_seq_out = embd_seq;
1407
+
1408
+ for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
1409
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
1410
+ if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1411
+ continue;
1412
+ }
1413
+ embd_seq_out[seq_id].resize(1);
1414
+ lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
1415
+ }
1416
+ } break;
1417
+ case LLAMA_POOLING_TYPE_UNSPECIFIED:
1418
+ {
1419
+ LM_GGML_ABORT("unknown pooling type");
1420
+ }
1421
+ }
1422
+ }
1423
+
1424
+ n_outputs_prev += n_outputs;
1425
+ }
1426
+
1427
+ // finalize the batch processing
1428
+ kv_guard.commit();
1429
+
1430
+ // set output mappings
1431
+ {
1432
+ bool sorted_output = true;
1433
+
1434
+ LM_GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
1435
+
1436
+ for (int64_t i = 0; i < n_outputs_all; ++i) {
1437
+ int64_t out_id = sbatch.out_ids[i];
1438
+ output_ids[out_id] = i;
1439
+ if (out_id != i) {
1440
+ sorted_output = false;
1441
+ }
1442
+ }
1443
+
1444
+ if (sorted_output) {
1445
+ sbatch.out_ids.clear();
1446
+ }
1447
+ }
1448
+
1449
+ // set to total number of outputs in the batch, for use in llama_get_logits_ith
1450
+ n_outputs = n_outputs_all;
1451
+
1452
+ // wait for the computation to finish (automatically done when obtaining the model output)
1453
+ //synchronize();
1454
+
1455
+ // decide if we need to defrag the kv cache
1456
+ if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
1457
+ // - do not defrag small contexts (i.e. < 2048 tokens)
1458
+ // - count the padding towards the number of used tokens
1459
+ const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
1460
+
1461
+ // queue defragmentation for next llama_kv_cache_update
1462
+ if (fragmentation > cparams.defrag_thold) {
1463
+ LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
1464
+
1465
+ kv_self->defrag();
1466
+ }
1467
+ }
1468
+
1469
+ // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1470
+ // overlap with device computation.
1471
+ lm_ggml_backend_sched_reset(sched.get());
1472
+
1473
+ return 0;
1474
+ }
1475
+
1476
+ //
1477
+ // output
1478
+ //
1479
+
1480
+ int32_t llama_context::output_reserve(int32_t n_outputs) {
1481
+ const auto & hparams = model.hparams;
1482
+ const auto & vocab = model.vocab;
1483
+
1484
+ const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
1485
+
1486
+ const auto n_batch = cparams.n_batch;
1487
+ const auto n_vocab = vocab.n_tokens();
1488
+ const auto n_embd = hparams.n_embd;
1489
+
1490
+ // TODO: use a per-batch flag for logits presence instead
1491
+ bool has_logits = !cparams.embeddings;
1492
+ bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
1493
+
1494
+ // TODO: hacky enc-dec support
1495
+ if (model.arch == LLM_ARCH_T5) {
1496
+ has_logits = true;
1497
+ has_embd = true;
1498
+ }
1499
+
1500
+ logits_size = has_logits ? n_vocab*n_outputs_max : 0;
1501
+ embd_size = has_embd ? n_embd*n_outputs_max : 0;
1502
+
1503
+ if (output_ids.empty()) {
1504
+ // init, never resized afterwards
1505
+ output_ids.resize(n_batch);
1506
+ }
1507
+
1508
+ const size_t prev_size = buf_output ? lm_ggml_backend_buffer_get_size(buf_output.get()) : 0;
1509
+ const size_t new_size = (logits_size + embd_size) * sizeof(float);
1510
+
1511
+ // alloc only when more than the current capacity is required
1512
+ // TODO: also consider shrinking the buffer
1513
+ if (!buf_output || prev_size < new_size) {
1514
+ if (buf_output) {
1515
+ #ifndef NDEBUG
1516
+ // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
1517
+ LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
1518
+ #endif
1519
+ buf_output = nullptr;
1520
+ logits = nullptr;
1521
+ embd = nullptr;
1522
+ }
1523
+
1524
+ auto * buft = lm_ggml_backend_cpu_buffer_type();
1525
+ // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
1526
+ auto * output_dev = model.dev_output();
1527
+ auto * output_dev_host_buft = output_dev ? lm_ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
1528
+ if (output_dev_host_buft) {
1529
+ buft = output_dev_host_buft;
1530
+ }
1531
+ buf_output.reset(lm_ggml_backend_buft_alloc_buffer(buft, new_size));
1532
+ if (buf_output == nullptr) {
1533
+ LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
1534
+ return 0;
1535
+ }
1536
+ }
1537
+
1538
+ float * output_base = (float *) lm_ggml_backend_buffer_get_base(buf_output.get());
1539
+
1540
+ logits = has_logits ? output_base : nullptr;
1541
+ embd = has_embd ? output_base + logits_size : nullptr;
1542
+
1543
+ // set all ids as invalid (negative)
1544
+ std::fill(output_ids.begin(), output_ids.end(), -1);
1545
+
1546
+ lm_ggml_backend_buffer_clear(buf_output.get(), 0);
1547
+
1548
+ this->n_outputs = 0;
1549
+ this->n_outputs_max = n_outputs_max;
1550
+
1551
+ return n_outputs_max;
1552
+ }
1553
+
1554
+ void llama_context::output_reorder() {
1555
+ auto & out_ids = sbatch.out_ids;
1556
+ if (!out_ids.empty()) {
1557
+ const uint32_t n_vocab = model.vocab.n_tokens();
1558
+ const uint32_t n_embd = model.hparams.n_embd;
1559
+
1560
+ LM_GGML_ASSERT((size_t) n_outputs == out_ids.size());
1561
+
1562
+ // TODO: is there something more efficient which also minimizes swaps?
1563
+ // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1564
+ for (int32_t i = 0; i < n_outputs - 1; ++i) {
1565
+ int32_t j_min = i;
1566
+ for (int32_t j = i + 1; j < n_outputs; ++j) {
1567
+ if (out_ids[j] < out_ids[j_min]) {
1568
+ j_min = j;
1569
+ }
1570
+ }
1571
+ if (j_min == i) { continue; }
1572
+ std::swap(out_ids[i], out_ids[j_min]);
1573
+ if (logits_size > 0) {
1574
+ for (uint32_t k = 0; k < n_vocab; k++) {
1575
+ std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1576
+ }
1577
+ }
1578
+ if (embd_size > 0) {
1579
+ for (uint32_t k = 0; k < n_embd; k++) {
1580
+ std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1581
+ }
1582
+ }
1583
+ }
1584
+ std::fill(output_ids.begin(), output_ids.end(), -1);
1585
+ for (int32_t i = 0; i < n_outputs; ++i) {
1586
+ output_ids[out_ids[i]] = i;
1587
+ }
1588
+ out_ids.clear();
1589
+ }
1590
+ }
1591
+
1592
+ //
1593
+ // graph
1594
+ //
1595
+
1596
+ int32_t llama_context::graph_max_nodes() const {
1597
+ return std::max<int32_t>(65536, 5*model.n_tensors());
1598
+ }
1599
+
1600
+ lm_ggml_cgraph * llama_context::graph_init() {
1601
+ lm_ggml_init_params params = {
1602
+ /*.mem_size =*/ buf_compute_meta.size(),
1603
+ /*.mem_buffer =*/ buf_compute_meta.data(),
1604
+ /*.no_alloc =*/ true,
1605
+ };
1606
+
1607
+ ctx_compute.reset(lm_ggml_init(params));
1608
+
1609
+ return lm_ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1610
+ }
1611
+
1612
+ llm_graph_result_ptr llama_context::graph_build(
1613
+ lm_ggml_context * ctx,
1614
+ lm_ggml_cgraph * gf,
1615
+ const llama_ubatch & ubatch,
1616
+ llm_graph_type gtype) {
1617
+ return model.build_graph(
1618
+ {
1619
+ /*.ctx =*/ ctx,
1620
+ /*.arch =*/ model.arch,
1621
+ /*.hparams =*/ model.hparams,
1622
+ /*.cparams =*/ cparams,
1623
+ /*.ubatch =*/ ubatch,
1624
+ /*.sched =*/ sched.get(),
1625
+ /*.backend_cpu =*/ backend_cpu,
1626
+ /*.cvec =*/ &cvec,
1627
+ /*.loras =*/ &loras,
1628
+ /*.memory =*/ kv_self.get(),
1629
+ /*.cross =*/ &cross,
1630
+ /*.n_outputs =*/ n_outputs,
1631
+ /*.cb =*/ graph_get_cb(),
1632
+ }, gf, gtype);
1633
+ }
1634
+
1635
+ lm_ggml_status llama_context::graph_compute(
1636
+ lm_ggml_cgraph * gf,
1637
+ bool batched) {
1638
+ int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads;
1639
+ lm_ggml_threadpool_t tp = batched ? threadpool_batch : threadpool;
1640
+
1641
+ if (backend_cpu != nullptr) {
1642
+ auto * reg = lm_ggml_backend_dev_backend_reg(lm_ggml_backend_get_device(backend_cpu));
1643
+ auto * set_threadpool_fn = (decltype(lm_ggml_backend_cpu_set_threadpool) *) lm_ggml_backend_reg_get_proc_address(reg, "lm_ggml_backend_cpu_set_threadpool");
1644
+ set_threadpool_fn(backend_cpu, tp);
1645
+ }
1646
+
1647
+ // set the number of threads for all the backends
1648
+ for (const auto & set_n_threads_fn : set_n_threads_fns) {
1649
+ set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
1650
+ }
1651
+
1652
+ auto status = lm_ggml_backend_sched_graph_compute_async(sched.get(), gf);
1653
+ if (status != LM_GGML_STATUS_SUCCESS) {
1654
+ LLAMA_LOG_ERROR("%s: lm_ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
1655
+ }
1656
+
1657
+ // fprintf(stderr, "splits: %d\n", lm_ggml_backend_sched_get_n_splits(sched));
1658
+
1659
+ return status;
1660
+ }
1661
+
1662
+ llm_graph_cb llama_context::graph_get_cb() const {
1663
+ return [&](const llama_ubatch & ubatch, lm_ggml_tensor * cur, const char * name, int il) {
1664
+ if (il >= 0) {
1665
+ lm_ggml_format_name(cur, "%s-%d", name, il);
1666
+ } else {
1667
+ lm_ggml_set_name(cur, name);
1668
+ }
1669
+
1670
+ if (!cparams.offload_kqv) {
1671
+ if (strcmp(name, "kqv_merged_cont") == 0) {
1672
+ // all nodes between the KV store and the attention output are run on the CPU
1673
+ lm_ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
1674
+ }
1675
+ }
1676
+
1677
+ // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
1678
+ // FIXME: fix in lm_ggml_backend_sched
1679
+ const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer;
1680
+ if (ubatch.n_tokens < 32 || full_offload) {
1681
+ if (il != -1 && strcmp(name, "norm") == 0) {
1682
+ const auto & dev_layer = model.dev_layer(il);
1683
+ for (const auto & backend : backends) {
1684
+ if (lm_ggml_backend_get_device(backend.get()) == dev_layer) {
1685
+ if (lm_ggml_backend_supports_op(backend.get(), cur)) {
1686
+ lm_ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get());
1687
+ }
1688
+ }
1689
+ }
1690
+ }
1691
+ }
1692
+ };
1693
+ }
1694
+
1695
+ //
1696
+ // state save/load
1697
+ //
1698
+
1699
+ class llama_io_write_dummy : public llama_io_write_i {
1700
+ public:
1701
+ llama_io_write_dummy() = default;
1702
+
1703
+ void write(const void * /* src */, size_t size) override {
1704
+ size_written += size;
1705
+ }
1706
+
1707
+ void write_tensor(const lm_ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
1708
+ size_written += size;
1709
+ }
1710
+
1711
+ size_t n_bytes() override {
1712
+ return size_written;
1713
+ }
1714
+
1715
+ private:
1716
+ size_t size_written = 0;
1717
+ };
1718
+
1719
+ class llama_io_write_buffer : public llama_io_write_i {
1720
+ public:
1721
+ llama_io_write_buffer(
1722
+ uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
1723
+
1724
+ void write(const void * src, size_t size) override {
1725
+ if (size > buf_size) {
1726
+ throw std::runtime_error("unexpectedly reached end of buffer");
1402
1727
  }
1403
1728
  memcpy(ptr, src, size);
1404
1729
  ptr += size;
@@ -1406,7 +1731,7 @@ struct llama_data_write_buffer : llama_data_write {
1406
1731
  buf_size -= size;
1407
1732
  }
1408
1733
 
1409
- void write_tensor_data(const struct lm_ggml_tensor * tensor, size_t offset, size_t size) override {
1734
+ void write_tensor(const lm_ggml_tensor * tensor, size_t offset, size_t size) override {
1410
1735
  if (size > buf_size) {
1411
1736
  throw std::runtime_error("unexpectedly reached end of buffer");
1412
1737
  }
@@ -1416,17 +1741,19 @@ struct llama_data_write_buffer : llama_data_write {
1416
1741
  buf_size -= size;
1417
1742
  }
1418
1743
 
1419
- size_t get_size_written() override {
1744
+ size_t n_bytes() override {
1420
1745
  return size_written;
1421
1746
  }
1422
- };
1423
1747
 
1424
- struct llama_data_read_buffer : llama_data_read {
1425
- const uint8_t * ptr;
1748
+ private:
1749
+ uint8_t * ptr;
1426
1750
  size_t buf_size = 0;
1427
- size_t size_read = 0;
1751
+ size_t size_written = 0;
1752
+ };
1428
1753
 
1429
- llama_data_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
1754
+ class llama_io_read_buffer : public llama_io_read_i {
1755
+ public:
1756
+ llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
1430
1757
 
1431
1758
  const uint8_t * read(size_t size) override {
1432
1759
  const uint8_t * base_ptr = ptr;
@@ -1443,40 +1770,44 @@ struct llama_data_read_buffer : llama_data_read {
1443
1770
  memcpy(dst, read(size), size);
1444
1771
  }
1445
1772
 
1446
- size_t get_size_read() override {
1773
+ size_t n_bytes() override {
1447
1774
  return size_read;
1448
1775
  }
1449
- };
1450
1776
 
1451
- struct llama_data_write_file : llama_data_write {
1452
- llama_file * file;
1453
- size_t size_written = 0;
1454
- std::vector<uint8_t> temp_buffer;
1777
+ private:
1778
+ const uint8_t * ptr;
1779
+ size_t buf_size = 0;
1780
+ size_t size_read = 0;
1781
+ };
1455
1782
 
1456
- llama_data_write_file(llama_file * f) : file(f) {}
1783
+ class llama_io_write_file : public llama_io_write_i {
1784
+ public:
1785
+ llama_io_write_file(llama_file * f) : file(f) {}
1457
1786
 
1458
1787
  void write(const void * src, size_t size) override {
1459
1788
  file->write_raw(src, size);
1460
1789
  size_written += size;
1461
1790
  }
1462
1791
 
1463
- void write_tensor_data(const struct lm_ggml_tensor * tensor, size_t offset, size_t size) override {
1792
+ void write_tensor(const lm_ggml_tensor * tensor, size_t offset, size_t size) override {
1464
1793
  temp_buffer.resize(size);
1465
1794
  lm_ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
1466
1795
  write(temp_buffer.data(), temp_buffer.size());
1467
1796
  }
1468
1797
 
1469
- size_t get_size_written() override {
1798
+ size_t n_bytes() override {
1470
1799
  return size_written;
1471
1800
  }
1472
- };
1473
1801
 
1474
- struct llama_data_read_file : llama_data_read {
1802
+ private:
1475
1803
  llama_file * file;
1476
- size_t size_read = 0;
1804
+ size_t size_written = 0;
1477
1805
  std::vector<uint8_t> temp_buffer;
1806
+ };
1478
1807
 
1479
- llama_data_read_file(llama_file * f) : file(f) {}
1808
+ class llama_io_read_file : public llama_io_read_i {
1809
+ public:
1810
+ llama_io_read_file(llama_file * f) : file(f) {}
1480
1811
 
1481
1812
  void read_to(void * dst, size_t size) override {
1482
1813
  file->read_raw(dst, size);
@@ -1489,89 +1820,78 @@ struct llama_data_read_file : llama_data_read {
1489
1820
  return temp_buffer.data();
1490
1821
  }
1491
1822
 
1492
- size_t get_size_read() override {
1823
+ size_t n_bytes() override {
1493
1824
  return size_read;
1494
1825
  }
1495
- };
1496
-
1497
- /** copy state data into either a buffer or file depending on the passed in context
1498
- *
1499
- * file context:
1500
- * llama_file file("/path", "wb");
1501
- * llama_data_write_file data_ctx(&file);
1502
- * llama_state_get_data_internal(ctx, data_ctx);
1503
- *
1504
- * buffer context:
1505
- * std::vector<uint8_t> buf(max_size, 0);
1506
- * llama_data_write_buffer data_ctx(buf.data(), max_size);
1507
- * llama_state_get_data_internal(ctx, data_ctx);
1508
- *
1509
- */
1510
- static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) {
1511
- llama_synchronize(ctx);
1512
-
1513
- data_ctx.write_model_info(ctx);
1514
-
1515
- // copy outputs
1516
- data_ctx.write_output_ids(ctx);
1517
- data_ctx.write_logits(ctx);
1518
- data_ctx.write_embeddings(ctx);
1519
1826
 
1520
- data_ctx.write_kv_cache(ctx);
1827
+ private:
1828
+ llama_file * file;
1829
+ size_t size_read = 0;
1830
+ std::vector<uint8_t> temp_buffer;
1831
+ };
1521
1832
 
1522
- return data_ctx.get_size_written();
1833
+ size_t llama_context::state_get_size() {
1834
+ llama_io_write_dummy io;
1835
+ try {
1836
+ return state_write_data(io);
1837
+ } catch (const std::exception & err) {
1838
+ LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
1839
+ return 0;
1840
+ }
1523
1841
  }
1524
1842
 
1525
- size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) {
1526
- llama_data_write_buffer data_ctx(dst, size);
1843
+ size_t llama_context::state_get_data(uint8_t * dst, size_t size) {
1844
+ llama_io_write_buffer io(dst, size);
1527
1845
  try {
1528
- return llama_state_get_data_internal(ctx, data_ctx);
1846
+ return state_write_data(io);
1529
1847
  } catch (const std::exception & err) {
1530
1848
  LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
1531
1849
  return 0;
1532
1850
  }
1533
1851
  }
1534
1852
 
1535
- // Returns the *actual* size of the state.
1536
- // Intended to be used when saving to state to a buffer.
1537
- size_t llama_state_get_size(struct llama_context * ctx) {
1538
- llama_data_write_dummy data_ctx;
1853
+ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
1854
+ llama_io_read_buffer io(src, size);
1539
1855
  try {
1540
- return llama_state_get_data_internal(ctx, data_ctx);
1856
+ return state_read_data(io);
1541
1857
  } catch (const std::exception & err) {
1542
- LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
1858
+ LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
1543
1859
  return 0;
1544
1860
  }
1545
1861
  }
1546
1862
 
1547
- static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) {
1548
- llama_synchronize(ctx);
1549
-
1550
- data_ctx.read_model_info(ctx);
1551
-
1552
- // set outputs
1553
- data_ctx.read_output_ids(ctx);
1554
- data_ctx.read_logits(ctx);
1555
- data_ctx.read_embeddings(ctx);
1556
-
1557
- data_ctx.read_kv_cache(ctx);
1863
+ size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
1864
+ llama_io_write_dummy io;
1865
+ try {
1866
+ return state_seq_write_data(io, seq_id);
1867
+ } catch (const std::exception & err) {
1868
+ LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
1869
+ return 0;
1870
+ }
1871
+ }
1558
1872
 
1559
- return data_ctx.get_size_read();
1873
+ size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
1874
+ llama_io_write_buffer io(dst, size);
1875
+ try {
1876
+ return state_seq_write_data(io, seq_id);
1877
+ } catch (const std::exception & err) {
1878
+ LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
1879
+ return 0;
1880
+ }
1560
1881
  }
1561
1882
 
1562
- // Sets the state reading from the specified source address
1563
- size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) {
1564
- llama_data_read_buffer data_ctx(src, size);
1883
+ size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
1884
+ llama_io_read_buffer io(src, size);
1565
1885
  try {
1566
- return llama_state_set_data_internal(ctx, data_ctx);
1886
+ return state_seq_read_data(io, seq_id);
1567
1887
  } catch (const std::exception & err) {
1568
1888
  LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
1569
1889
  return 0;
1570
1890
  }
1571
1891
  }
1572
1892
 
1573
- static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
1574
- llama_file file(path_session, "rb");
1893
+ bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
1894
+ llama_file file(filepath, "rb");
1575
1895
 
1576
1896
  // sanity checks
1577
1897
  {
@@ -1601,28 +1921,20 @@ static bool llama_state_load_file_internal(struct llama_context * ctx, const cha
1601
1921
  {
1602
1922
  const size_t n_state_size_cur = file.size() - file.tell();
1603
1923
 
1604
- llama_data_read_file data_ctx(&file);
1605
- const size_t n_read = llama_state_set_data_internal(ctx, data_ctx);
1924
+ llama_io_read_file io( &file);
1925
+ const size_t n_read = state_read_data(io);
1606
1926
 
1607
1927
  if (n_read != n_state_size_cur) {
1608
1928
  LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
1609
1929
  return false;
1610
1930
  }
1611
1931
  }
1612
- return true;
1613
- }
1614
1932
 
1615
- bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
1616
- try {
1617
- return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
1618
- } catch (const std::exception & err) {
1619
- LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
1620
- return false;
1621
- }
1933
+ return true;
1622
1934
  }
1623
1935
 
1624
- static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
1625
- llama_file file(path_session, "wb");
1936
+ bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) {
1937
+ llama_file file(filepath, "wb");
1626
1938
 
1627
1939
  file.write_u32(LLAMA_SESSION_MAGIC);
1628
1940
  file.write_u32(LLAMA_SESSION_VERSION);
@@ -1632,63 +1944,56 @@ static bool llama_state_save_file_internal(struct llama_context * ctx, const cha
1632
1944
  file.write_raw(tokens, sizeof(llama_token) * n_token_count);
1633
1945
 
1634
1946
  // save the context state using stream saving
1635
- llama_data_write_file data_ctx(&file);
1636
- llama_state_get_data_internal(ctx, data_ctx);
1947
+ llama_io_write_file io(&file);
1948
+ state_write_data(io);
1637
1949
 
1638
1950
  return true;
1639
1951
  }
1640
1952
 
1641
- bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
1642
- try {
1643
- return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count);
1644
- } catch (const std::exception & err) {
1645
- LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
1646
- return false;
1647
- }
1648
- }
1649
-
1650
- static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
1651
- llama_synchronize(ctx);
1953
+ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
1954
+ llama_file file(filepath, "rb");
1652
1955
 
1653
- data_ctx.write_kv_cache(ctx, seq_id);
1956
+ // version checks
1957
+ {
1958
+ const uint32_t magic = file.read_u32();
1959
+ const uint32_t version = file.read_u32();
1654
1960
 
1655
- return data_ctx.get_size_written();
1656
- }
1961
+ if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
1962
+ LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
1963
+ return 0;
1964
+ }
1965
+ }
1657
1966
 
1658
- size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) {
1659
- llama_data_write_dummy data_ctx;
1660
- return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
1661
- }
1662
-
1663
- size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
1664
- llama_data_write_buffer data_ctx(dst, size);
1665
- try {
1666
- return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
1667
- } catch (const std::exception & err) {
1668
- LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
1669
- return 0;
1670
- }
1671
- }
1672
-
1673
- static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
1674
- llama_synchronize(ctx);
1967
+ // load the prompt
1968
+ {
1969
+ const uint32_t n_token_count = file.read_u32();
1675
1970
 
1676
- data_ctx.read_kv_cache(ctx, dest_seq_id);
1971
+ if (n_token_count > n_token_capacity) {
1972
+ LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
1973
+ return 0;
1974
+ }
1677
1975
 
1678
- return data_ctx.get_size_read();
1679
- }
1976
+ file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
1977
+ *n_token_count_out = n_token_count;
1978
+ }
1680
1979
 
1681
- size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) {
1682
- llama_data_read_buffer data_ctx(src, size);
1683
- try {
1684
- return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
1685
- } catch (const std::exception & err) {
1686
- LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what());
1687
- return 0;
1980
+ // restore the context state
1981
+ {
1982
+ const size_t state_size = file.size() - file.tell();
1983
+ llama_io_read_file io(&file);
1984
+ const size_t nread = state_seq_read_data(io, seq_id);
1985
+ if (!nread) {
1986
+ LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
1987
+ return 0;
1988
+ }
1989
+ LM_GGML_ASSERT(nread <= state_size);
1990
+ LM_GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
1688
1991
  }
1992
+
1993
+ return file.tell();
1689
1994
  }
1690
1995
 
1691
- static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
1996
+ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) {
1692
1997
  llama_file file(filepath, "wb");
1693
1998
 
1694
1999
  file.write_u32(LLAMA_STATE_SEQ_MAGIC);
@@ -1699,77 +2004,828 @@ static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, con
1699
2004
  file.write_raw(tokens, sizeof(llama_token) * n_token_count);
1700
2005
 
1701
2006
  // save the context state using stream saving
1702
- llama_data_write_file data_ctx(&file);
1703
- llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
2007
+ llama_io_write_file io(&file);
2008
+ state_seq_write_data(io, seq_id);
1704
2009
 
1705
2010
  const size_t res = file.tell();
1706
- LM_GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written());
2011
+ LM_GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
2012
+
1707
2013
  return res;
1708
2014
  }
1709
2015
 
1710
- static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
1711
- llama_file file(filepath, "rb");
2016
+ size_t llama_context::state_write_data(llama_io_write_i & io) {
2017
+ LLAMA_LOG_DEBUG("%s: writing state\n", __func__);
1712
2018
 
1713
- // version checks
2019
+ // write model info
1714
2020
  {
1715
- const uint32_t magic = file.read_u32();
1716
- const uint32_t version = file.read_u32();
2021
+ LLAMA_LOG_DEBUG("%s: - writing model info\n", __func__);
1717
2022
 
1718
- if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
1719
- LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
1720
- return 0;
2023
+ const std::string arch_str = llm_arch_name(model.arch);
2024
+ io.write_string(arch_str);
2025
+ // TODO: add more model-specific info which should prevent loading the session file if not identical
2026
+ }
2027
+
2028
+ // write output ids
2029
+ {
2030
+ LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
2031
+
2032
+ output_reorder();
2033
+
2034
+ const auto n_outputs = this->n_outputs;
2035
+ const auto & output_ids = this->output_ids;
2036
+
2037
+ std::vector<int32_t> w_output_pos;
2038
+
2039
+ LM_GGML_ASSERT(n_outputs <= n_outputs_max);
2040
+
2041
+ w_output_pos.resize(n_outputs);
2042
+
2043
+ // build a more compact representation of the output ids
2044
+ for (size_t i = 0; i < n_batch(); ++i) {
2045
+ // map an output id to a position in the batch
2046
+ int32_t pos = output_ids[i];
2047
+ if (pos >= 0) {
2048
+ LM_GGML_ASSERT(pos < n_outputs);
2049
+ w_output_pos[pos] = i;
2050
+ }
2051
+ }
2052
+
2053
+ io.write(&n_outputs, sizeof(n_outputs));
2054
+
2055
+ if (n_outputs) {
2056
+ io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
1721
2057
  }
1722
2058
  }
1723
2059
 
1724
- // load the prompt
2060
+ // write logits
1725
2061
  {
1726
- const uint32_t n_token_count = file.read_u32();
2062
+ LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
1727
2063
 
1728
- if (n_token_count > n_token_capacity) {
1729
- LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
1730
- return 0;
2064
+ const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
2065
+
2066
+ io.write(&logits_size, sizeof(logits_size));
2067
+
2068
+ if (logits_size) {
2069
+ io.write(logits, logits_size * sizeof(float));
1731
2070
  }
2071
+ }
1732
2072
 
1733
- file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
1734
- *n_token_count_out = n_token_count;
2073
+ // write embeddings
2074
+ {
2075
+ LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
2076
+
2077
+ const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
2078
+
2079
+ io.write(&embd_size, sizeof(embd_size));
2080
+
2081
+ if (embd_size) {
2082
+ io.write(embd, embd_size * sizeof(float));
2083
+ }
1735
2084
  }
1736
2085
 
1737
- // restore the context state
2086
+ LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
2087
+ kv_self->state_write(io);
2088
+
2089
+ return io.n_bytes();
2090
+ }
2091
+
2092
+ size_t llama_context::state_read_data(llama_io_read_i & io) {
2093
+ LLAMA_LOG_DEBUG("%s: reading state\n", __func__);
2094
+
2095
+ // read model info
1738
2096
  {
1739
- const size_t state_size = file.size() - file.tell();
1740
- llama_data_read_file data_ctx(&file);
1741
- const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
1742
- if (!nread) {
1743
- LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
1744
- return 0;
2097
+ LLAMA_LOG_DEBUG("%s: - reading model info\n", __func__);
2098
+
2099
+ const std::string cur_arch_str = llm_arch_name(model.arch);
2100
+
2101
+ std::string arch_str;
2102
+ io.read_string(arch_str);
2103
+ if (cur_arch_str != arch_str) {
2104
+ throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
1745
2105
  }
1746
- LM_GGML_ASSERT(nread <= state_size);
1747
- LM_GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
2106
+ // TODO: add more info which needs to be identical but which is not verified otherwise
1748
2107
  }
1749
2108
 
1750
- return file.tell();
2109
+ // read output ids
2110
+ {
2111
+ LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
2112
+
2113
+ auto n_outputs = this->n_outputs;
2114
+ io.read_to(&n_outputs, sizeof(n_outputs));
2115
+
2116
+ if (n_outputs > output_reserve(n_outputs)) {
2117
+ throw std::runtime_error("could not reserve outputs");
2118
+ }
2119
+
2120
+ std::vector<int32_t> output_pos;
2121
+
2122
+ if (n_outputs) {
2123
+ output_pos.resize(n_outputs);
2124
+ io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
2125
+
2126
+ for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
2127
+ int32_t id = output_pos[i];
2128
+ if ((uint32_t) id >= n_batch()) {
2129
+ throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
2130
+ }
2131
+ this->output_ids[id] = i;
2132
+ }
2133
+
2134
+ this->n_outputs = n_outputs;
2135
+ }
2136
+ }
2137
+
2138
+ // read logits
2139
+ {
2140
+ LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
2141
+
2142
+ uint64_t logits_size;
2143
+ io.read_to(&logits_size, sizeof(logits_size));
2144
+
2145
+ if (this->logits_size < logits_size) {
2146
+ throw std::runtime_error("logits buffer too small");
2147
+ }
2148
+
2149
+ if (logits_size) {
2150
+ io.read_to(this->logits, logits_size * sizeof(float));
2151
+ }
2152
+ }
2153
+
2154
+ // read embeddings
2155
+ {
2156
+ LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
2157
+
2158
+ uint64_t embd_size;
2159
+ io.read_to(&embd_size, sizeof(embd_size));
2160
+
2161
+ if (this->embd_size < embd_size) {
2162
+ throw std::runtime_error("embeddings buffer too small");
2163
+ }
2164
+
2165
+ if (embd_size) {
2166
+ io.read_to(this->embd, embd_size * sizeof(float));
2167
+ }
2168
+ }
2169
+
2170
+ LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
2171
+ kv_self->state_read(io);
2172
+
2173
+ return io.n_bytes();
2174
+ }
2175
+
2176
+ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
2177
+ LM_GGML_UNUSED(seq_id);
2178
+
2179
+ kv_self->state_write(io, seq_id);
2180
+
2181
+ return io.n_bytes();
2182
+ }
2183
+
2184
+ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
2185
+ LM_GGML_UNUSED(seq_id);
2186
+
2187
+ kv_self->state_read(io, seq_id);
2188
+
2189
+ return io.n_bytes();
2190
+ }
2191
+
2192
+ //
2193
+ // perf
2194
+ //
2195
+
2196
+ llama_perf_context_data llama_context::perf_get_data() const {
2197
+ llama_perf_context_data data = {};
2198
+
2199
+ data.t_start_ms = 1e-3 * t_start_us;
2200
+ data.t_load_ms = 1e-3 * t_load_us;
2201
+ data.t_p_eval_ms = 1e-3 * t_p_eval_us;
2202
+ data.t_eval_ms = 1e-3 * t_eval_us;
2203
+ data.n_p_eval = std::max(1, n_p_eval);
2204
+ data.n_eval = std::max(1, n_eval);
2205
+
2206
+ return data;
1751
2207
  }
1752
2208
 
1753
- size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
2209
+ void llama_context::perf_reset() {
2210
+ t_start_us = lm_ggml_time_us();
2211
+ t_eval_us = n_eval = 0;
2212
+ t_p_eval_us = n_p_eval = 0;
2213
+ }
2214
+
2215
+ //
2216
+ // interface implementation
2217
+ //
2218
+
2219
+ llama_context_params llama_context_default_params() {
2220
+ llama_context_params result = {
2221
+ /*.n_ctx =*/ 512,
2222
+ /*.n_batch =*/ 2048,
2223
+ /*.n_ubatch =*/ 512,
2224
+ /*.n_seq_max =*/ 1,
2225
+ /*.n_threads =*/ LM_GGML_DEFAULT_N_THREADS, // TODO: better default
2226
+ /*.n_threads_batch =*/ LM_GGML_DEFAULT_N_THREADS,
2227
+ /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
2228
+ /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
2229
+ /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
2230
+ /*.rope_freq_base =*/ 0.0f,
2231
+ /*.rope_freq_scale =*/ 0.0f,
2232
+ /*.yarn_ext_factor =*/ -1.0f,
2233
+ /*.yarn_attn_factor =*/ 1.0f,
2234
+ /*.yarn_beta_fast =*/ 32.0f,
2235
+ /*.yarn_beta_slow =*/ 1.0f,
2236
+ /*.yarn_orig_ctx =*/ 0,
2237
+ /*.defrag_thold =*/ -1.0f,
2238
+ /*.cb_eval =*/ nullptr,
2239
+ /*.cb_eval_user_data =*/ nullptr,
2240
+ /*.type_k =*/ LM_GGML_TYPE_F16,
2241
+ /*.type_v =*/ LM_GGML_TYPE_F16,
2242
+ /*.logits_all =*/ false,
2243
+ /*.embeddings =*/ false,
2244
+ /*.offload_kqv =*/ true,
2245
+ /*.flash_attn =*/ false,
2246
+ /*.no_perf =*/ true,
2247
+ /*.abort_callback =*/ nullptr,
2248
+ /*.abort_callback_data =*/ nullptr,
2249
+ };
2250
+
2251
+ return result;
2252
+ }
2253
+
2254
+ llama_context * llama_init_from_model(
2255
+ llama_model * model,
2256
+ llama_context_params params) {
2257
+ if (!model) {
2258
+ LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
2259
+ return nullptr;
2260
+ }
2261
+
2262
+ if (params.n_batch == 0 && params.n_ubatch == 0) {
2263
+ LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__);
2264
+ return nullptr;
2265
+ }
2266
+
2267
+ if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) {
2268
+ LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__);
2269
+ return nullptr;
2270
+ }
2271
+
2272
+ if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
2273
+ LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
2274
+ params.flash_attn = false;
2275
+ }
2276
+
2277
+ if (lm_ggml_is_quantized(params.type_v) && !params.flash_attn) {
2278
+ LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
2279
+ return nullptr;
2280
+ }
2281
+
1754
2282
  try {
1755
- return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count);
2283
+ auto * ctx = new llama_context(*model, params);
2284
+ return ctx;
2285
+ } catch (const std::exception & err) {
2286
+ LLAMA_LOG_ERROR("%s: failed to initialize the context: %s\n", __func__, err.what());
2287
+ }
2288
+
2289
+ return nullptr;
2290
+ }
2291
+
2292
+ // deprecated
2293
+ llama_context * llama_new_context_with_model(
2294
+ llama_model * model,
2295
+ llama_context_params params) {
2296
+ return llama_init_from_model(model, params);
2297
+ }
2298
+
2299
+ void llama_free(llama_context * ctx) {
2300
+ delete ctx;
2301
+ }
2302
+
2303
+ uint32_t llama_n_ctx(const llama_context * ctx) {
2304
+ return ctx->n_ctx();
2305
+ }
2306
+
2307
+ uint32_t llama_n_batch(const llama_context * ctx) {
2308
+ return ctx->n_batch();
2309
+ }
2310
+
2311
+ uint32_t llama_n_ubatch(const llama_context * ctx) {
2312
+ return ctx->n_ubatch();
2313
+ }
2314
+
2315
+ uint32_t llama_n_seq_max(const llama_context * ctx) {
2316
+ return ctx->n_seq_max();
2317
+ }
2318
+
2319
+ const llama_model * llama_get_model(const llama_context * ctx) {
2320
+ return &ctx->get_model();
2321
+ }
2322
+
2323
+ llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
2324
+ return ctx->get_kv_self();
2325
+ }
2326
+
2327
+ void llama_kv_self_update(llama_context * ctx) {
2328
+ ctx->kv_self_update();
2329
+ }
2330
+
2331
+ enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
2332
+ return ctx->pooling_type();
2333
+ }
2334
+
2335
+ void llama_attach_threadpool(
2336
+ llama_context * ctx,
2337
+ lm_ggml_threadpool_t threadpool,
2338
+ lm_ggml_threadpool_t threadpool_batch) {
2339
+ ctx->attach_threadpool(threadpool, threadpool_batch);
2340
+ }
2341
+
2342
+ void llama_detach_threadpool(llama_context * ctx) {
2343
+ ctx->detach_threadpool();
2344
+ }
2345
+
2346
+ void llama_set_n_threads(llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
2347
+ ctx->set_n_threads(n_threads, n_threads_batch);
2348
+ }
2349
+
2350
+ int32_t llama_n_threads(llama_context * ctx) {
2351
+ return ctx->n_threads();
2352
+ }
2353
+
2354
+ int32_t llama_n_threads_batch(llama_context * ctx) {
2355
+ return ctx->n_threads_batch();
2356
+ }
2357
+
2358
+ void llama_set_abort_callback(llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
2359
+ ctx->set_abort_callback(abort_callback, abort_callback_data);
2360
+ }
2361
+
2362
+ void llama_set_embeddings(llama_context * ctx, bool embeddings) {
2363
+ ctx->set_embeddings(embeddings);
2364
+ }
2365
+
2366
+ void llama_set_causal_attn(llama_context * ctx, bool causal_attn) {
2367
+ ctx->set_causal_attn(causal_attn);
2368
+ }
2369
+
2370
+ void llama_set_warmup(llama_context * ctx, bool warmup) {
2371
+ ctx->set_warmup(warmup);
2372
+ }
2373
+
2374
+ void llama_synchronize(llama_context * ctx) {
2375
+ ctx->synchronize();
2376
+ }
2377
+
2378
+ float * llama_get_logits(llama_context * ctx) {
2379
+ ctx->synchronize();
2380
+
2381
+ return ctx->get_logits();
2382
+ }
2383
+
2384
+ float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
2385
+ ctx->synchronize();
2386
+
2387
+ return ctx->get_logits_ith(i);
2388
+ }
2389
+
2390
+ float * llama_get_embeddings(llama_context * ctx) {
2391
+ ctx->synchronize();
2392
+
2393
+ return ctx->get_embeddings();
2394
+ }
2395
+
2396
+ float * llama_get_embeddings_ith(llama_context * ctx, int32_t i) {
2397
+ ctx->synchronize();
2398
+
2399
+ return ctx->get_embeddings_ith(i);
2400
+ }
2401
+
2402
+ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
2403
+ ctx->synchronize();
2404
+
2405
+ return ctx->get_embeddings_seq(seq_id);
2406
+ }
2407
+
2408
+ // llama adapter API
2409
+
2410
+ int32_t llama_set_adapter_lora(
2411
+ llama_context * ctx,
2412
+ llama_adapter_lora * adapter,
2413
+ float scale) {
2414
+ ctx->set_adapter_lora(adapter, scale);
2415
+
2416
+ return 0;
2417
+ }
2418
+
2419
+ int32_t llama_rm_adapter_lora(
2420
+ llama_context * ctx,
2421
+ llama_adapter_lora * adapter) {
2422
+ bool res = ctx->rm_adapter_lora(adapter);
2423
+
2424
+ return res ? 0 : -1;
2425
+ }
2426
+
2427
+ void llama_clear_adapter_lora(llama_context * ctx) {
2428
+ ctx->clear_adapter_lora();
2429
+ }
2430
+
2431
+ int32_t llama_apply_adapter_cvec(
2432
+ llama_context * ctx,
2433
+ const float * data,
2434
+ size_t len,
2435
+ int32_t n_embd,
2436
+ int32_t il_start,
2437
+ int32_t il_end) {
2438
+ bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end);
2439
+
2440
+ return res ? 0 : -1;
2441
+ }
2442
+
2443
+ //
2444
+ // kv cache view
2445
+ //
2446
+
2447
+ llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
2448
+ const auto * kv = ctx->get_kv_self();
2449
+ if (kv == nullptr) {
2450
+ LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
2451
+ return {};
2452
+ }
2453
+
2454
+ return llama_kv_cache_view_init(*kv, n_seq_max);
2455
+ }
2456
+
2457
+ void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
2458
+ const auto * kv = ctx->get_kv_self();
2459
+ if (kv == nullptr) {
2460
+ LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
2461
+ return;
2462
+ }
2463
+
2464
+ llama_kv_cache_view_update(view, kv);
2465
+ }
2466
+
2467
+ //
2468
+ // kv cache
2469
+ //
2470
+
2471
+ // deprecated
2472
+ int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
2473
+ return llama_kv_self_n_tokens(ctx);
2474
+ }
2475
+
2476
+ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2477
+ const auto * kv = ctx->get_kv_self();
2478
+ if (!kv) {
2479
+ return 0;
2480
+ }
2481
+
2482
+ return kv->get_n_tokens();
2483
+ }
2484
+
2485
+ // deprecated
2486
+ int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
2487
+ return llama_kv_self_used_cells(ctx);
2488
+ }
2489
+
2490
+ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2491
+ const auto * kv = ctx->get_kv_self();
2492
+ if (!kv) {
2493
+ return 0;
2494
+ }
2495
+
2496
+ return kv->get_used_cells();
2497
+ }
2498
+
2499
+ // deprecated
2500
+ void llama_kv_cache_clear(llama_context * ctx) {
2501
+ llama_kv_self_clear(ctx);
2502
+ }
2503
+
2504
+ void llama_kv_self_clear(llama_context * ctx) {
2505
+ auto * kv = ctx->get_kv_self();
2506
+ if (!kv) {
2507
+ return;
2508
+ }
2509
+
2510
+ kv->clear();
2511
+ }
2512
+
2513
+ // deprecated
2514
+ bool llama_kv_cache_seq_rm(
2515
+ llama_context * ctx,
2516
+ llama_seq_id seq_id,
2517
+ llama_pos p0,
2518
+ llama_pos p1) {
2519
+ return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
2520
+ }
2521
+
2522
+ bool llama_kv_self_seq_rm(
2523
+ llama_context * ctx,
2524
+ llama_seq_id seq_id,
2525
+ llama_pos p0,
2526
+ llama_pos p1) {
2527
+ auto * kv = ctx->get_kv_self();
2528
+ if (!kv) {
2529
+ return true;
2530
+ }
2531
+
2532
+ return kv->seq_rm(seq_id, p0, p1);
2533
+ }
2534
+
2535
+ // deprecated
2536
+ void llama_kv_cache_seq_cp(
2537
+ llama_context * ctx,
2538
+ llama_seq_id seq_id_src,
2539
+ llama_seq_id seq_id_dst,
2540
+ llama_pos p0,
2541
+ llama_pos p1) {
2542
+ return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
2543
+ }
2544
+
2545
+ void llama_kv_self_seq_cp(
2546
+ llama_context * ctx,
2547
+ llama_seq_id seq_id_src,
2548
+ llama_seq_id seq_id_dst,
2549
+ llama_pos p0,
2550
+ llama_pos p1) {
2551
+ auto * kv = ctx->get_kv_self();
2552
+ if (!kv) {
2553
+ return;
2554
+ }
2555
+
2556
+ return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2557
+ }
2558
+
2559
+ // deprecated
2560
+ void llama_kv_cache_seq_keep(
2561
+ llama_context * ctx,
2562
+ llama_seq_id seq_id) {
2563
+ return llama_kv_self_seq_keep(ctx, seq_id);
2564
+ }
2565
+
2566
+ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2567
+ auto * kv = ctx->get_kv_self();
2568
+ if (!kv) {
2569
+ return;
2570
+ }
2571
+
2572
+ return kv->seq_keep(seq_id);
2573
+ }
2574
+
2575
+ // deprecated
2576
+ void llama_kv_cache_seq_add(
2577
+ llama_context * ctx,
2578
+ llama_seq_id seq_id,
2579
+ llama_pos p0,
2580
+ llama_pos p1,
2581
+ llama_pos delta) {
2582
+ return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
2583
+ }
2584
+
2585
+ void llama_kv_self_seq_add(
2586
+ llama_context * ctx,
2587
+ llama_seq_id seq_id,
2588
+ llama_pos p0,
2589
+ llama_pos p1,
2590
+ llama_pos delta) {
2591
+ auto * kv = ctx->get_kv_self();
2592
+ if (!kv) {
2593
+ return;
2594
+ }
2595
+
2596
+ return kv->seq_add(seq_id, p0, p1, delta);
2597
+ }
2598
+
2599
+ // deprecated
2600
+ void llama_kv_cache_seq_div(
2601
+ llama_context * ctx,
2602
+ llama_seq_id seq_id,
2603
+ llama_pos p0,
2604
+ llama_pos p1,
2605
+ int d) {
2606
+ return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
2607
+ }
2608
+
2609
+ void llama_kv_self_seq_div(
2610
+ llama_context * ctx,
2611
+ llama_seq_id seq_id,
2612
+ llama_pos p0,
2613
+ llama_pos p1,
2614
+ int d) {
2615
+ auto * kv = ctx->get_kv_self();
2616
+ if (!kv) {
2617
+ return;
2618
+ }
2619
+
2620
+ return kv->seq_div(seq_id, p0, p1, d);
2621
+ }
2622
+
2623
+ // deprecated
2624
+ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2625
+ return llama_kv_self_seq_pos_max(ctx, seq_id);
2626
+ }
2627
+
2628
+ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2629
+ const auto * kv = ctx->get_kv_self();
2630
+ if (!kv) {
2631
+ return 0;
2632
+ }
2633
+
2634
+ return kv->seq_pos_max(seq_id);
2635
+ }
2636
+
2637
+ // deprecated
2638
+ void llama_kv_cache_defrag(llama_context * ctx) {
2639
+ return llama_kv_self_defrag(ctx);
2640
+ }
2641
+
2642
+ void llama_kv_self_defrag(llama_context * ctx) {
2643
+ auto * kv = ctx->get_kv_self();
2644
+ if (!kv) {
2645
+ return;
2646
+ }
2647
+
2648
+ return kv->defrag();
2649
+ }
2650
+
2651
+ // deprecated
2652
+ bool llama_kv_cache_can_shift(const llama_context * ctx) {
2653
+ return llama_kv_self_can_shift(ctx);
2654
+ }
2655
+
2656
+ bool llama_kv_self_can_shift(const llama_context * ctx) {
2657
+ const auto * kv = ctx->get_kv_self();
2658
+ if (!kv) {
2659
+ return false;
2660
+ }
2661
+
2662
+ return kv->get_can_shift();
2663
+ }
2664
+
2665
+ // deprecated
2666
+ void llama_kv_cache_update(llama_context * ctx) {
2667
+ llama_kv_self_update(ctx);
2668
+ }
2669
+
2670
+ // llama state API
2671
+
2672
+ // deprecated
2673
+ size_t llama_get_state_size(llama_context * ctx) {
2674
+ return llama_state_get_size(ctx);
2675
+ }
2676
+
2677
+ // deprecated
2678
+ size_t llama_copy_state_data(llama_context * ctx, uint8_t * dst) {
2679
+ return llama_state_get_data(ctx, dst, -1);
2680
+ }
2681
+
2682
+ // deprecated
2683
+ size_t llama_set_state_data(llama_context * ctx, const uint8_t * src) {
2684
+ return llama_state_set_data(ctx, src, -1);
2685
+ }
2686
+
2687
+ // deprecated
2688
+ bool llama_load_session_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
2689
+ return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
2690
+ }
2691
+
2692
+ // deprecated
2693
+ bool llama_save_session_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
2694
+ return llama_state_save_file(ctx, path_session, tokens, n_token_count);
2695
+ }
2696
+
2697
+ // Returns the *actual* size of the state.
2698
+ // Intended to be used when saving to state to a buffer.
2699
+ size_t llama_state_get_size(llama_context * ctx) {
2700
+ return ctx->state_get_size();
2701
+ }
2702
+
2703
+ size_t llama_state_get_data(llama_context * ctx, uint8_t * dst, size_t size) {
2704
+ ctx->synchronize();
2705
+
2706
+ return ctx->state_get_data(dst, size);
2707
+ }
2708
+
2709
+ // Sets the state reading from the specified source address
2710
+ size_t llama_state_set_data(llama_context * ctx, const uint8_t * src, size_t size) {
2711
+ ctx->synchronize();
2712
+
2713
+ return ctx->state_set_data(src, size);
2714
+ }
2715
+
2716
+ bool llama_state_load_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
2717
+ ctx->synchronize();
2718
+
2719
+ try {
2720
+ return ctx->state_load_file(path_session, tokens_out, n_token_capacity, n_token_count_out);
2721
+ } catch (const std::exception & err) {
2722
+ LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
2723
+ return false;
2724
+ }
2725
+ }
2726
+
2727
+ bool llama_state_save_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
2728
+ ctx->synchronize();
2729
+
2730
+ try {
2731
+ return ctx->state_save_file(path_session, tokens, n_token_count);
2732
+ } catch (const std::exception & err) {
2733
+ LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
2734
+ return false;
2735
+ }
2736
+ }
2737
+
2738
+ size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
2739
+ return ctx->state_seq_get_size(seq_id);
2740
+ }
2741
+
2742
+ size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
2743
+ ctx->synchronize();
2744
+
2745
+ return ctx->state_seq_get_data(seq_id, dst, size);
2746
+ }
2747
+
2748
+ size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
2749
+ ctx->synchronize();
2750
+
2751
+ return ctx->state_seq_set_data(seq_id, src, size);
2752
+ }
2753
+
2754
+ size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
2755
+ ctx->synchronize();
2756
+
2757
+ try {
2758
+ return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count);
1756
2759
  } catch (const std::exception & err) {
1757
2760
  LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what());
1758
2761
  return 0;
1759
2762
  }
1760
2763
  }
1761
2764
 
1762
- size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
2765
+ size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
2766
+ ctx->synchronize();
2767
+
1763
2768
  try {
1764
- return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out);
2769
+ return ctx->state_seq_load_file(dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out);
1765
2770
  } catch (const std::exception & err) {
1766
2771
  LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what());
1767
2772
  return 0;
1768
2773
  }
1769
2774
  }
1770
2775
 
1771
- const std::vector<std::pair<std::string, struct lm_ggml_tensor *>> & llama_internal_get_tensor_map(
1772
- struct llama_context * ctx
1773
- ) {
1774
- return ctx->model.tensors_by_name;
2776
+ ///
2777
+
2778
+ int32_t llama_encode(
2779
+ llama_context * ctx,
2780
+ llama_batch batch) {
2781
+ const int ret = ctx->encode(batch);
2782
+ if (ret != 0) {
2783
+ LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
2784
+ }
2785
+
2786
+ return ret;
2787
+ }
2788
+
2789
+ int32_t llama_decode(
2790
+ llama_context * ctx,
2791
+ llama_batch batch) {
2792
+ const int ret = ctx->decode(batch);
2793
+ if (ret != 0) {
2794
+ LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2795
+ }
2796
+
2797
+ return ret;
2798
+ }
2799
+
2800
+ //
2801
+ // perf
2802
+ //
2803
+
2804
+ llama_perf_context_data llama_perf_context(const llama_context * ctx) {
2805
+ llama_perf_context_data data = {};
2806
+
2807
+ if (ctx == nullptr) {
2808
+ return data;
2809
+ }
2810
+
2811
+ data = ctx->perf_get_data();
2812
+
2813
+ return data;
2814
+ }
2815
+
2816
+ void llama_perf_context_print(const llama_context * ctx) {
2817
+ const auto data = llama_perf_context(ctx);
2818
+
2819
+ const double t_end_ms = 1e-3 * lm_ggml_time_us();
2820
+
2821
+ LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
2822
+ LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
2823
+ __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
2824
+ LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
2825
+ __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
2826
+ LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
2827
+ }
2828
+
2829
+ void llama_perf_context_reset(llama_context * ctx) {
2830
+ ctx->perf_reset();
1775
2831
  }