@novastera-oss/llamarn 0.2.5 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. package/RNLlamaCpp.podspec +3 -2
  2. package/android/CMakeLists.txt +6 -3
  3. package/android/src/main/cpp/include/llama.h +140 -38
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  12. package/cpp/LlamaCppModel.cpp +48 -67
  13. package/cpp/LlamaCppModel.h +8 -3
  14. package/cpp/PureCppImpl.cpp +1 -1
  15. package/cpp/PureCppImpl.h +2 -2
  16. package/cpp/build-info.cpp +2 -2
  17. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  18. package/cpp/llama.cpp/Makefile +2 -2
  19. package/cpp/llama.cpp/README.md +33 -13
  20. package/cpp/llama.cpp/common/CMakeLists.txt +15 -28
  21. package/cpp/llama.cpp/common/arg.cpp +38 -12
  22. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  23. package/cpp/llama.cpp/common/chat-parser.cpp +9 -3
  24. package/cpp/llama.cpp/common/chat-parser.h +4 -1
  25. package/cpp/llama.cpp/common/chat.cpp +16 -13
  26. package/cpp/llama.cpp/common/chat.h +1 -1
  27. package/cpp/llama.cpp/common/common.cpp +52 -40
  28. package/cpp/llama.cpp/common/common.h +5 -2
  29. package/cpp/llama.cpp/common/json-partial.cpp +5 -4
  30. package/cpp/llama.cpp/common/json-partial.h +2 -1
  31. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
  32. package/cpp/llama.cpp/common/json-schema-to-grammar.h +4 -4
  33. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  34. package/cpp/llama.cpp/convert_hf_to_gguf.py +128 -84
  35. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
  36. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  37. package/cpp/llama.cpp/ggml/include/ggml.h +1 -3
  38. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +49 -13
  39. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +10 -5
  41. package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  42. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +33 -2
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
  74. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +6 -8
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +25 -16
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  79. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-impl.h +2 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  82. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
  83. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
  86. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  88. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  93. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  94. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  96. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
  97. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
  98. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
  99. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
  102. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -46
  103. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  104. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  105. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +118 -11
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -248
  109. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  112. package/cpp/llama.cpp/ggml/src/ggml.c +9 -8
  113. package/cpp/llama.cpp/ggml/src/ggml.cpp +26 -0
  114. package/cpp/llama.cpp/ggml/src/gguf.cpp +19 -2
  115. package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
  116. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
  117. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
  118. package/cpp/llama.cpp/include/llama.h +140 -38
  119. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  120. package/cpp/llama.cpp/src/CMakeLists.txt +4 -1
  121. package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
  122. package/cpp/llama.cpp/src/llama-arch.h +7 -1
  123. package/cpp/llama.cpp/src/llama-batch.cpp +289 -31
  124. package/cpp/llama.cpp/src/llama-batch.h +47 -17
  125. package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
  126. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  127. package/cpp/llama.cpp/src/llama-context.cpp +488 -313
  128. package/cpp/llama.cpp/src/llama-context.h +38 -17
  129. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  130. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  131. package/cpp/llama.cpp/src/llama-graph.cpp +275 -152
  132. package/cpp/llama.cpp/src/llama-graph.h +109 -52
  133. package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
  134. package/cpp/llama.cpp/src/llama-hparams.h +8 -2
  135. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +281 -0
  136. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +133 -0
  137. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +1835 -0
  138. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +308 -0
  139. package/cpp/llama.cpp/src/llama-kv-cells.h +53 -17
  140. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
  141. package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
  142. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +1116 -0
  143. package/cpp/llama.cpp/src/llama-memory-recurrent.h +188 -0
  144. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  145. package/cpp/llama.cpp/src/llama-memory.h +89 -4
  146. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  147. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  148. package/cpp/llama.cpp/src/llama-model.cpp +735 -143
  149. package/cpp/llama.cpp/src/llama-model.h +4 -0
  150. package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
  151. package/cpp/llama.cpp/src/llama-vocab.cpp +39 -25
  152. package/cpp/llama.cpp/src/llama.cpp +11 -7
  153. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  154. package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +10518 -0
  155. package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +93468 -0
  156. package/cpp/llama.cpp/{common → vendor}/minja/chat-template.hpp +1 -1
  157. package/cpp/llama.cpp/{common → vendor}/minja/minja.hpp +1 -1
  158. package/cpp/llama.cpp/{common → vendor/nlohmann}/json.hpp +3027 -2267
  159. package/cpp/llama.cpp/vendor/nlohmann/json_fwd.hpp +187 -0
  160. package/cpp/llama.cpp/vendor/stb/stb_image.h +7988 -0
  161. package/cpp/rn-completion.cpp +65 -10
  162. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  163. package/cpp/{rn-utils.hpp → rn-utils.h} +8 -1
  164. package/ios/include/chat.h +1 -1
  165. package/ios/include/common/minja/chat-template.hpp +1 -1
  166. package/ios/include/common/minja/minja.hpp +1 -1
  167. package/ios/include/common.h +5 -2
  168. package/ios/include/json-schema-to-grammar.h +4 -4
  169. package/ios/include/llama.h +140 -38
  170. package/ios/include/{common → nlohmann}/json.hpp +3027 -2267
  171. package/ios/libs/llama.xcframework/Info.plist +20 -20
  172. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  173. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4617
  174. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +1 -3
  175. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +140 -38
  176. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  177. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  178. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
  179. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3557
  180. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  181. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
  182. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  183. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  184. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
  185. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3559
  186. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +1 -3
  187. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +140 -38
  188. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +1 -3
  189. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +140 -38
  190. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  191. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +1 -3
  192. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +140 -38
  193. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  194. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  195. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  196. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4616
  197. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +1 -3
  198. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +140 -38
  199. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  200. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  201. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4637
  202. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3556
  203. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  204. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
  205. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  206. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  207. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4653
  208. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +1 -3
  209. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +140 -38
  210. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  211. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  212. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4674
  213. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3587
  214. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  215. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
  216. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  217. package/package.json +1 -2
  218. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  219. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  220. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  221. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -2747
  222. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -502
  223. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  224. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  225. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -1,8 +1,14 @@
