@novastera-oss/llamarn 0.2.6 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. package/android/src/main/cpp/include/llama.h +134 -36
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +2 -2
  11. package/cpp/LlamaCppModel.h +3 -3
  12. package/cpp/PureCppImpl.cpp +1 -1
  13. package/cpp/PureCppImpl.h +2 -2
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  16. package/cpp/llama.cpp/Makefile +2 -2
  17. package/cpp/llama.cpp/README.md +32 -13
  18. package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
  19. package/cpp/llama.cpp/common/arg.cpp +30 -6
  20. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  21. package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
  22. package/cpp/llama.cpp/common/chat-parser.h +2 -0
  23. package/cpp/llama.cpp/common/chat.cpp +12 -9
  24. package/cpp/llama.cpp/common/chat.h +1 -1
  25. package/cpp/llama.cpp/common/common.cpp +50 -40
  26. package/cpp/llama.cpp/common/common.h +5 -2
  27. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  28. package/cpp/llama.cpp/convert_hf_to_gguf.py +97 -56
  29. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
  30. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  31. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +47 -13
  32. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  34. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  35. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -2
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +5 -8
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +6 -8
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  70. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  72. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
  74. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
  76. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  84. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -38
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +431 -247
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  101. package/cpp/llama.cpp/ggml/src/ggml.c +0 -6
  102. package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
  103. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
  104. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
  105. package/cpp/llama.cpp/include/llama.h +134 -36
  106. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  107. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  108. package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
  109. package/cpp/llama.cpp/src/llama-arch.h +7 -1
  110. package/cpp/llama.cpp/src/llama-batch.cpp +270 -19
  111. package/cpp/llama.cpp/src/llama-batch.h +36 -11
  112. package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
  113. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  114. package/cpp/llama.cpp/src/llama-context.cpp +313 -213
  115. package/cpp/llama.cpp/src/llama-context.h +16 -12
  116. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  117. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  118. package/cpp/llama.cpp/src/llama-graph.cpp +249 -129
  119. package/cpp/llama.cpp/src/llama-graph.h +90 -34
  120. package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
  121. package/cpp/llama.cpp/src/llama-hparams.h +8 -2
  122. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +82 -50
  123. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  124. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +292 -174
  125. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +68 -38
  126. package/cpp/llama.cpp/src/llama-kv-cells.h +18 -13
  127. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
  128. package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
  129. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +266 -282
  130. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +54 -57
  131. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  132. package/cpp/llama.cpp/src/llama-memory.h +64 -23
  133. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  134. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  135. package/cpp/llama.cpp/src/llama-model.cpp +726 -141
  136. package/cpp/llama.cpp/src/llama-model.h +4 -0
  137. package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
  138. package/cpp/llama.cpp/src/llama-vocab.cpp +32 -23
  139. package/cpp/llama.cpp/src/llama.cpp +11 -7
  140. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  141. package/cpp/rn-completion.cpp +2 -2
  142. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  143. package/ios/include/chat.h +1 -1
  144. package/ios/include/common.h +5 -2
  145. package/ios/include/llama.h +134 -36
  146. package/ios/libs/llama.xcframework/Info.plist +18 -18
  147. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  148. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  149. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +134 -36
  150. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  151. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  152. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  153. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  154. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  155. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3624
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +134 -36
  160. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +134 -36
  161. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  162. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +134 -36
  163. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  164. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  165. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  167. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +134 -36
  168. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  173. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  175. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4725
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +134 -36
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  178. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4746
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3652
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  183. package/package.json +1 -2
  184. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  185. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  186. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  187. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
  188. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
  189. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  190. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  191. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  192. /package/cpp/{rn-utils.hpp → rn-utils.h} +0 -0
@@ -1,8 +1,14 @@
1
1
  #include "llama-batch.h"
2
2
 
3
+ #include "llama-impl.h"
4
+ #include "llama-cparams.h"
5
+ #include "llama-vocab.h"
6
+ #include "llama-memory.h"
7
+
3
8
  #include <cassert>
4
9
  #include <cstring>
5
10
  #include <algorithm>
11
+ #include <sstream>
6
12
 
7
13
  llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
8
14
  // clear empty sequences
@@ -105,12 +111,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
105
111
  ubatch.seq_id = batch->seq_id + seq.offset;
106
112
  }
107
113
  }
