cui-llama.rn 1.6.0 → 1.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (285) hide show
  1. package/README.md +35 -7
  2. package/android/src/main/CMakeLists.txt +22 -11
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +42 -6
  4. package/android/src/main/java/com/rnllama/RNLlama.java +139 -4
  5. package/android/src/main/jni.cpp +173 -18
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  14. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
  15. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
  16. package/cpp/LICENSE +21 -0
  17. package/cpp/chat.cpp +129 -107
  18. package/cpp/chat.h +2 -0
  19. package/cpp/common.cpp +58 -78
  20. package/cpp/common.h +29 -21
  21. package/cpp/ggml-alloc.c +4 -1
  22. package/cpp/ggml-backend.cpp +9 -5
  23. package/cpp/ggml-backend.h +4 -4
  24. package/cpp/ggml-cpp.h +1 -1
  25. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  26. package/cpp/ggml-cpu/amx/amx.h +8 -0
  27. package/cpp/ggml-cpu/amx/common.h +91 -0
  28. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  29. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  30. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/binary-ops.h +1 -1
  31. package/cpp/ggml-cpu/common.h +72 -0
  32. package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -103
  33. package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +306 -6
  34. package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +114 -55
  35. package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +32 -16
  36. package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +353 -173
  37. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/ops.h +2 -20
  38. package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
  39. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/simd-mappings.h +7 -3
  40. package/{ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/unary-ops.h +1 -1
  41. package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -6
  42. package/{ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/vec.h +16 -0
  43. package/cpp/ggml-cpu.h +5 -0
  44. package/cpp/ggml-impl.h +16 -9
  45. package/cpp/ggml-llama-sim.metallib +0 -0
  46. package/cpp/ggml-llama.metallib +0 -0
  47. package/cpp/ggml-metal-impl.h +36 -11
  48. package/cpp/ggml-metal.m +810 -176
  49. package/cpp/ggml-opt.cpp +373 -190
  50. package/cpp/ggml-opt.h +49 -28
  51. package/cpp/ggml-quants.c +0 -6
  52. package/cpp/ggml.c +227 -282
  53. package/cpp/ggml.h +82 -101
  54. package/cpp/gguf.cpp +33 -33
  55. package/cpp/json-schema-to-grammar.cpp +3 -0
  56. package/cpp/llama-adapter.cpp +6 -0
  57. package/cpp/llama-arch.cpp +49 -17
  58. package/cpp/llama-arch.h +9 -0
  59. package/cpp/llama-batch.cpp +8 -2
  60. package/cpp/llama-batch.h +2 -1
  61. package/cpp/llama-chat.cpp +39 -16
  62. package/cpp/llama-chat.h +4 -2
  63. package/cpp/llama-context.cpp +440 -611
  64. package/cpp/llama-context.h +44 -33
  65. package/cpp/llama-cparams.h +1 -0
  66. package/cpp/llama-graph.cpp +214 -291
  67. package/cpp/llama-graph.h +69 -21
  68. package/cpp/llama-hparams.cpp +17 -1
  69. package/cpp/llama-hparams.h +39 -5
  70. package/cpp/llama-kv-cache.cpp +2067 -620
  71. package/cpp/llama-kv-cache.h +410 -108
  72. package/cpp/llama-memory.h +12 -1
  73. package/cpp/llama-model-loader.cpp +24 -15
  74. package/cpp/llama-model-saver.cpp +281 -0
  75. package/cpp/llama-model-saver.h +37 -0
  76. package/cpp/llama-model.cpp +1089 -359
  77. package/cpp/llama-model.h +19 -3
  78. package/cpp/llama-sampling.cpp +20 -7
  79. package/cpp/llama-vocab.cpp +54 -9
  80. package/cpp/llama-vocab.h +6 -0
  81. package/cpp/llama.cpp +14 -0
  82. package/cpp/llama.h +86 -142
  83. package/cpp/minja/chat-template.hpp +9 -5
  84. package/cpp/minja/minja.hpp +69 -36
  85. package/cpp/rn-llama.cpp +602 -190
  86. package/cpp/rn-llama.h +34 -8
  87. package/cpp/sampling.cpp +57 -50
  88. package/cpp/tools/mtmd/clip-impl.h +462 -0
  89. package/cpp/tools/mtmd/clip.cpp +4024 -0
  90. package/cpp/tools/mtmd/clip.h +101 -0
  91. package/cpp/tools/mtmd/miniaudio.h +93468 -0
  92. package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
  93. package/cpp/tools/mtmd/mtmd-audio.h +62 -0
  94. package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
  95. package/cpp/tools/mtmd/mtmd.cpp +942 -0
  96. package/cpp/tools/mtmd/mtmd.h +362 -0
  97. package/cpp/tools/mtmd/stb_image.h +7988 -0
  98. package/ios/CMakeLists.txt +20 -10
  99. package/ios/RNLlama.h +6 -0
  100. package/ios/RNLlama.mm +82 -3
  101. package/ios/RNLlamaContext.h +5 -1
  102. package/ios/RNLlamaContext.mm +131 -38
  103. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
  104. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +29 -21
  105. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  106. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  107. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  108. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  109. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  110. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  111. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +82 -101
  112. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  113. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  114. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +4 -2
  115. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +44 -33
  116. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  117. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +69 -21
  118. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +39 -5
  119. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  120. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +12 -1
  121. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  122. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +19 -3
  123. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  124. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +86 -142
  125. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  126. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  127. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +34 -8
  128. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
  129. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  130. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  131. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  132. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +29 -21
  133. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  134. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  135. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  136. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  137. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  138. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  139. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +82 -101
  140. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  141. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  142. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +4 -2
  143. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +44 -33
  144. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  145. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +69 -21
  146. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +39 -5
  147. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  148. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +12 -1
  149. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  150. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +19 -3
  151. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  152. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +86 -142
  153. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  154. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  155. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +34 -8
  156. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  157. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  160. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
  161. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +29 -21
  162. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
  163. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  164. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  165. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  166. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  167. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
  168. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +82 -101
  169. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  170. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  171. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +4 -2
  172. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +44 -33
  173. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
  174. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +69 -21
  175. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +39 -5
  176. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  177. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +12 -1
  178. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
  179. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +19 -3
  180. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
  181. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +86 -142
  182. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  183. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
  184. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +34 -8
  185. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
  186. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  187. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  188. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
  189. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +29 -21
  190. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
  191. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  192. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  193. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  194. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
  195. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
  196. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +82 -101
  197. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  198. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  199. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +4 -2
  200. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +44 -33
  201. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
  202. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +69 -21
  203. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +39 -5
  204. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +410 -108
  205. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +12 -1
  206. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
  207. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +19 -3
  208. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
  209. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +86 -142
  210. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
  211. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
  212. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +34 -8
  213. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
  214. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
  215. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  216. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  217. package/jest/mock.js +33 -7
  218. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  219. package/lib/commonjs/index.js +153 -21
  220. package/lib/commonjs/index.js.map +1 -1
  221. package/lib/module/NativeRNLlama.js.map +1 -1
  222. package/lib/module/index.js +152 -20
  223. package/lib/module/index.js.map +1 -1
  224. package/lib/typescript/NativeRNLlama.d.ts +54 -4
  225. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  226. package/lib/typescript/index.d.ts +72 -6
  227. package/lib/typescript/index.d.ts.map +1 -1
  228. package/package.json +1 -1
  229. package/src/NativeRNLlama.ts +72 -4
  230. package/src/index.ts +212 -38
  231. package/cpp/binary-ops.h +0 -16
  232. package/cpp/ops.h +0 -128
  233. package/cpp/simd-mappings.h +0 -888
  234. package/cpp/unary-ops.h +0 -28
  235. package/cpp/vec.h +0 -802
  236. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  237. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  238. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  239. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  240. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  241. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +0 -128
  242. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  243. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  244. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +0 -802
  245. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  246. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  247. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  248. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  249. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  250. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  251. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  252. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  253. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  254. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  255. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  256. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  257. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +0 -128
  258. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  259. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  260. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +0 -28
  261. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +0 -16
  262. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  263. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  264. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  265. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  266. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +0 -128
  267. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  268. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +0 -888
  269. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  270. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  271. package/lib/commonjs/chat.js +0 -37
  272. package/lib/commonjs/chat.js.map +0 -1
  273. package/lib/module/chat.js +0 -33
  274. package/lib/module/chat.js.map +0 -1
  275. package/lib/typescript/chat.d.ts +0 -10
  276. package/lib/typescript/chat.d.ts.map +0 -1
  277. package/src/chat.ts +0 -44
  278. /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
  279. /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
  280. /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
  281. /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
  282. /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
  283. /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
  284. /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
  285. /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
@@ -6,7 +6,6 @@
6
6
  #include "llama-model.h"
7
7
  #include "llama-kv-cache.h"
8
8
 
9
- #include <cassert>
10
9
  #include <cstring>
11
10
  #include <stdexcept>
12
11
  #include <cinttypes>
@@ -95,6 +94,8 @@ llama_context::llama_context(
95
94
 
96
95
  cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
97
96
 
97
+ cparams.op_offload = params.op_offload;
98
+
98
99
  const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
99
100
 
100
101
  LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
@@ -113,12 +114,10 @@ llama_context::llama_context(
113
114
  }
114
115
 
115
116
  if (n_ctx_per_seq > hparams.n_ctx_train) {
116
- LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
117
+ LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
117
118
  __func__, n_ctx_per_seq, hparams.n_ctx_train);
118
119
  }
119
120
 
120
- logits_all = params.logits_all;
121
-
122
121
  if (!hparams.vocab_only) {
123
122
  // GPU backends
124
123
  for (auto * dev : model.devices) {
@@ -176,44 +175,14 @@ llama_context::llama_context(
176
175
  }
177
176
 
178
177
  // init the memory module
179
- // TODO: for now, always create a unified KV cache
180
178
  if (!hparams.vocab_only) {
181
- kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
182
-
183
- LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
184
-
185
- cparams.n_ctx = LM_GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
186
-
187
- LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
188
-
189
- uint32_t kv_size = cparams.n_ctx;
190
- lm_ggml_type type_k = params.type_k;
191
- lm_ggml_type type_v = params.type_v;
192
-
193
- if (llama_model_is_recurrent(&model)) {
194
- // Mamba needs at least as many KV cells as there are sequences kept at any time
195
- kv_size = std::max((uint32_t) 1, params.n_seq_max);
196
- // it's probably best to keep as much precision as possible for the states
197
- type_k = LM_GGML_TYPE_F32; // required by lm_ggml_ssm_conv for Mamba's conv_states
198
- type_v = LM_GGML_TYPE_F32; // required by lm_ggml_ssm_scan for Mamba's ssm_states
199
- }
200
-
201
- LM_GGML_ASSERT(hparams.n_embd_head_k % lm_ggml_blck_size(type_k) == 0);
202
- LM_GGML_ASSERT(hparams.n_embd_head_v % lm_ggml_blck_size(type_v) == 0);
203
-
204
- if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
205
- throw std::runtime_error("failed to initialize self-attention cache");
206
- }
207
-
208
- {
209
- const size_t memory_size_k = kv_self->size_k_bytes();
210
- const size_t memory_size_v = kv_self->size_v_bytes();
179
+ llama_memory_params params_mem = {
180
+ /*.type_k =*/ params.type_k,
181
+ /*.type_v =*/ params.type_v,
182
+ /*.swa_full =*/ params.swa_full,
183
+ };
211
184
 
212
- LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
213
- (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
214
- lm_ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
215
- lm_ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
216
- }
185
+ memory.reset(model.create_memory(params_mem, cparams));
217
186
  }
218
187
 
219
188
  // init backends
@@ -277,7 +246,7 @@ llama_context::llama_context(
277
246
  }
278
247
  }
279
248
 
280
- sched.reset(lm_ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
249
+ sched.reset(lm_ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload));
281
250
 
282
251
  if (pipeline_parallel) {
283
252
  LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, lm_ggml_backend_sched_get_n_copies(sched.get()));
@@ -285,7 +254,7 @@ llama_context::llama_context(
285
254
  }
286
255
 
287
256
  // reserve worst-case graph
288
- if (!hparams.vocab_only) {
257
+ if (!hparams.vocab_only && memory) {
289
258
  const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
290
259
  const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
291
260
 
@@ -304,7 +273,9 @@ llama_context::llama_context(
304
273
  int n_nodes_tg = -1;
305
274
 
306
275
  // simulate full KV cache
307
- kv_self->n = kv_self->size;
276
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
277
+
278
+ kv_self->set_full();
308
279
 
309
280
  cross.v_embd.clear();
310
281
 
@@ -390,7 +361,9 @@ llama_context::llama_context(
390
361
  }
391
362
  }
392
363
 
393
- llama_context::~llama_context() = default;
364
+ llama_context::~llama_context() {
365
+ lm_ggml_opt_free(opt_ctx);
366
+ }
394
367
 
395
368
  void llama_context::synchronize() {
396
369
  lm_ggml_backend_sched_synchronize(sched.get());
@@ -426,6 +399,18 @@ const llama_model & llama_context::get_model() const {
426
399
  return model;
427
400
  }
428
401
 
402
+ const llama_cparams & llama_context::get_cparams() const {
403
+ return cparams;
404
+ }
405
+
406
+ lm_ggml_backend_sched_t llama_context::get_sched() const {
407
+ return sched.get();
408
+ }
409
+
410
+ lm_ggml_context * llama_context::get_ctx_compute() const {
411
+ return ctx_compute.get();
412
+ }
413
+
429
414
  uint32_t llama_context::n_ctx() const {
430
415
  return cparams.n_ctx;
431
416
  }
@@ -455,345 +440,21 @@ uint32_t llama_context::n_threads_batch() const {
455
440
  }
456
441
 
457
442
  llama_kv_cache * llama_context::get_kv_self() {
458
- return kv_self.get();
443
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
444
+ return kv_self;
459
445
  }
460
446
 
461
447
  const llama_kv_cache * llama_context::get_kv_self() const {
462
- return kv_self.get();
463
- }
464
-
465
- lm_ggml_tensor * llama_context::build_rope_shift(
466
- lm_ggml_context * ctx0,
467
- lm_ggml_tensor * cur,
468
- lm_ggml_tensor * shift,
469
- lm_ggml_tensor * factors,
470
- float freq_base,
471
- float freq_scale,
472
- lm_ggml_backend_buffer * bbuf) const {
473
- const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
474
-
475
- const auto & yarn_ext_factor = cparams.yarn_ext_factor;
476
- const auto & yarn_attn_factor = cparams.yarn_attn_factor;
477
- const auto & yarn_beta_fast = cparams.yarn_beta_fast;
478
- const auto & yarn_beta_slow = cparams.yarn_beta_slow;
479
-
480
- const auto & hparams = model.hparams;
481
-
482
- const auto & n_rot = hparams.n_rot;
483
- const auto & rope_type = hparams.rope_type;
484
-
485
- lm_ggml_tensor * tmp;
486
-
487
- if (lm_ggml_is_quantized(cur->type)) {
488
- // dequantize to f32 -> RoPE -> quantize back
489
- tmp = lm_ggml_cast(ctx0, cur, LM_GGML_TYPE_F32);
490
-
491
- if (bbuf) {
492
- for (const auto & backend : backends) {
493
- // Figure out which backend KV cache belongs to
494
- if (lm_ggml_backend_supports_buft(backend.get(), lm_ggml_backend_buffer_get_type(bbuf))) {
495
- lm_ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get());
496
- break;
497
- }
498
- }
499
- }
500
-
501
- tmp = lm_ggml_rope_ext_inplace(ctx0, tmp,
502
- shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
503
- yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
504
-
505
- tmp = lm_ggml_cpy(ctx0, tmp, cur);
506
- } else {
507
- // we rotate only the first n_rot dimensions
508
- tmp = lm_ggml_rope_ext_inplace(ctx0, cur,
509
- shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
510
- yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
511
- }
512
-
513
- return tmp;
514
- }
515
-
516
- class llm_graph_input_k_shift : public llm_graph_input_i {
517
- public:
518
- llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
519
- virtual ~llm_graph_input_k_shift() = default;
520
-
521
- void set_input(const llama_ubatch * ubatch) override;
522
-
523
- lm_ggml_tensor * k_shift; // I32 [kv_size]
524
-
525
- const llama_kv_cache_unified * kv_self;
526
- };
527
-
528
- void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
529
- LM_GGML_UNUSED(ubatch);
530
-
531
- if (k_shift) {
532
- assert(lm_ggml_backend_buffer_is_host(k_shift->buffer));
533
-
534
- int32_t * data = (int32_t *) k_shift->data;
535
-
536
- for (uint32_t i = 0; i < kv_self->size; ++i) {
537
- data[i] = kv_self->cells[i].delta;
538
- }
539
- }
540
- }
541
-
542
- llm_graph_result_ptr llama_context::build_kv_self_shift(
543
- lm_ggml_context * ctx0,
544
- lm_ggml_cgraph * gf) const {
545
- auto res = std::make_unique<llm_graph_result>();
546
-
547
- const auto & hparams = model.hparams;
548
-
549
- const auto & n_layer = hparams.n_layer;
550
-
551
- const auto & n_embd_head_k = hparams.n_embd_head_k;
552
- //const auto & n_embd_head_v = hparams.n_embd_head_v;
553
-
554
- //LM_GGML_ASSERT(kv_self->size == n_ctx);
555
-
556
- auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
557
-
558
- inp->k_shift = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, cparams.n_ctx);
559
- lm_ggml_set_input(inp->k_shift);
560
-
561
- for (uint32_t il = 0; il < n_layer; ++il) {
562
- const int64_t n_head_kv = hparams.n_head_kv(il);
563
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
564
-
565
- const bool is_swa = hparams.is_swa(il);
566
-
567
- // note: the swa rope params could become part of the cparams in the future
568
- // if we decide to make them configurable, like the non-sliding ones
569
- const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
570
- const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
571
-
572
- lm_ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
573
-
574
- lm_ggml_tensor * k =
575
- lm_ggml_view_3d(ctx0, kv_self->k_l[il],
576
- n_embd_head_k, n_head_kv, kv_self->size,
577
- lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
578
- lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
579
- 0);
580
-
581
- lm_ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
582
-
583
- lm_ggml_build_forward_expand(gf, cur);
584
- }
585
-
586
- res->add_input(std::move(inp));
587
-
588
- return res;
589
- }
590
-
591
- llm_graph_result_ptr llama_context::build_kv_self_defrag(
592
- lm_ggml_context * ctx0,
593
- lm_ggml_cgraph * gf) const {
594
- auto res = std::make_unique<llm_graph_result>();
595
-
596
- const auto & hparams = model.hparams;
597
-
598
- const auto & ids = kv_self->defrag_info.ids;
599
-
600
- #if 0
601
- // CPU defrag
602
- //
603
- // TODO: optimizations are possible:
604
- // - multiple threads
605
- // - avoid copying to the host memory when already there
606
- //
607
- // likely not worth the effort, as we have lm_ggml_graph based defrag
608
- //
609
-
610
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
611
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
612
-
613
- const uint32_t kv_size = size;
614
-
615
- std::vector<uint8_t> buf_k;
616
- std::vector<uint8_t> buf_v;
617
-
618
- for (uint32_t il = 0; il < n_layer; ++il) {
619
- const size_t k_size_row = lm_ggml_row_size(k_l[il]->type, n_embd_k_gqa);
620
- const size_t k_size = lm_ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
621
-
622
- const size_t v_size_el = lm_ggml_type_size(v_l[il]->type);
623
- const size_t v_size = lm_ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
624
-
625
- buf_k.resize(k_size);
626
- buf_v.resize(v_size);
627
-
628
- lm_ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
629
- lm_ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
630
-
631
- // batch move [i, i+nm) to [id, id+nm)
632
- // note: cells can move only to a lower index
633
- for (uint32_t i = 0; i < n_kv; ++i) {
634
- const uint32_t id = ids[i];
635
-
636
- if (i == id || id == n_kv) {
637
- continue;
638
- }
639
-
640
- uint32_t nm = 1;
641
-
642
- while (i + nm < n_kv && ids[i + nm] == id + nm) {
643
- nm++;
644
- }
645
-
646
- // move keys
647
- {
648
- const int64_t os = i*k_size_row;
649
- const int64_t od = id*k_size_row;
650
-
651
- memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
652
- }
653
-
654
- // move values (note: they are transposed)
655
- {
656
- const int64_t os = i;
657
- const int64_t od = id;
658
-
659
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
660
- memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
661
- }
662
- }
663
-
664
- i += nm - 1;
665
- }
666
-
667
- lm_ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
668
- lm_ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
669
- }
670
- #else
671
- for (uint32_t i = 0; i < ids.size(); ++i) {
672
- const uint32_t id = ids[i];
673
-
674
- if (i == id || id == ids.size()) {
675
- continue;
676
- }
677
-
678
- uint32_t nm = 1;
679
-
680
- while (i + nm < ids.size() && ids[i + nm] == id + nm) {
681
- nm++;
682
- }
683
-
684
- for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
685
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
686
- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
687
-
688
- lm_ggml_tensor * view_k_src = lm_ggml_view_2d(ctx0, kv_self->k_l[il],
689
- n_embd_k_gqa, nm,
690
- lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
691
- lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
692
-
693
- lm_ggml_tensor * view_k_dst = lm_ggml_view_2d(ctx0, kv_self->k_l[il],
694
- n_embd_k_gqa, nm,
695
- lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
696
- lm_ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
697
-
698
- lm_ggml_tensor * view_v_src;
699
- lm_ggml_tensor * view_v_dst;
700
-
701
- if (cparams.flash_attn) {
702
- // NOTE: the V cache is not transposed when using flash attention
703
- view_v_src = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
704
- n_embd_v_gqa, nm,
705
- lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
706
- lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
707
-
708
- view_v_dst = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
709
- n_embd_v_gqa, nm,
710
- lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
711
- lm_ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
712
- } else {
713
- view_v_src = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
714
- nm, n_embd_v_gqa,
715
- lm_ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
716
- lm_ggml_row_size(kv_self->v_l[il]->type, i));
717
-
718
- view_v_dst = lm_ggml_view_2d(ctx0, kv_self->v_l[il],
719
- nm, n_embd_v_gqa,
720
- lm_ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
721
- lm_ggml_row_size(kv_self->v_l[il]->type, id));
722
- }
723
-
724
- lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, view_k_src, view_k_dst));
725
- lm_ggml_build_forward_expand(gf, lm_ggml_cpy(ctx0, view_v_src, view_v_dst));
726
- }
727
-
728
- i += nm - 1;
729
- }
730
-
731
- //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
732
- #endif
733
-
734
- return res;
448
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
449
+ return kv_self;
735
450
  }
736
451
 
737
452
  void llama_context::kv_self_update() {
738
- auto & kv = kv_self;
739
-
740
453
  bool need_reserve = false;
741
454
 
742
- if (kv->has_shift) {
743
- if (!kv->get_can_shift()) {
744
- LM_GGML_ABORT("The current context does not support K-shift");
745
- }
746
-
747
- LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
748
-
749
- // apply K-shift if needed
750
- if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
751
- lm_ggml_backend_sched_reset(sched.get());
752
-
753
- auto * gf = graph_init();
754
-
755
- auto res = build_kv_self_shift(ctx_compute.get(), gf);
756
-
757
- lm_ggml_backend_sched_alloc_graph(sched.get(), gf);
758
-
759
- res->set_inputs(nullptr);
760
-
761
- graph_compute(gf, false);
762
-
763
- need_reserve = true;
764
- }
765
-
766
- {
767
- kv->has_shift = false;
768
-
769
- for (uint32_t i = 0; i < kv->size; ++i) {
770
- kv->cells[i].delta = 0;
771
- }
772
- }
773
- }
774
-
775
- // defragment the KV cache if needed
776
- if (kv->do_defrag) {
777
- LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
778
-
779
- if (kv->defrag_prepare(graph_max_nodes())) {
780
- lm_ggml_backend_sched_reset(sched.get());
781
-
782
- auto * gf = graph_init();
783
-
784
- auto res = build_kv_self_defrag(ctx_compute.get(), gf);
785
-
786
- lm_ggml_backend_sched_alloc_graph(sched.get(), gf);
455
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
787
456
 
788
- res->set_inputs(nullptr);
789
-
790
- graph_compute(gf, false);
791
-
792
- need_reserve = true;
793
- }
794
-
795
- kv->do_defrag = false;
796
- }
457
+ need_reserve = kv_self->update(*this);
797
458
 
798
459
  // reserve a worst case graph if needed
799
460
  if (need_reserve) {
@@ -804,7 +465,7 @@ void llama_context::kv_self_update() {
804
465
  uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
805
466
 
806
467
  // simulate full KV cache
807
- kv_self->n = kv_self->size;
468
+ kv_self->set_full();
808
469
 
809
470
  llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
810
471
  llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
@@ -825,9 +486,6 @@ enum llama_pooling_type llama_context::pooling_type() const {
825
486
  }
826
487
 
827
488
  float * llama_context::get_logits() {
828
- // reorder logits for backward compatibility
829
- output_reorder();
830
-
831
489
  return logits;
832
490
  }
833
491
 
@@ -870,9 +528,6 @@ float * llama_context::get_logits_ith(int32_t i) {
870
528
  }
871
529
 
872
530
  float * llama_context::get_embeddings() {
873
- // reorder embeddings for backward compatibility
874
- output_reorder();
875
-
876
531
  return embd;
877
532
  }
878
533
 
@@ -1024,8 +679,8 @@ int llama_context::encode(llama_batch & inp_batch) {
1024
679
  }
1025
680
 
1026
681
  // temporary allocate memory for the input batch if needed
1027
- // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
1028
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
682
+ // note: during encode, we always pass the full sequence starting from pos = 0
683
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
1029
684
 
1030
685
  const llama_batch & batch = batch_allocr.batch;
1031
686
  const int32_t n_tokens = batch.n_tokens;
@@ -1050,11 +705,13 @@ int llama_context::encode(llama_batch & inp_batch) {
1050
705
  t_compute_start_us = lm_ggml_time_us();
1051
706
  }
1052
707
 
708
+ embd_seq.clear();
709
+
1053
710
  n_queued_tokens += n_tokens;
1054
711
 
1055
712
  const int64_t n_embd = hparams.n_embd;
1056
713
 
1057
- sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
714
+ llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
1058
715
 
1059
716
  const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
1060
717
 
@@ -1111,12 +768,12 @@ int llama_context::encode(llama_batch & inp_batch) {
1111
768
  lm_ggml_backend_t backend_embd = lm_ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
1112
769
  LM_GGML_ASSERT(backend_embd != nullptr);
1113
770
 
1114
- LM_GGML_ASSERT(embd != nullptr);
1115
-
1116
771
  switch (cparams.pooling_type) {
1117
772
  case LLAMA_POOLING_TYPE_NONE:
1118
773
  {
1119
774
  // extract token embeddings
775
+ LM_GGML_ASSERT(embd != nullptr);
776
+
1120
777
  LM_GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
1121
778
  lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
1122
779
  } break;
@@ -1141,11 +798,18 @@ int llama_context::encode(llama_batch & inp_batch) {
1141
798
  } break;
1142
799
  case LLAMA_POOLING_TYPE_RANK:
1143
800
  {
1144
- // TODO: this likely should be the same logic as in llama_decoder_internal, but better to
1145
- // wait for an encoder model that requires this pooling type in order to test it
1146
- // https://github.com/ggerganov/llama.cpp/pull/9510
1147
- LM_GGML_ABORT("RANK pooling not implemented yet");
1148
- }
801
+ // extract the rerank score - a single float per sequence
802
+ auto & embd_seq_out = embd_seq;
803
+
804
+ for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
805
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
806
+ if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
807
+ continue;
808
+ }
809
+ embd_seq_out[seq_id].resize(1);
810
+ lm_ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
811
+ }
812
+ } break;
1149
813
  case LLAMA_POOLING_TYPE_UNSPECIFIED:
1150
814
  {
1151
815
  LM_GGML_ABORT("unknown pooling type");
@@ -1183,14 +847,27 @@ int llama_context::encode(llama_batch & inp_batch) {
1183
847
  }
1184
848
 
1185
849
  int llama_context::decode(llama_batch & inp_batch) {
850
+ if (!memory) {
851
+ LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
852
+ return encode(inp_batch);
853
+ }
854
+
1186
855
  if (inp_batch.n_tokens == 0) {
1187
856
  LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
1188
857
  return -1;
1189
858
  }
1190
859
 
860
+ if (!inp_batch.pos) {
861
+ if (inp_batch.seq_id) {
862
+ LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
863
+ return -1;
864
+ }
865
+ }
866
+
867
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
868
+
1191
869
  // temporary allocate memory for the input batch if needed
1192
- // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
1193
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
870
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
1194
871
 
1195
872
  const llama_batch & batch = batch_allocr.batch;
1196
873
 
@@ -1202,7 +879,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1202
879
  const int64_t n_tokens_all = batch.n_tokens;
1203
880
  const int64_t n_embd = hparams.n_embd;
1204
881
 
1205
- llama_kv_cache_guard kv_guard(kv_self.get());
882
+ llama_kv_cache_guard kv_guard(kv_self);
1206
883
 
1207
884
  LM_GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
1208
885
 
@@ -1236,18 +913,14 @@ int llama_context::decode(llama_batch & inp_batch) {
1236
913
  for (uint32_t i = 0; i < n_tokens_all; ++i) {
1237
914
  n_outputs_all += batch.logits[i] != 0;
1238
915
  }
1239
- } else if (logits_all || embd_pooled) {
916
+ } else if (embd_pooled) {
1240
917
  n_outputs_all = n_tokens_all;
1241
918
  } else {
1242
919
  // keep last output only
1243
920
  n_outputs_all = 1;
1244
921
  }
1245
922
 
1246
- const bool logits_all = n_outputs_all == n_tokens_all;
1247
-
1248
- sbatch.from_batch(batch, n_embd,
1249
- /* simple_split */ !kv_self->recurrent,
1250
- /* logits_all */ logits_all);
923
+ llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
1251
924
 
1252
925
  // reserve output buffer
1253
926
  if (output_reserve(n_outputs_all) < n_outputs_all) {
@@ -1261,22 +934,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1261
934
  int64_t n_outputs_prev = 0;
1262
935
 
1263
936
  while (sbatch.n_tokens > 0) {
1264
- llama_ubatch ubatch = llama_ubatch();
1265
-
1266
- const auto & n_ubatch = cparams.n_ubatch;
1267
-
1268
- if (kv_self->recurrent) {
1269
- if (embd_pooled) {
1270
- // Pooled embeddings cannot be split across ubatches (yet)
1271
- ubatch = sbatch.split_seq(cparams.n_ubatch);
1272
- } else {
1273
- // recurrent model architectures are easier to implement
1274
- // with equal-length sequences
1275
- ubatch = sbatch.split_equal(cparams.n_ubatch);
1276
- }
1277
- } else {
1278
- ubatch = sbatch.split_simple(n_ubatch);
1279
- }
937
+ llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
1280
938
 
1281
939
  // count the outputs in this u_batch
1282
940
  {
@@ -1296,24 +954,10 @@ int llama_context::decode(llama_batch & inp_batch) {
1296
954
  }
1297
955
 
1298
956
  // find KV slot
1299
- {
1300
- if (!kv_self->find_slot(ubatch)) {
1301
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1302
-
1303
- return 1;
1304
- }
1305
-
1306
- if (!kv_self->recurrent) {
1307
- // a heuristic, to avoid attending the full cache if it is not yet utilized
1308
- // after enough generations, the benefit from this heuristic disappears
1309
- // if we start defragmenting the cache, the benefit from this will be more important
1310
- const uint32_t pad = kv_self->get_padding(cparams);
1311
- kv_self->n = std::min(kv_self->size, std::max(pad, LM_GGML_PAD(kv_self->cell_max(), pad)));
1312
- }
957
+ if (!kv_self->find_slot(ubatch)) {
958
+ return 1;
1313
959
  }
1314
960
 
1315
- //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
1316
-
1317
961
  lm_ggml_backend_sched_reset(sched.get());
1318
962
  lm_ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1319
963
 
@@ -1427,43 +1071,68 @@ int llama_context::decode(llama_batch & inp_batch) {
1427
1071
  // finalize the batch processing
1428
1072
  kv_guard.commit();
1429
1073
 
1074
+ // set to total number of outputs in the batch, for use in llama_get_logits_ith
1075
+ n_outputs = n_outputs_all;
1076
+
1430
1077
  // set output mappings
1431
1078
  {
1432
1079
  bool sorted_output = true;
1433
1080
 
1434
- LM_GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
1081
+ auto & out_ids = sbatch.out_ids;
1082
+
1083
+ LM_GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
1435
1084
 
1436
1085
  for (int64_t i = 0; i < n_outputs_all; ++i) {
1437
- int64_t out_id = sbatch.out_ids[i];
1086
+ int64_t out_id = out_ids[i];
1438
1087
  output_ids[out_id] = i;
1439
1088
  if (out_id != i) {
1440
1089
  sorted_output = false;
1441
1090
  }
1442
1091
  }
1443
1092
 
1444
- if (sorted_output) {
1445
- sbatch.out_ids.clear();
1093
+ // make the outputs have the same order they had in the user-provided batch
1094
+ // note: this is mostly relevant for recurrent models atm
1095
+ if (!sorted_output) {
1096
+ const uint32_t n_vocab = model.vocab.n_tokens();
1097
+ const uint32_t n_embd = model.hparams.n_embd;
1098
+
1099
+ LM_GGML_ASSERT((size_t) n_outputs == out_ids.size());
1100
+
1101
+ // TODO: is there something more efficient which also minimizes swaps?
1102
+ // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1103
+ for (int32_t i = 0; i < n_outputs - 1; ++i) {
1104
+ int32_t j_min = i;
1105
+ for (int32_t j = i + 1; j < n_outputs; ++j) {
1106
+ if (out_ids[j] < out_ids[j_min]) {
1107
+ j_min = j;
1108
+ }
1109
+ }
1110
+ if (j_min == i) { continue; }
1111
+ std::swap(out_ids[i], out_ids[j_min]);
1112
+ if (logits_size > 0) {
1113
+ for (uint32_t k = 0; k < n_vocab; k++) {
1114
+ std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1115
+ }
1116
+ }
1117
+ if (embd_size > 0) {
1118
+ for (uint32_t k = 0; k < n_embd; k++) {
1119
+ std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1120
+ }
1121
+ }
1122
+ }
1123
+ std::fill(output_ids.begin(), output_ids.end(), -1);
1124
+ for (int32_t i = 0; i < n_outputs; ++i) {
1125
+ output_ids[out_ids[i]] = i;
1126
+ }
1446
1127
  }
1447
1128
  }
1448
1129
 
1449
- // set to total number of outputs in the batch, for use in llama_get_logits_ith
1450
- n_outputs = n_outputs_all;
1451
-
1452
1130
  // wait for the computation to finish (automatically done when obtaining the model output)
1453
1131
  //synchronize();
1454
1132
 
1455
1133
  // decide if we need to defrag the kv cache
1456
- if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
1457
- // - do not defrag small contexts (i.e. < 2048 tokens)
1458
- // - count the padding towards the number of used tokens
1459
- const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
1460
-
1461
- // queue defragmentation for next llama_kv_cache_update
1462
- if (fragmentation > cparams.defrag_thold) {
1463
- LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
1464
-
1465
- kv_self->defrag();
1466
- }
1134
+ if (cparams.defrag_thold > 0.0f) {
1135
+ kv_self->defrag_sched(cparams.defrag_thold);
1467
1136
  }
1468
1137
 
1469
1138
  // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
@@ -1543,52 +1212,12 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1543
1212
  // set all ids as invalid (negative)
1544
1213
  std::fill(output_ids.begin(), output_ids.end(), -1);
1545
1214
 
1546
- lm_ggml_backend_buffer_clear(buf_output.get(), 0);
1547
-
1548
1215
  this->n_outputs = 0;
1549
1216
  this->n_outputs_max = n_outputs_max;
1550
1217
 
1551
1218
  return n_outputs_max;
1552
1219
  }
1553
1220
 
1554
- void llama_context::output_reorder() {
1555
- auto & out_ids = sbatch.out_ids;
1556
- if (!out_ids.empty()) {
1557
- const uint32_t n_vocab = model.vocab.n_tokens();
1558
- const uint32_t n_embd = model.hparams.n_embd;
1559
-
1560
- LM_GGML_ASSERT((size_t) n_outputs == out_ids.size());
1561
-
1562
- // TODO: is there something more efficient which also minimizes swaps?
1563
- // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1564
- for (int32_t i = 0; i < n_outputs - 1; ++i) {
1565
- int32_t j_min = i;
1566
- for (int32_t j = i + 1; j < n_outputs; ++j) {
1567
- if (out_ids[j] < out_ids[j_min]) {
1568
- j_min = j;
1569
- }
1570
- }
1571
- if (j_min == i) { continue; }
1572
- std::swap(out_ids[i], out_ids[j_min]);
1573
- if (logits_size > 0) {
1574
- for (uint32_t k = 0; k < n_vocab; k++) {
1575
- std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
1576
- }
1577
- }
1578
- if (embd_size > 0) {
1579
- for (uint32_t k = 0; k < n_embd; k++) {
1580
- std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
1581
- }
1582
- }
1583
- }
1584
- std::fill(output_ids.begin(), output_ids.end(), -1);
1585
- for (int32_t i = 0; i < n_outputs; ++i) {
1586
- output_ids[out_ids[i]] = i;
1587
- }
1588
- out_ids.clear();
1589
- }
1590
- }
1591
-
1592
1221
  //
1593
1222
  // graph
1594
1223
  //
@@ -1625,7 +1254,7 @@ llm_graph_result_ptr llama_context::graph_build(
1625
1254
  /*.backend_cpu =*/ backend_cpu,
1626
1255
  /*.cvec =*/ &cvec,
1627
1256
  /*.loras =*/ &loras,
1628
- /*.memory =*/ kv_self.get(),
1257
+ /*.memory =*/ memory.get(),
1629
1258
  /*.cross =*/ &cross,
1630
1259
  /*.n_outputs =*/ n_outputs,
1631
1260
  /*.cb =*/ graph_get_cb(),
@@ -2029,8 +1658,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
2029
1658
  {
2030
1659
  LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
2031
1660
 
2032
- output_reorder();
2033
-
2034
1661
  const auto n_outputs = this->n_outputs;
2035
1662
  const auto & output_ids = this->output_ids;
2036
1663
 
@@ -2083,8 +1710,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
2083
1710
  }
2084
1711
  }
2085
1712
 
2086
- LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
2087
- kv_self->state_write(io);
1713
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1714
+
1715
+ if (kv_self != nullptr) {
1716
+ LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
1717
+ kv_self->state_write(io);
1718
+ }
2088
1719
 
2089
1720
  return io.n_bytes();
2090
1721
  }
@@ -2167,8 +1798,13 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
2167
1798
  }
2168
1799
  }
2169
1800
 
2170
- LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
2171
- kv_self->state_read(io);
1801
+ if (memory) {
1802
+ LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
1803
+
1804
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1805
+
1806
+ kv_self->state_read(io);
1807
+ }
2172
1808
 
2173
1809
  return io.n_bytes();
2174
1810
  }
@@ -2176,7 +1812,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
2176
1812
  size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
2177
1813
  LM_GGML_UNUSED(seq_id);
2178
1814
 
2179
- kv_self->state_write(io, seq_id);
1815
+ if (memory) {
1816
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1817
+
1818
+ kv_self->state_write(io, seq_id);
1819
+ }
2180
1820
 
2181
1821
  return io.n_bytes();
2182
1822
  }
@@ -2184,7 +1824,11 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
2184
1824
  size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
2185
1825
  LM_GGML_UNUSED(seq_id);
2186
1826
 
2187
- kv_self->state_read(io, seq_id);
1827
+ if (memory) {
1828
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1829
+
1830
+ kv_self->state_read(io, seq_id);
1831
+ }
2188
1832
 
2189
1833
  return io.n_bytes();
2190
1834
  }
@@ -2212,6 +1856,215 @@ void llama_context::perf_reset() {
2212
1856
  t_p_eval_us = n_p_eval = 0;
2213
1857
  }
2214
1858
 
1859
+ //
1860
+ // training
1861
+ //
1862
+
1863
+ static void llama_set_param(struct lm_ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
1864
+ if (!tensor || tensor->type != LM_GGML_TYPE_F32) {
1865
+ return;
1866
+ }
1867
+ if (!param_filter(tensor, userdata)) {
1868
+ return;
1869
+ }
1870
+ if (strcmp(tensor->name, "token_embd.weight") == 0) {
1871
+ return; // FIXME
1872
+ }
1873
+ if (strcmp(tensor->name, "rope_freqs.weight") == 0) {
1874
+ return; // FIXME
1875
+ }
1876
+ lm_ggml_set_param(tensor);
1877
+ }
1878
+
1879
+ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) {
1880
+ LM_GGML_ASSERT(!opt_ctx);
1881
+ model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx();
1882
+ const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train);
1883
+ const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
1884
+ LM_GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
1885
+ LM_GGML_ASSERT(n_batch % n_ubatch == 0);
1886
+
1887
+ lm_ggml_opt_params opt_params = lm_ggml_opt_default_params(sched.get(), LM_GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
1888
+ opt_params.opt_period = n_batch / n_ubatch;
1889
+ opt_params.get_opt_pars = lopt_params.get_opt_pars;
1890
+ opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
1891
+
1892
+ opt_ctx = lm_ggml_opt_init(opt_params);
1893
+
1894
+ llama_opt_param_filter param_filter = lopt_params.param_filter;
1895
+ void * param_filter_ud = lopt_params.param_filter_ud;
1896
+
1897
+ //llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME
1898
+ llama_set_param(model->type_embd, param_filter, param_filter_ud);
1899
+ llama_set_param(model->pos_embd, param_filter, param_filter_ud);
1900
+ llama_set_param(model->tok_norm, param_filter, param_filter_ud);
1901
+ llama_set_param(model->tok_norm_b, param_filter, param_filter_ud);
1902
+ llama_set_param(model->output_norm, param_filter, param_filter_ud);
1903
+ llama_set_param(model->output_norm_b, param_filter, param_filter_ud);
1904
+ llama_set_param(model->output, param_filter, param_filter_ud);
1905
+ llama_set_param(model->output_b, param_filter, param_filter_ud);
1906
+ llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
1907
+ llama_set_param(model->cls, param_filter, param_filter_ud);
1908
+ llama_set_param(model->cls_b, param_filter, param_filter_ud);
1909
+ llama_set_param(model->cls_out, param_filter, param_filter_ud);
1910
+ llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
1911
+
1912
+ for (struct llama_layer & layer : model->layers) {
1913
+ for (size_t i = 0; i < sizeof(layer)/sizeof(struct lm_ggml_tensor *); ++i) {
1914
+ llama_set_param(reinterpret_cast<struct lm_ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
1915
+ }
1916
+ }
1917
+ }
1918
+
1919
+ void llama_context::opt_epoch_iter(
1920
+ lm_ggml_opt_dataset_t dataset,
1921
+ lm_ggml_opt_result_t result,
1922
+ const std::vector<llama_token> & tokens,
1923
+ const std::vector<llama_token> & labels_sparse,
1924
+ llama_batch & batch,
1925
+ lm_ggml_opt_epoch_callback callback,
1926
+ bool train,
1927
+ int64_t idata_in_loop,
1928
+ int64_t ndata_in_loop,
1929
+ int64_t t_loop_start) {
1930
+ LM_GGML_ASSERT(opt_ctx);
1931
+ const uint32_t n_ctx = llama_model_n_ctx_train(&model);
1932
+ const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
1933
+ const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
1934
+
1935
+ llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
1936
+
1937
+ kv_self->clear();
1938
+ llama_kv_cache_guard kv_guard(kv_self);
1939
+
1940
+ for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
1941
+ batch.n_tokens = n_batch;
1942
+ for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
1943
+ batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
1944
+ batch.pos [pos_batch] = pos_ctx + pos_batch;
1945
+ batch.n_seq_id[pos_batch] = 1;
1946
+ batch.seq_id [pos_batch][0] = 0;
1947
+ batch.logits [pos_batch] = true;
1948
+ }
1949
+
1950
+ const auto n_tokens_all = batch.n_tokens;
1951
+
1952
+ n_queued_tokens += n_tokens_all;
1953
+
1954
+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
1955
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
1956
+
1957
+ embd_seq.clear();
1958
+
1959
+ int64_t n_outputs_all = n_tokens_all;
1960
+
1961
+ llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
1962
+
1963
+ // reserve output buffer
1964
+ if (output_reserve(n_outputs_all) < n_outputs_all) {
1965
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
1966
+ LM_GGML_ABORT("TODO: handle this error");
1967
+ };
1968
+
1969
+ for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
1970
+ llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
1971
+
1972
+ n_outputs = ubatch.n_tokens;
1973
+
1974
+ // TODO: not sure if this is needed
1975
+ if (!kv_self->find_slot(ubatch)) {
1976
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1977
+
1978
+ LM_GGML_ABORT("TODO: handle this error");
1979
+ }
1980
+
1981
+ auto * gf = graph_init();
1982
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
1983
+
1984
+ struct lm_ggml_context * ctx_compute_opt;
1985
+ {
1986
+ const size_t size_gf = lm_ggml_graph_size(gf);
1987
+ const size_t size_meta = 4*size_gf*lm_ggml_tensor_overhead() + 2*lm_ggml_graph_overhead_custom(size_gf, /*grads = */ true);
1988
+ struct lm_ggml_init_params params = {
1989
+ /*.mem_size =*/ size_meta,
1990
+ /*.mem_buffer =*/ nullptr,
1991
+ /*.no_alloc =*/ true,
1992
+ };
1993
+ ctx_compute_opt = lm_ggml_init(params);
1994
+ }
1995
+ lm_ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
1996
+ lm_ggml_opt_alloc(opt_ctx, train);
1997
+ res->set_inputs(&ubatch);
1998
+ {
1999
+ struct lm_ggml_tensor * labels = lm_ggml_opt_labels(opt_ctx);
2000
+ LM_GGML_ASSERT(labels->ne[1] == n_ubatch);
2001
+ lm_ggml_set_zero(labels);
2002
+ const float onef = 1.0f;
2003
+ for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
2004
+ const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
2005
+ LM_GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
2006
+ lm_ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
2007
+ }
2008
+ }
2009
+ lm_ggml_opt_eval(opt_ctx, result);
2010
+ if (callback) {
2011
+ callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
2012
+ }
2013
+ lm_ggml_free(ctx_compute_opt);
2014
+ }
2015
+ }
2016
+
2017
+ kv_guard.commit();
2018
+ }
2019
+
2020
+ void llama_context::opt_epoch(
2021
+ lm_ggml_opt_dataset_t dataset,
2022
+ lm_ggml_opt_result_t result_train,
2023
+ lm_ggml_opt_result_t result_eval,
2024
+ int64_t idata_split,
2025
+ lm_ggml_opt_epoch_callback callback_train,
2026
+ lm_ggml_opt_epoch_callback callback_eval) {
2027
+ const uint32_t n_ctx = this->n_ctx();
2028
+ const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
2029
+ const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
2030
+ const int64_t ndata = lm_ggml_opt_dataset_ndata(dataset);
2031
+
2032
+ LM_GGML_ASSERT(idata_split >= 0);
2033
+ LM_GGML_ASSERT(idata_split <= ndata);
2034
+
2035
+ const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
2036
+
2037
+ struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
2038
+ std::vector<llama_token> tokens(n_ctx);
2039
+ std::vector<llama_token> labels_sparse(n_ctx);
2040
+
2041
+ int64_t idata = 0;
2042
+
2043
+ int64_t t_loop_start = lm_ggml_time_us();
2044
+ int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
2045
+ for (; idata < idata_split; ++idata) {
2046
+ constexpr bool train = true;
2047
+ const int64_t idata_in_loop = idata*ubatch_per_ctx;
2048
+
2049
+ lm_ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
2050
+ opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch,
2051
+ callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
2052
+ }
2053
+
2054
+ t_loop_start = lm_ggml_time_us();
2055
+ ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
2056
+ for (; idata < ndata; ++idata) {
2057
+ constexpr bool train = false;
2058
+ const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
2059
+
2060
+ lm_ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
2061
+ opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch,
2062
+ callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
2063
+ }
2064
+
2065
+ llama_batch_free(batch);
2066
+ }
2067
+
2215
2068
  //