1
1
  #include "llama-batch.h"
2
2
 
3
+ #include "llama-impl.h"
4
+ #include "llama-cparams.h"
5
+ #include "llama-vocab.h"
6
+ #include "llama-memory.h"
7
+
3
8
  #include <cassert>
4
9
  #include <cstring>
5
10
  #include <algorithm>
11
+ #include <sstream>
6
12
 
7
13
  llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
8
14
  // clear empty sequences
@@ -15,24 +21,31 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
15
21
  break;
16
22
  }
17
23
  }
18
- ubatch_token.resize(!has_embd ? n_ubatch : 0);
19
- ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
20
- ubatch_pos.resize(n_ubatch);
21
- ubatch_n_seq_id.resize(n_ubatch);
22
- ubatch_seq_id.resize(n_ubatch);
23
- ubatch_output.resize(n_ubatch);
24
+
25
+ udatas.push_back({});
26
+
27
+ auto & udata = udatas.back();
28
+
29
+ udata.token.resize(!has_embd ? n_ubatch : 0);
30
+ udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
31
+ udata.pos.resize(n_ubatch);
32
+ udata.n_seq_id.resize(n_ubatch);
33
+ udata.seq_id.resize(n_ubatch);
34
+ udata.output.resize(n_ubatch);
35
+
24
36
  llama_ubatch ubatch = {
25
37
  /*equal_seqs =*/ true,
26
38
  /*n_tokens =*/ 0,
27
39
  /*n_seq_tokens =*/ 0,
28
40
  /*n_seqs =*/ 0,
29
- /*token =*/ !has_embd ? ubatch_token.data() : nullptr,
30
- /*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
31
- /*pos =*/ ubatch_pos.data(),
32
- /*n_seq_id =*/ ubatch_n_seq_id.data(),
33
- /*seq_id =*/ ubatch_seq_id.data(),
34
- /*output =*/ ubatch_output.data(),
41
+ /*token =*/ !has_embd ? udata.token.data() : nullptr,
42
+ /*embd =*/ has_embd ? udata.embd.data() : nullptr,
43
+ /*pos =*/ udata.pos.data(),
44
+ /*n_seq_id =*/ udata.n_seq_id.data(),
45
+ /*seq_id =*/ udata.seq_id.data(),
46
+ /*output =*/ udata.output.data(),
35
47
  };