108
- if (logits_all) {
109
- for (size_t i = 0; i < length; ++i) {
110
- ubatch.output[ubatch.n_tokens + i] = 1;
111
- out_ids.push_back(ids[seq.offset + i]);
112
- }
113
- } else if (batch->logits) {
114
+ if (batch->logits) {
114
115
  if (ubatch.equal_seqs) {
115
116
  for (size_t i = 0; i < length; ++i) {
116
117
  size_t id = ids[seq.offset + i];
@@ -197,11 +198,10 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
197
198
  return ubatch;
198
199
  }
199
200
 
200
- llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
201
+ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
201
202
  GGML_ASSERT(batch.n_tokens >= 0);
202
203
  this->batch = &batch;
203
204
  this->n_embd = n_embd;
204
- this->logits_all = logits_all;
205
205
 
206
206
  n_tokens = batch.n_tokens;
207
207
  ids.resize(n_tokens);
@@ -285,17 +285,56 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
285
285
  );
286
286
  }
287
287
 
288
- llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
289
- batch = in_batch;
288
+ llama_batch_allocr::llama_batch_allocr() {
289
+ const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
290
+ debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
291
+
292
+ seq_pos.resize(LLAMA_MAX_SEQ);
293
+ seq_cpl.resize(LLAMA_MAX_SEQ);
294
+ for (auto & cur : seq_cpl) {
295
+ cur.resize(LLAMA_MAX_SEQ);
296
+ }
297
+ }
298
+
299
+ bool llama_batch_allocr::init(
300
+ const llama_batch & batch_inp,
301
+ const llama_vocab & vocab,
302
+ const llama_memory_i * memory,
303
+ bool embd_all) {
304
+ clear();
305
+
306
+ batch = batch_inp;
307
+
290
308
  GGML_ASSERT(batch.n_tokens > 0);
291
- if (!batch.pos) {
292
- assert(p0 >= 0);
293
- pos.resize(batch.n_tokens);
294
- for (int32_t i = 0; i < batch.n_tokens; i++) {
295
- pos[i] = p0 + i;
309
+
310
+ //
311
+ // validate input batch
312
+ //
313
+
314
+ if (batch.token) {
315
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
316
+ if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
317
+ LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
318
+ return false;
319
+ }
320
+ }
321
+ }
322
+
323
+ if (batch.seq_id) {
324
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
325
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
326
+ if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
327
+ LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
328
+ return false;
329
+ }
330
+ }
296
331
  }
297
- batch.pos = pos.data();
298
332
  }
333
+
334
+ //
335
+ // auto-generate missing fields
336
+ //
337
+
299
338
  if (!batch.n_seq_id) {
300
339
  n_seq_id.resize(batch.n_tokens);
301
340
  for (int32_t i = 0; i < batch.n_tokens; i++) {
@@ -303,6 +342,7 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
303
342
  }
304
343
  batch.n_seq_id = n_seq_id.data();
305
344
  }
345
+
306
346
  if (!batch.seq_id) {
307
347
  seq_id.resize(batch.n_tokens + 1);
308
348
  seq_id[batch.n_tokens] = NULL;
@@ -311,10 +351,221 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
311
351
  }
312
352
  batch.seq_id = seq_id.data();
313
353
  }