2216
2069
  // interface implementation
2217
2070
  //
@@ -2239,13 +2092,14 @@ llama_context_params llama_context_default_params() {
2239
2092
  /*.cb_eval_user_data =*/ nullptr,
2240
2093
  /*.type_k =*/ LM_GGML_TYPE_F16,
2241
2094
  /*.type_v =*/ LM_GGML_TYPE_F16,
2242
- /*.logits_all =*/ false,
2095
+ /*.abort_callback =*/ nullptr,
2096
+ /*.abort_callback_data =*/ nullptr,
2243
2097
  /*.embeddings =*/ false,
2244
2098
  /*.offload_kqv =*/ true,
2245
2099
  /*.flash_attn =*/ false,
2246
2100
  /*.no_perf =*/ true,
2247
- /*.abort_callback =*/ nullptr,
2248
- /*.abort_callback_data =*/ nullptr,
2101
+ /*.op_offload =*/ true,
2102
+ /*.swa_full =*/ true,
2249
2103
  };
2250
2104
 
2251
2105
  return result;
@@ -2440,65 +2294,51 @@ int32_t llama_apply_adapter_cvec(
2440
2294
  return res ? 0 : -1;
2441
2295
  }
2442
2296
 
2443
- //
2444
- // kv cache view
2445
- //
2446
-
2447
- llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
2448
- const auto * kv = ctx->get_kv_self();
2449
- if (kv == nullptr) {
2450
- LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
2451
- return {};
2452
- }
2453
-
2454
- return llama_kv_cache_view_init(*kv, n_seq_max);
2455
- }
2456
-
2457
- void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
2458
- const auto * kv = ctx->get_kv_self();
2459
- if (kv == nullptr) {
2460
- LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
2461
- return;
2462
- }
2463
-
2464
- llama_kv_cache_view_update(view, kv);
2465
- }
2466
-
2467
2297
  //