48
+
36
49
  return ubatch;
37
50
  }
38
51
 
@@ -98,12 +111,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
98
111
  ubatch.seq_id = batch->seq_id + seq.offset;
99
112
  }
100
113
  }
101
- if (logits_all) {
102
- for (size_t i = 0; i < length; ++i) {
103
- ubatch.output[ubatch.n_tokens + i] = 1;
104
- out_ids.push_back(ids[seq.offset + i]);
105
- }
106
- } else if (batch->logits) {
114
+ if (batch->logits) {
107
115
  if (ubatch.equal_seqs) {
108
116
  for (size_t i = 0; i < length; ++i) {
109
117
  size_t id = ids[seq.offset + i];
@@ -190,11 +198,10 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
190
198
  return ubatch;
191
199
  }
192
200
 
193
- llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
201
+ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
194
202
  GGML_ASSERT(batch.n_tokens >= 0);
195
203
  this->batch = &batch;
196
204
  this->n_embd = n_embd;
197
- this->logits_all = logits_all;
198
205
 
199
206
  n_tokens = batch.n_tokens;
200
207
  ids.resize(n_tokens);
@@ -278,17 +285,56 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
278
285
  );
279
286
  }
280
287
 
281
- llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
282
- batch = in_batch;
288
+ llama_batch_allocr::llama_batch_allocr() {
289
+ const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
290
+ debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
291
+
292
+ seq_pos.resize(LLAMA_MAX_SEQ);
293
+ seq_cpl.resize(LLAMA_MAX_SEQ);
294
+ for (auto & cur : seq_cpl) {
295
+ cur.resize(LLAMA_MAX_SEQ);
296
+ }
297
+ }
298
+
299
+ bool llama_batch_allocr::init(
300
+ const llama_batch & batch_inp,
301
+ const llama_vocab & vocab,
302
+ const llama_memory_i * memory,
303
+ bool embd_all) {
304
+ clear();
305
+
306
+ batch = batch_inp;
307
+
283
308
  GGML_ASSERT(batch.n_tokens > 0);
284
- if (!batch.pos) {
285
- assert(p0 >= 0);
286
- pos.resize(batch.n_tokens);
287
- for (int32_t i = 0; i < batch.n_tokens; i++) {
288
- pos[i] = p0 + i;
309
+
310
+ //
311
+ // validate input batch
312
+ //
313
+
314
+ if (batch.token) {
315
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
316
+ if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
317
+ LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
318
+ return false;
319
+ }
289
320
  }
290
- batch.pos = pos.data();
291
321
  }
322
+
323
+ if (batch.seq_id) {
324
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
325
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
326
+ if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
327
+ LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
328
+ return false;
329
+ }
330
+ }
331
+ }
332
+ }
333
+
334
+ //
335
+ // auto-generate missing fields
336
+ //
337
+
292
338
  if (!batch.n_seq_id) {
293
339
  n_seq_id.resize(batch.n_tokens);
294
340
  for (int32_t i = 0; i < batch.n_tokens; i++) {
@@ -296,6 +342,7 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
296
342
  }
297
343
  batch.n_seq_id = n_seq_id.data();
298
344
  }
345
+
299
346
  if (!batch.seq_id) {
300
347
  seq_id.resize(batch.n_tokens + 1);
301
348
  seq_id[batch.n_tokens] = NULL;
@@ -304,10 +351,221 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
304
351
  }
305
352
  batch.seq_id = seq_id.data();
306
353
  }