354
+
355
+ if (!batch.pos) {
356
+ pos.resize(batch.n_tokens);
357
+
358
+ // initialize the starting position for each sequence based on the positions in the memory
359
+ llama_pos p0[LLAMA_MAX_SEQ];
360
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
361
+ if (!memory) {
362
+ p0[s] = 0;
363
+ } else {
364
+ p0[s] = memory->seq_pos_max(s) + 1;
365
+ }
366
+ }
367
+
368
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
369
+ const llama_seq_id seq_id = batch.seq_id[i][0];
370
+
371
+ pos[i] = p0[seq_id];
372
+
373
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
374
+ p0[batch.seq_id[i][s]] = pos[i] + 1;
375
+ }
376
+ }
377
+
378
+ batch.pos = pos.data();
379
+ }
380
+
314
381
  if (!batch.logits) {
315
- logits.resize(batch.n_tokens);
316
- logits[logits.size() - 1] = true;
317
- batch.logits = logits.data();
382
+ if (embd_all) {
383
+ // return the output for all tokens
384
+ output.resize(batch.n_tokens, true);
385
+ } else {
386
+ // return the output only for the last token
387
+ output.resize(batch.n_tokens, false);
388
+ output[output.size() - 1] = true;
389
+ }
390
+
391
+ batch.logits = output.data();
392
+ } else if (embd_all) {
393
+ bool warn = false;
394
+
395
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
396
+ if (batch.logits[i] == 0) {
397
+ warn = true;
398
+ }
399
+ }
400
+
401
+ if (warn) {
402
+ LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
403
+
404
+ output.resize(batch.n_tokens, true);
405
+ batch.logits = output.data();
406
+ }
407
+ }
408
+
409
+ //
410
+ // compute stats
411
+ //
412
+
413
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
414
+ n_outputs += batch.logits[i] != 0;
415
+ }
416
+
417
+ // determine coupled sequences
418
+ // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
419
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
420
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
421
+ seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]);
422
+
423
+ if (s > 0) {
424
+ const llama_seq_id s0 = batch.seq_id[i][0];
425
+ const llama_seq_id s1 = batch.seq_id[i][s];
426
+
427
+ // mark that sequence s1 is coupled to s0
428
+ seq_cpl[s1][s0] = true;
429
+
430
+ // note: the other way around is not necessary for now
431
+ //seq_cpl[s0][s1] = true;
432
+ }
433
+ }
434
+ }
435
+
436
+ if (debug > 0) {
437
+ LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
438
+ LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
439
+ LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
440
+ LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
441
+ LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
442
+ LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) batch.n_seq_id);
443
+ LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) batch.seq_id);
444
+ LLAMA_LOG_DEBUG("%s: logits = %p\n", __func__, (void *) batch.logits);
445
+ LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
446
+
447
+ if (debug > 1) {
448
+ int seq_id_max = 0;
449
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
450
+ for (int s = 0; s < batch.n_seq_id[i]; ++s) {
451
+ for (int s = 0; s < batch.n_seq_id[i]; ++s) {
452
+ seq_id_max = std::max(seq_id_max, batch.seq_id[i][s]);
453
+ }
454
+ }
455
+ }
456
+ ++seq_id_max;
457
+
458
+ LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
459
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
460
+ std::vector<int8_t> seq_id(seq_id_max);
461
+
462
+ for (int s = 0; s < batch.n_seq_id[i]; ++s) {
463
+ seq_id[batch.seq_id[i][s]] = 1;
464
+ }
465
+
466
+ std::stringstream ss;
467
+ for (int s = 0; s < seq_id_max; ++s) {
468
+ if (seq_id[s]) {
469
+ ss << s%10;
470
+ } else {
471
+ ss << ".";
472
+ }
473
+ }
474
+
475
+ LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
476
+ __func__, i, batch.token[i], vocab.token_to_piece(batch.token[i]).c_str(),
477
+ batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
478
+ }
479
+ LLAMA_LOG_DEBUG("%s: ]\n", __func__);
480
+
481
+ LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
482
+ for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
483
+ if (seq_pos[s0].empty()) {
484
+ continue;
485
+ }
486
+
487
+ std::stringstream ss;
488
+ for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
489
+ if (seq_cpl[s0][s1]) {
490
+ ss << s1 << " ";
491
+ }
492
+ }
493
+
494
+ LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
495
+ __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
496
+ }
497
+ LLAMA_LOG_DEBUG("%s: ]\n", __func__);
498
+ }
499
+ }
500
+
501
+ //
502
+ // consistency checks
503
+ //
504
+
505
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
506
+ if (seq_pos[s].empty()) {
507
+ continue;
508
+ }
509
+
510
+ if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
511
+ LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
512
+ return false;
513
+ }
514
+
515
+ if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
516
+ LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
517
+ return false;
518
+ }
519
+ }
520
+
521
+ if (memory) {
522
+ for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
523
+ for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
524
+ if (seq_cpl[s0][s1]) {
525
+ if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
526
+ memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
527
+ LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
528
+ return false;
529
+ }
530
+ }
531
+ }
532
+ }
533
+ }
534
+
535
+ return true;
536
+ }
537
+
538
+ const llama_batch & llama_batch_allocr::get_batch() const {
539
+ return batch;
540
+ }
541
+
542
+ uint32_t llama_batch_allocr::get_n_outputs() const {
543
+ return n_outputs;
544
+ }
545
+
546
+ llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
547
+ return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
548
+ }
549
+
550
+ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
551
+ return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
552
+ }
553
+
554
+ void llama_batch_allocr::clear() {
555
+ n_outputs = 0;
556
+
557
+ batch = {};
558
+ pos.clear();
559
+ n_seq_id.clear();
560
+ seq_id.clear();
561
+ output.clear();
562
+
563
+ for (auto & cur : seq_pos) {
564
+ cur.clear();
565
+ }
566
+
567
+ for (auto & cur : seq_cpl) {
568
+ std::fill(cur.begin(), cur.end(), false);
318
569
  }
319
570
  }
320
571
 
@@ -4,6 +4,7 @@
4
4
 
5
5
  #include <array>
6
6
  #include <vector>
7
+ #include <set>
7
8
 
8
9
  // very similar to llama_batch,
9
10
  // but has more metadata about sequences
@@ -18,8 +19,8 @@ struct llama_ubatch {
18
19
  llama_token * token; // [n_tokens]
19
20
  float * embd; // [n_embd, n_tokens]
20
21
  llama_pos * pos; // [n_tokens]
21
- int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
22
- llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
22
+ int32_t * n_seq_id; // [n_seqs]
23
+ llama_seq_id ** seq_id; // [n_seqs]
23
24
  int8_t * output; // [n_tokens]
24
25
  };
25
26
 