2468
2298
  // kv cache
2469
2299
  //
2470
2300
 
2471
2301
  // deprecated
2472
- int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
2473
- return llama_kv_self_n_tokens(ctx);
2474
- }
2475
-
2476
2302
  int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2477
2303
  const auto * kv = ctx->get_kv_self();
2478
2304
  if (!kv) {
2479
2305
  return 0;
2480
2306
  }
2481
2307
 
2482
- return kv->get_n_tokens();
2483
- }
2308
+ int32_t res = 0;
2484
2309
 
2485
- // deprecated
2486
- int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
2487
- return llama_kv_self_used_cells(ctx);
2310
+ for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2311
+ const llama_pos p0 = kv->seq_pos_min(s);
2312
+ const llama_pos p1 = kv->seq_pos_max(s);
2313
+
2314
+ if (p0 >= 0) {
2315
+ res += (p1 - p0) + 1;
2316
+ }
2317
+ }
2318
+
2319
+ return res;
2488
2320
  }
2489
2321
 
2322
+ // deprecated
2323
+ // note: this is the same as above - will be removed anyway, so it's ok
2490
2324
  int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2491
2325
  const auto * kv = ctx->get_kv_self();
2492
2326
  if (!kv) {
2493
2327
  return 0;
2494
2328
  }