354
+
355
+ if (!batch.pos) {
356
+ pos.resize(batch.n_tokens);
357
+
358
+ // initialize the starting position for each sequence based on the positions in the memory
359
+ llama_pos p0[LLAMA_MAX_SEQ];
360
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
361
+ if (!memory) {
362
+ p0[s] = 0;
363
+ } else {
364
+ p0[s] = memory->seq_pos_max(s) + 1;
365
+ }
366
+ }
367
+
368
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
369
+ const llama_seq_id seq_id = batch.seq_id[i][0];
370
+
371
+ pos[i] = p0[seq_id];
372
+
373
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
374
+ p0[batch.seq_id[i][s]] = pos[i] + 1;
375
+ }
376
+ }
377
+
378
+ batch.pos = pos.data();
379
+ }
380
+
307
381
  if (!batch.logits) {
308
- logits.resize(batch.n_tokens);
309
- logits[logits.size() - 1] = true;
310
- batch.logits = logits.data();
382
+ if (embd_all) {
383
+ // return the output for all tokens
384
+ output.resize(batch.n_tokens, true);
385
+ } else {
386
+ // return the output only for the last token
387
+ output.resize(batch.n_tokens, false);
388
+ output[output.size() - 1] = true;
389
+ }
390
+
391
+ batch.logits = output.data();
392
+ } else if (embd_all) {
393
+ bool warn = false;
394
+
395
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
396
+ if (batch.logits[i] == 0) {
397
+ warn = true;
398
+ }
399
+ }
400
+
401
+ if (warn) {
402
+ LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
403
+
404
+ output.resize(batch.n_tokens, true);
405
+ batch.logits = output.data();
406
+ }
407
+ }
408
+
409
+ //
410
+ // compute stats
411
+ //
412
+
413
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
414
+ n_outputs += batch.logits[i] != 0;
415
+ }
416
+
417
+ // determine coupled sequences
418
+ // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
419
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
420
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
421
+ seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]);
422
+
423
+ if (s > 0) {
424
+ const llama_seq_id s0 = batch.seq_id[i][0];
425
+ const llama_seq_id s1 = batch.seq_id[i][s];
426
+
427
+ // mark that sequence s1 is coupled to s0
428
+ seq_cpl[s1][s0] = true;
429
+
430
+ // note: the other way around is not necessary for now
431
+ //seq_cpl[s0][s1] = true;
432
+ }
433
+ }
434
+ }
435
+
436
+ if (debug > 0) {
437
+ LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
438
+ LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
439
+ LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
440
+ LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
441
+ LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
442
+ LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) batch.n_seq_id);
443
+ LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) batch.seq_id);
444
+ LLAMA_LOG_DEBUG("%s: logits = %p\n", __func__, (void *) batch.logits);
445
+ LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
446
+
447
+ if (debug > 1) {
448
+ int seq_id_max = 0;
449
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
450
+ for (int s = 0; s < batch.n_seq_id[i]; ++s) {
451
+ for (int s = 0; s < batch.n_seq_id[i]; ++s) {
452
+ seq_id_max = std::max(seq_id_max, batch.seq_id[i][s]);
453
+ }
454
+ }
455
+ }
456
+ ++seq_id_max;
457
+
458
+ LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
459
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
460
+ std::vector<int8_t> seq_id(seq_id_max);
461
+
462
+ for (int s = 0; s < batch.n_seq_id[i]; ++s) {
463
+ seq_id[batch.seq_id[i][s]] = 1;
464
+ }
465
+
466
+ std::stringstream ss;
467
+ for (int s = 0; s < seq_id_max; ++s) {
468
+ if (seq_id[s]) {
469
+ ss << s%10;
470
+ } else {
471
+ ss << ".";
472
+ }
473
+ }
474
+
475
+ LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
476
+ __func__, i, batch.token[i], vocab.token_to_piece(batch.token[i]).c_str(),
477
+ batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
478
+ }
479
+ LLAMA_LOG_DEBUG("%s: ]\n", __func__);
480
+
481
+ LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
482
+ for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
483
+ if (seq_pos[s0].empty()) {
484
+ continue;
485
+ }
486
+
487
+ std::stringstream ss;
488
+ for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
489
+ if (seq_cpl[s0][s1]) {
490
+ ss << s1 << " ";
491
+ }
492
+ }
493
+
494
+ LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
495
+ __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
496
+ }
497
+ LLAMA_LOG_DEBUG("%s: ]\n", __func__);
498
+ }
499
+ }
500
+
501
+ //
502
+ // consistency checks
503
+ //
504
+
505
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
506
+ if (seq_pos[s].empty()) {
507
+ continue;
508
+ }
509
+
510
+ if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
511
+ LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
512
+ return false;
513
+ }
514
+
515
+ if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
516
+ LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
517
+ return false;
518
+ }
519
+ }
520
+
521
+ if (memory) {
522
+ for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
523
+ for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
524
+ if (seq_cpl[s0][s1]) {
525
+ if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
526
+ memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
527
+ LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
528
+ return false;
529
+ }
530
+ }
531
+ }
532
+ }
533
+ }
534
+
535
+ return true;
536
+ }
537
+
538
+ const llama_batch & llama_batch_allocr::get_batch() const {
539
+ return batch;
540
+ }
541
+
542
+ uint32_t llama_batch_allocr::get_n_outputs() const {
543
+ return n_outputs;
544
+ }
545
+
546
+ llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
547
+ return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
548
+ }
549
+
550
+ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
551
+ return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
552
+ }
553
+
554
+ void llama_batch_allocr::clear() {
555
+ n_outputs = 0;
556
+
557
+ batch = {};
558
+ pos.clear();
559
+ n_seq_id.clear();
560
+ seq_id.clear();
561
+ output.clear();
562
+
563
+ for (auto & cur : seq_pos) {
564
+ cur.clear();
565
+ }
566
+
567
+ for (auto & cur : seq_cpl) {
568
+ std::fill(cur.begin(), cur.end(), false);
311
569
  }
312
570
  }