@@ -39,8 +40,6 @@ struct llama_sbatch {
39
40
 
40
41
  size_t n_embd;
41
42
 
42
- bool logits_all; // TODO: remove once lctx.logits_all is removed too
43
-
44
43
  // sorted indices into the batch
45
44
  std::vector<int64_t> ids;
46
45
  // batch indices of the output
@@ -76,19 +75,45 @@ struct llama_sbatch {
76
75
  llama_ubatch split_seq(size_t n_ubatch);
77
76
 
78
77
  llama_sbatch() = default;
79
- llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
78
+ llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
80
79
  };
81
80
 
82
- // temporary allocate memory for the input batch if needed
83
- struct llama_batch_allocr {
84
- struct llama_batch batch;
81
+ // a helper for sanitizing and fulfilling a batch
82
+ class llama_batch_allocr {
83
+ public:
84
+ llama_batch_allocr();
85
+
86
+ // sanitize and auto-gen missing data in the input batch
87
+ // memory is optional. if provided will be used to check for sequence continuity and to determine the positions
88
+ bool init(
89
+ const llama_batch & batch_inp,
90
+ const llama_vocab & vocab,
91
+ const llama_memory_i * memory,
92
+ bool embd_all);
93
+
94
+ const llama_batch & get_batch() const;
95
+
96
+ uint32_t get_n_outputs() const;
97
+
98
+ llama_pos seq_pos_min(llama_seq_id seq_id) const;
99
+ llama_pos seq_pos_max(llama_seq_id seq_id) const;
100
+
101
+ private:
102
+ void clear();
103
+
104
+ llama_batch batch;
105
+
106
+ uint32_t n_outputs;
85
107
 
86
108
  std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
109
+
87
110
  std::vector<llama_pos> pos;
88
111
  std::vector<int32_t> n_seq_id;
89
112
  std::vector<llama_seq_id *> seq_id;
90
- std::vector<int8_t> logits;
113
+ std::vector<int8_t> output;
114
+
115
+ std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
116
+ std::vector<std::vector<bool>> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
91
117
 
92
- // optionally fulfill the batch returned by llama_batch_get_one
93
- llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
118
+ int debug;
94
119
  };
@@ -183,6 +183,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
183
183
  return LLM_CHAT_TEMPLATE_BAILING;
184
184
  } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
185
185
  return LLM_CHAT_TEMPLATE_LLAMA4;
186
+ } else if (tmpl_contains("<|endofuserprompt|>")) {
187
+ return LLM_CHAT_TEMPLATE_DOTS1;
186
188
  }
187
189
  return LLM_CHAT_TEMPLATE_UNKNOWN;
188
190
  }
@@ -331,7 +333,7 @@ int32_t llm_chat_apply_template(
331
333
  std::string role(message->role);
332
334
  if (role == "system") {
333
335
  // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
334
- system_prompt = trim(message->content);
336
+ system_prompt += trim(message->content);
335
337
  continue;
336
338
  }
337
339
  // in gemma, "assistant" is "model"
@@ -353,7 +355,7 @@ int32_t llm_chat_apply_template(
353
355
  std::string role(message->role);
354
356
  if (role == "system") {
355
357
  // there is no system message support, we will merge it with user prompt
356
- system_prompt = message->content;
358
+ system_prompt += message->content;
357
359
  continue;
358
360
  } else if (role == "user") {
359
361
  ss << "Human: ";
@@ -643,6 +645,21 @@ int32_t llm_chat_apply_template(
643
645
  if (add_ass) {
644
646
  ss << "Assistant:";
645
647
  }
648
+ } else if (tmpl == LLM_CHAT_TEMPLATE_DOTS1) {
649
+ // dots.llm1.inst (DOTS1)
650
+ for (auto message : chat) {
651
+ std::string role(message->role);
652
+ if (role == "system") {
653
+ ss << "<|system|>" << message->content << "<|endofsystem|>";
654
+ } else if (role == "user") {
655
+ ss << "<|userprompt|>" << message->content << "<|endofuserprompt|>";
656
+ } else {
657
+ ss << "<|response|>" << message->content << "<|endofresponse|>";
658
+ }
659
+ }
660
+ if (add_ass) {
661
+ ss << "<|response|>";
662
+ }
646
663
  } else {
647
664
  // template not supported
648
665
  return -1;
@@ -43,6 +43,7 @@ enum llm_chat_template {
43
43
  LLM_CHAT_TEMPLATE_BAILING,
44
44
  LLM_CHAT_TEMPLATE_LLAMA4,
45
45
  LLM_CHAT_TEMPLATE_SMOLVLM,
46
+ LLM_CHAT_TEMPLATE_DOTS1,
46
47
  LLM_CHAT_TEMPLATE_UNKNOWN,
47
48
  };
48
49