2495
2329
 
2496
- return kv->get_used_cells();
2497
- }
2330
+ int32_t res = 0;
2498
2331
 
2499
- // deprecated
2500
- void llama_kv_cache_clear(llama_context * ctx) {
2501
- llama_kv_self_clear(ctx);
2332
+ for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2333
+ const llama_pos p0 = kv->seq_pos_min(s);
2334
+ const llama_pos p1 = kv->seq_pos_max(s);
2335
+
2336
+ if (p0 >= 0) {
2337
+ res += (p1 - p0) + 1;
2338
+ }
2339
+ }
2340
+
2341
+ return res;
2502
2342
  }
2503
2343
 
2504
2344
  void llama_kv_self_clear(llama_context * ctx) {
@@ -2510,15 +2350,6 @@ void llama_kv_self_clear(llama_context * ctx) {
2510
2350
  kv->clear();
2511
2351
  }
2512
2352
 
2513
- // deprecated
2514
- bool llama_kv_cache_seq_rm(
2515
- llama_context * ctx,
2516
- llama_seq_id seq_id,
2517
- llama_pos p0,
2518
- llama_pos p1) {
2519
- return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
2520
- }
2521
-
2522
2353
  bool llama_kv_self_seq_rm(
2523
2354
  llama_context * ctx,
2524
2355
  llama_seq_id seq_id,
@@ -2532,16 +2363,6 @@ bool llama_kv_self_seq_rm(
2532
2363
  return kv->seq_rm(seq_id, p0, p1);
2533
2364
  }
2534
2365
 
2535
- // deprecated
2536
- void llama_kv_cache_seq_cp(
2537
- llama_context * ctx,
2538
- llama_seq_id seq_id_src,
2539
- llama_seq_id seq_id_dst,
2540
- llama_pos p0,
2541
- llama_pos p1) {
2542
- return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
2543
- }
2544
-
2545
2366
  void llama_kv_self_seq_cp(
2546
2367
  llama_context * ctx,
2547
2368
  llama_seq_id seq_id_src,
@@ -2553,14 +2374,7 @@ void llama_kv_self_seq_cp(
2553
2374
  return;
2554
2375
  }
2555
2376
 
2556
- return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2557
- }
2558
-
2559
- // deprecated
2560
- void llama_kv_cache_seq_keep(
2561
- llama_context * ctx,
2562
- llama_seq_id seq_id) {
2563
- return llama_kv_self_seq_keep(ctx, seq_id);
2377
+ kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2564
2378
  }
2565
2379
 
2566
2380
  void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
@@ -2569,17 +2383,7 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2569
2383
  return;
2570
2384
  }
2571
2385
 
2572
- return kv->seq_keep(seq_id);
2573
- }
2574
-
2575
- // deprecated
2576
- void llama_kv_cache_seq_add(
2577
- llama_context * ctx,
2578
- llama_seq_id seq_id,
2579
- llama_pos p0,
2580
- llama_pos p1,
2581
- llama_pos delta) {
2582
- return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
2386
+ kv->seq_keep(seq_id);
2583
2387
  }