313
571
 
@@ -4,6 +4,7 @@
4
4
 
5
5
  #include <array>
6
6
  #include <vector>
7
+ #include <set>
7
8
 
8
9
  // very similar to llama_batch,
9
10
  // but has more metadata about sequences
@@ -11,7 +12,7 @@ struct llama_ubatch {
11
12
  bool equal_seqs;
12
13
  // TODO: whole_seqs for embeddings?
13
14
 
14
- uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
15
+ uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
15
16
  uint32_t n_seq_tokens; // tokens per sequence
16
17
  uint32_t n_seqs;
17
18
 
@@ -39,8 +40,6 @@ struct llama_sbatch {
39
40
 
40
41
  size_t n_embd;
41
42
 
42
- bool logits_all; // TODO: remove once lctx.logits_all is removed too
43
-
44
43
  // sorted indices into the batch
45
44
  std::vector<int64_t> ids;
46
45
  // batch indices of the output
@@ -49,13 +48,18 @@ struct llama_sbatch {
49
48
 
50
49
  const llama_batch * batch = nullptr;
51
50
 
52
- // buffers for the ubatch
53
- std::vector<llama_token> ubatch_token;
54
- std::vector<float> ubatch_embd;
55
- std::vector<llama_pos> ubatch_pos;
56
- std::vector<int32_t> ubatch_n_seq_id;
57
- std::vector<llama_seq_id *> ubatch_seq_id;
58
- std::vector<int8_t> ubatch_output;
51
+ // buffers for the ubatches
52
+ // TODO: very hacky, this needs a complete rework
53
+ struct ubatch_data {
54
+ std::vector<llama_token> token;
55
+ std::vector<float> embd;
56
+ std::vector<llama_pos> pos;
57
+ std::vector<int32_t> n_seq_id;
58
+ std::vector<llama_seq_id *> seq_id;
59
+ std::vector<int8_t> output;
60
+ };
61
+
62
+ std::vector<ubatch_data> udatas;
59
63
 
60
64
  llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
61
65
 
@@ -71,19 +75,45 @@ struct llama_sbatch {
71
75
  llama_ubatch split_seq(size_t n_ubatch);
72
76
 
73
77
  llama_sbatch() = default;
74
- llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
78
+ llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
75
79
  };
76
80
 
77
- // temporary allocate memory for the input batch if needed
78
- struct llama_batch_allocr {
79
- struct llama_batch batch;
81
+ // a helper for sanitizing and fulfilling a batch
82
+ class llama_batch_allocr {
83
+ public:
84
+ llama_batch_allocr();
85
+
86
+ // sanitize and auto-gen missing data in the input batch
87
+ // memory is optional. if provided will be used to check for sequence continuity and to determine the positions
88
+ bool init(
89
+ const llama_batch & batch_inp,
90
+ const llama_vocab & vocab,
91
+ const llama_memory_i * memory,
92
+ bool embd_all);
93
+
94
+ const llama_batch & get_batch() const;
95
+
96
+ uint32_t get_n_outputs() const;
97
+
98
+ llama_pos seq_pos_min(llama_seq_id seq_id) const;
99
+ llama_pos seq_pos_max(llama_seq_id seq_id) const;
100
+
101
+ private:
102
+ void clear();
103
+
104
+ llama_batch batch;
105
+
106
+ uint32_t n_outputs;
80
107
 
81
108
  std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
109
+
82
110
  std::vector<llama_pos> pos;
83
111
  std::vector<int32_t> n_seq_id;
84
112
  std::vector<llama_seq_id *> seq_id;
85
- std::vector<int8_t> logits;
113
+ std::vector<int8_t> output;
114
+
115
+ std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
116
+ std::vector<std::vector<bool>> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
86
117
 
87
- // optionally fulfill the batch returned by llama_batch_get_one
88
- llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
118
+ int debug;
89
119
  };
@@ -183,6 +183,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
183
183
  return LLM_CHAT_TEMPLATE_BAILING;
184
184
  } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
185
185
  return LLM_CHAT_TEMPLATE_LLAMA4;
186
+ } else if (tmpl_contains("<|endofuserprompt|>")) {
187
+ return LLM_CHAT_TEMPLATE_DOTS1;
186
188
  }
187
189
  return LLM_CHAT_TEMPLATE_UNKNOWN;
188
190
  }
@@ -331,7 +333,7 @@ int32_t llm_chat_apply_template(
331
333
  std::string role(message->role);
332
334
  if (role == "system") {
333
335
  // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
334
- system_prompt = trim(message->content);
336
+ system_prompt += trim(message->content);
335
337
  continue;
336
338
  }
337
339
  // in gemma, "assistant" is "model"
@@ -353,7 +355,7 @@ int32_t llm_chat_apply_template(
353
355
  std::string role(message->role);
354
356
  if (role == "system") {
355
357
  // there is no system message support, we will merge it with user prompt
356
- system_prompt = message->content;
358
+ system_prompt += message->content;
357
359
  continue;
358
360
  } else if (role == "user") {
359
361
  ss << "Human: ";
@@ -643,6 +645,21 @@ int32_t llm_chat_apply_template(
643
645
  if (add_ass) {
644
646
  ss << "Assistant:";
645
647
  }
648
+ } else if (tmpl == LLM_CHAT_TEMPLATE_DOTS1) {
649
+ // dots.llm1.inst (DOTS1)
650
+ for (auto message : chat) {
651
+ std::string role(message->role);
652
+ if (role == "system") {
653
+ ss << "<|system|>" << message->content << "<|endofsystem|>";
654
+ } else if (role == "user") {
655
+ ss << "<|userprompt|>" << message->content << "<|endofuserprompt|>";
656
+ } else {
657
+ ss << "<|response|>" << message->content << "<|endofresponse|>";
658
+ }
659
+ }
660
+ if (add_ass) {
661
+ ss << "<|response|>";
662
+ }
646
663
  } else {
647
664
  // template not supported
648
665
  return -1;
@@ -43,6 +43,7 @@ enum llm_chat_template {
43
43
  LLM_CHAT_TEMPLATE_BAILING,
44
44
  LLM_CHAT_TEMPLATE_LLAMA4,
45
45
  LLM_CHAT_TEMPLATE_SMOLVLM,
46
+ LLM_CHAT_TEMPLATE_DOTS1,
46
47
  LLM_CHAT_TEMPLATE_UNKNOWN,
47
48
  };
48
49