2584
2388
 
2585
2389
  void llama_kv_self_seq_add(
@@ -2593,17 +2397,7 @@ void llama_kv_self_seq_add(
2593
2397
  return;
2594
2398
  }
2595
2399
 
2596
- return kv->seq_add(seq_id, p0, p1, delta);
2597
- }
2598
-
2599
- // deprecated
2600
- void llama_kv_cache_seq_div(
2601
- llama_context * ctx,
2602
- llama_seq_id seq_id,
2603
- llama_pos p0,
2604
- llama_pos p1,
2605
- int d) {
2606
- return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
2400
+ kv->seq_add(seq_id, p0, p1, delta);
2607
2401
  }
2608
2402
 
2609
2403
  void llama_kv_self_seq_div(
@@ -2617,40 +2411,35 @@ void llama_kv_self_seq_div(
2617
2411
  return;
2618
2412
  }
2619
2413
 
2620
- return kv->seq_div(seq_id, p0, p1, d);
2414
+ kv->seq_div(seq_id, p0, p1, d);
2621
2415
  }
2622
2416
 
2623
- // deprecated
2624
- llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2625
- return llama_kv_self_seq_pos_max(ctx, seq_id);
2417
+ llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
2418
+ const auto * kv = ctx->get_kv_self();
2419
+ if (!kv) {
2420
+ return -1;
2421
+ }
2422
+
2423
+ return kv->seq_pos_min(seq_id);
2626
2424
  }
2627
2425
 
2628
2426
  llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2629
2427
  const auto * kv = ctx->get_kv_self();
2630
2428
  if (!kv) {
2631
- return 0;
2429
+ return -1;
2632
2430
  }
2633
2431
 
2634
2432
  return kv->seq_pos_max(seq_id);
2635
2433
  }
2636
2434
 
2637
- // deprecated
2638
- void llama_kv_cache_defrag(llama_context * ctx) {
2639
- return llama_kv_self_defrag(ctx);
2640
- }
2641
-
2642
2435
  void llama_kv_self_defrag(llama_context * ctx) {
2643
2436
  auto * kv = ctx->get_kv_self();
2644
2437
  if (!kv) {
2645
2438
  return;
2646
2439
  }
2647
2440
 
2648
- return kv->defrag();
2649
- }
2650
-
2651
- // deprecated
2652
- bool llama_kv_cache_can_shift(const llama_context * ctx) {
2653
- return llama_kv_self_can_shift(ctx);
2441
+ // force defrag
2442
+ kv->defrag_sched(-1.0f);
2654
2443
  }
2655
2444
 
2656
2445
  bool llama_kv_self_can_shift(const llama_context * ctx) {
@@ -2662,11 +2451,6 @@ bool llama_kv_self_can_shift(const llama_context * ctx) {
2662
2451
  return kv->get_can_shift();
2663
2452
  }
2664
2453
 
2665
- // deprecated
2666
- void llama_kv_cache_update(llama_context * ctx) {
2667
- llama_kv_self_update(ctx);
2668
- }
2669
-
2670
2454
  // llama state API
2671
2455
 
2672
2456
  // deprecated
@@ -2789,7 +2573,21 @@ int32_t llama_encode(
2789
2573
  int32_t llama_decode(
2790
2574
  llama_context * ctx,
2791
2575
  llama_batch batch) {
2792
- const int ret = ctx->decode(batch);
2576
+ int ret = ctx->decode(batch);
2577
+
2578
+ // defrag and try again
2579
+ // TODO: distinguish return code when we are sure that even after defrag there is no space available
2580
+ if (ret == 1) {
2581
+ llama_kv_self_defrag(ctx);
2582
+ ret = ctx->decode(batch);
2583
+
2584
+ if (ret == 1) {
2585
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2586
+
2587
+ return ret;
2588
+ }
2589
+ }
2590
+
2793
2591
  if (ret != 0) {
2794
2592
  LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2795
2593
  }
@@ -2829,3 +2627,34 @@ void llama_perf_context_print(const llama_context * ctx) {
2829
2627
  void llama_perf_context_reset(llama_context * ctx) {
2830
2628
  ctx->perf_reset();
2831
2629
  }
2630
+
2631
+ //
2632
+ // training
2633
+ //
2634
+
2635
+ bool llama_opt_param_filter_all(const struct lm_ggml_tensor * tensor, void * userdata) {
2636
+ LM_GGML_UNUSED(tensor);
2637
+ LM_GGML_UNUSED(userdata);
2638
+ return true;
2639
+ }
2640
+
2641
+ void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) {
2642
+ ctx->opt_init(model, lopt_params);
2643
+ }
2644
+
2645
+ void llama_opt_epoch(
2646
+ struct llama_context * ctx,
2647
+ lm_ggml_opt_dataset_t dataset,
2648
+ lm_ggml_opt_result_t result_train,
2649
+ lm_ggml_opt_result_t result_eval,
2650
+ int64_t idata_split,
2651
+ lm_ggml_opt_epoch_callback callback_train,
2652
+ lm_ggml_opt_epoch_callback callback_eval) {
2653
+ ctx->opt_epoch(
2654
+ dataset,
2655
+ result_train,
2656
+ result_eval,
2657
+ idata_split,
2658
+ callback_train,
2659
+ callback_eval);
2660
+ }