@novastera-oss/llamarn 0.2.7 → 0.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +56 -22
  11. package/cpp/build-info.cpp +2 -2
  12. package/cpp/llama.cpp/CMakeLists.txt +1 -1
  13. package/cpp/llama.cpp/common/arg.cpp +7 -0
  14. package/cpp/llama.cpp/common/common.cpp +3 -0
  15. package/cpp/llama.cpp/common/common.h +1 -0
  16. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  17. package/cpp/llama.cpp/convert_hf_to_gguf.py +118 -20
  18. package/cpp/llama.cpp/ggml/CMakeLists.txt +1 -0
  19. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  20. package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
  21. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -0
  22. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  23. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +31 -2
  24. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  25. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  26. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  27. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  28. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  29. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  30. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  31. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  32. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  33. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  34. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  35. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +83 -102
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +192 -67
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +211 -33
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
  48. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +54 -29
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +84 -31
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
  62. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +227 -41
  64. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +362 -182
  65. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  66. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +240 -535
  67. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  68. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  69. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  70. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  71. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  72. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  73. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  74. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  75. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
  76. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  77. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +45 -54
  78. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  79. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  80. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  81. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  82. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  83. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
  84. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  89. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  90. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +57 -1
  91. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  92. package/cpp/llama.cpp/ggml/src/ggml.c +69 -13
  93. package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
  94. package/cpp/llama.cpp/gguf-py/gguf/constants.py +76 -0
  95. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +21 -0
  96. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +64 -0
  97. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
  98. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  99. package/cpp/llama.cpp/include/llama.h +8 -3
  100. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  101. package/cpp/llama.cpp/src/llama-arch.cpp +55 -0
  102. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  103. package/cpp/llama.cpp/src/llama-batch.cpp +570 -359
  104. package/cpp/llama.cpp/src/llama-batch.h +98 -70
  105. package/cpp/llama.cpp/src/llama-chat.cpp +11 -6
  106. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  107. package/cpp/llama.cpp/src/llama-context.h +13 -13
  108. package/cpp/llama.cpp/src/llama-graph.cpp +199 -252
  109. package/cpp/llama.cpp/src/llama-graph.h +44 -32
  110. package/cpp/llama.cpp/src/llama-hparams.cpp +4 -0
  111. package/cpp/llama.cpp/src/llama-hparams.h +8 -0
  112. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +51 -53
  113. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +19 -24
  114. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +110 -104
  115. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +17 -22
  116. package/cpp/llama.cpp/src/llama-kv-cells.h +35 -11
  117. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +66 -67
  118. package/cpp/llama.cpp/src/llama-memory-hybrid.h +16 -21
  119. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +69 -68
  120. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  121. package/cpp/llama.cpp/src/llama-memory.h +18 -22
  122. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  123. package/cpp/llama.cpp/src/llama-model.cpp +1006 -472
  124. package/cpp/llama.cpp/src/llama-model.h +22 -0
  125. package/cpp/llama.cpp/src/llama-quant.cpp +87 -5
  126. package/cpp/llama.cpp/src/llama-vocab.cpp +26 -3
  127. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  128. package/cpp/rn-utils.h +3 -0
  129. package/ios/include/common.h +1 -0
  130. package/ios/include/llama.h +8 -3
  131. package/ios/libs/llama.xcframework/Info.plist +19 -19
  132. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  133. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
  134. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  135. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
  136. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -3
  137. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  138. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  139. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  140. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
  141. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  142. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  143. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  144. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  145. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  146. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  147. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3744
  148. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  149. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
  150. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -3
  151. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  152. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
  153. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -3
  154. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  155. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -3
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  160. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  161. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
  162. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  163. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
  164. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -3
  165. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  167. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  168. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  173. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4900
  175. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -3
  178. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4871
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3773
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  183. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  184. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  185. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  186. package/package.json +1 -1
@@ -1,7 +1,6 @@
1
1
  #include "llama-batch.h"
2
2
 
3
3
  #include "llama-impl.h"
4
- #include "llama-cparams.h"
5
4
  #include "llama-vocab.h"
6
5
  #include "llama-memory.h"
7
6
 
@@ -10,282 +9,7 @@
10
9
  #include <algorithm>
11
10
  #include <sstream>
12
11
 
13
- llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
14
- // clear empty sequences
15
- // the previous ubatch is assumed to be gone,
16
- // so nothing should refer to values in these sequences anymore.
17
- for (size_t i = seq.size(); i-- > 0;) {
18
- if (seq[i].length == 0) {
19
- seq.pop_back();
20
- } else {
21
- break;
22
- }
23
- }
24
-
25
- udatas.push_back({});
26
-
27
- auto & udata = udatas.back();
28
-
29
- udata.token.resize(!has_embd ? n_ubatch : 0);
30
- udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
31
- udata.pos.resize(n_ubatch);
32
- udata.n_seq_id.resize(n_ubatch);
33
- udata.seq_id.resize(n_ubatch);
34
- udata.output.resize(n_ubatch);
35
-
36
- llama_ubatch ubatch = {
37
- /*equal_seqs =*/ true,
38
- /*n_tokens =*/ 0,
39
- /*n_seq_tokens =*/ 0,
40
- /*n_seqs =*/ 0,
41
- /*token =*/ !has_embd ? udata.token.data() : nullptr,
42
- /*embd =*/ has_embd ? udata.embd.data() : nullptr,
43
- /*pos =*/ udata.pos.data(),
44
- /*n_seq_id =*/ udata.n_seq_id.data(),
45
- /*seq_id =*/ udata.seq_id.data(),
46
- /*output =*/ udata.output.data(),
47
- };
48
-
49
- return ubatch;
50
- }
51
-
52
- void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
53
- GGML_ASSERT(batch != nullptr);
54
- GGML_ASSERT(length <= seq.length);
55
- // Can only add sequences of equal lengths to a batch,
56
- // otherwise it isn't clear to which sequence a token belongs
57
- GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
58
- GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
59
- // NOTE: loops are separated for cache-friendliness
60
- if (batch->token) {
61
- if (ubatch.equal_seqs) {
62
- for (size_t i = 0; i < length; ++i) {
63
- ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
64
- }
65
- } else {
66
- // simple split
67
- ubatch.token = batch->token + seq.offset;
68
- }
69
- } else {
70
- ubatch.token = nullptr;
71
- }
72
- if (batch->embd) {
73
- if (ubatch.equal_seqs) {
74
- for (size_t i = 0; i < length; ++i) {
75
- memcpy(
76
- ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
77
- batch->embd + (n_embd * ids[seq.offset + i]),
78
- n_embd * sizeof(float)
79
- );
80
- }
81
- } else {
82
- // simple split
83
- ubatch.embd = batch->embd + (n_embd * seq.offset);
84
- }
85
- } else {
86
- ubatch.embd = nullptr;
87
- }
88
- if (ubatch.equal_seqs) {
89
- for (size_t i = 0; i < length; ++i) {
90
- ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
91
- }
92
- } else {
93
- // simple split
94
- ubatch.pos = batch->pos + seq.offset;
95
- }
96
- if (ubatch.equal_seqs) {
97
- ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
98
- if (seq.seq_id) {
99
- ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
100
- }
101
- } else {
102
- // simple split
103
- if (batch->n_seq_id) {
104
- ubatch.n_seq_id = batch->n_seq_id + seq.offset;
105
- } else {
106
- for (size_t i = 0; i < length; ++i) {
107
- ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
108
- }
109
- }
110
- if (batch->seq_id) {
111
- ubatch.seq_id = batch->seq_id + seq.offset;
112
- }
113
- }
114
- if (batch->logits) {
115
- if (ubatch.equal_seqs) {
116
- for (size_t i = 0; i < length; ++i) {
117
- size_t id = ids[seq.offset + i];
118
- int8_t is_output = batch->logits[id];
119
- ubatch.output[ubatch.n_tokens + i] = is_output;
120
- if (is_output) { out_ids.push_back(id); }
121
- }
122
- } else {
123
- // simple split
124
- ubatch.output = batch->logits + seq.offset;
125
- for (size_t i = 0; i < length; ++i) {
126
- if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
127
- }
128
- }
129
- } else {
130
- // only get last output
131
- for (size_t i = 0; i < length; ++i) {
132
- size_t id = ids[seq.offset + i];
133
- int8_t is_last = id == ids.size() - 1;
134
- ubatch.output[ubatch.n_tokens + i] = is_last;
135
- if (is_last) { out_ids.push_back(id); }
136
- }
137
- }
138
- if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
139
- ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
140
- }
141
- ubatch.n_tokens += length;
142
- ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
143
- seq.offset += length;
144
- seq.length -= length;
145
- n_tokens -= length;
146
- GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
147
- }
148
-
149
- llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
150
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
151
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
152
- ubatch.equal_seqs = false;
153
- if (!seq.empty()) {
154
- llama_sbatch_seq & s = seq[0];
155
- size_t length = s.length < n_ubatch ? s.length : n_ubatch;
156
- GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
157
- add_seq_to_ubatch(ubatch, s, length);
158
- }
159
- return ubatch;
160
- }
161
-
162
- llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
163
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
164
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
165
- if (!seq.empty()) {
166
- size_t length = 0;
167
- size_t n_tokens_in_ubatch = 0;
168
- GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
169
- // smallest first, because it's easier to split this way;
170
- // starting from the end to pop in constant time.
171
- for (size_t i = seq.size(); i-- > 0;) {
172
- llama_sbatch_seq & s = seq[i];
173
- GGML_ASSERT(s.length > 0);
174
- if (length == 0) {
175
- length = s.length < n_ubatch ? s.length : n_ubatch;
176
- }
177
- add_seq_to_ubatch(ubatch, s, length);
178
- n_tokens_in_ubatch += length;
179
- // shared prompts can't be mixed with any of their sequences,
180
- // so it's safer to compute them in their own ubatch
181
- if (s.n_seq_id > 1) { break; }
182
- // stop when there isn't enough space for another sequence
183
- if (length + n_tokens_in_ubatch > n_ubatch) { break; }
184
- }
185
- }
186
- return ubatch;
187
- }
188
-
189
- llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
190
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
191
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
192
- if (!seq.empty()) {
193
- llama_sbatch_seq & s = seq[seq.size() - 1];
194
- size_t length = s.length < n_ubatch ? s.length : n_ubatch;
195
- GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
196
- add_seq_to_ubatch(ubatch, s, length);
197
- }
198
- return ubatch;
199
- }
200
-
201
- llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
202
- GGML_ASSERT(batch.n_tokens >= 0);
203
- this->batch = &batch;
204
- this->n_embd = n_embd;
205
-
206
- n_tokens = batch.n_tokens;
207
- ids.resize(n_tokens);
208
- out_ids.clear();
209
- // TODO: reserve out_ids and seq
210
-
211
- for (size_t i = 0; i < n_tokens; ++i) {
212
- ids[i] = i;
213
- }
214
-
215
- if (simple_split) {
216
- seq.resize(1);
217
- llama_sbatch_seq & s = seq[0];
218
- s.n_seq_id = 0;
219
- s.seq_id = nullptr;
220
- s.offset = 0;
221
- s.length = n_tokens;
222
- return;
223
- }
224
-
225
- std::sort(ids.begin(), ids.end(),
226
- [&batch](size_t a, size_t b) {
227
- int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
228
- int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
229
- // sort by seq_id, then by pos
230
- if (n_seq_a == n_seq_b) {
231
- if (batch.seq_id) {
232
- for (int32_t i = 0; i < n_seq_a; ++i) {
233
- llama_seq_id seq_id_a = batch.seq_id[a][i];
234
- llama_seq_id seq_id_b = batch.seq_id[b][i];
235
- // smaller seq_ids go first
236
- if (seq_id_a != seq_id_b) {
237
- return seq_id_a < seq_id_b;
238
- }
239
- }
240
- }
241
- // when all else is equal, sort by pos
242
- if (batch.pos) {
243
- return batch.pos[a] < batch.pos[b];
244
- }
245
- // no pos, sort by id
246
- return a < b;
247
- }
248
- // shared prompts go first
249
- return n_seq_a > n_seq_b;
250
- }
251
- );
252
-
253
- // init seq
254
- llama_sbatch_seq * last_seq = nullptr;
255
-
256
- for (size_t i = 0; i < n_tokens; ++i) {
257
- const size_t bi = ids[i];
258
- const int32_t n_seqs = batch.n_seq_id[bi];
259
- llama_seq_id * seq_ids = batch.seq_id[bi];
260
- if (last_seq != nullptr) {
261
- bool same = n_seqs == last_seq->n_seq_id;
262
- for (int32_t j = 0; same && j < n_seqs; ++j) {
263
- if (seq_ids[j] != last_seq->seq_id[j]) {
264
- same = false;
265
- }
266
- }
267
- if (same) {
268
- last_seq->length += 1;
269
- continue;
270
- }
271
- }
272
- llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
273
- seq.push_back(new_seq);
274
- last_seq = &seq.back();
275
- }
276
-
277
- // keep shared prompts first at the end, then sort by length descending.
278
- std::sort(seq.begin(), seq.end(),
279
- [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
280
- if (a.n_seq_id == b.n_seq_id) {
281
- return a.length > b.length;
282
- }
283
- return a.n_seq_id < b.n_seq_id;
284
- }
285
- );
286
- }
287
-
288
- llama_batch_allocr::llama_batch_allocr() {
12
+ llama_batch_allocr::llama_batch_allocr(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {
289
13
  const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
290
14
  debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
291
15
 
@@ -294,17 +18,22 @@ llama_batch_allocr::llama_batch_allocr() {
294
18
  for (auto & cur : seq_cpl) {
295
19
  cur.resize(LLAMA_MAX_SEQ);
296
20
  }
21
+
22
+ seq_idx.resize(LLAMA_MAX_SEQ, -1);
297
23
  }
298
24
 
299
25
  bool llama_batch_allocr::init(
300
26
  const llama_batch & batch_inp,
301
27
  const llama_vocab & vocab,
302
28
  const llama_memory_i * memory,
303
- bool embd_all) {
29
+ uint32_t n_embd,
30
+ bool output_all) {
304
31
  clear();
305
32
 
306
33
  batch = batch_inp;
307
34
 
35
+ this->vocab = &vocab;
36
+
308
37
  GGML_ASSERT(batch.n_tokens > 0);
309
38
 
310
39
  //
@@ -359,6 +88,7 @@ bool llama_batch_allocr::init(
359
88
  llama_pos p0[LLAMA_MAX_SEQ];
360
89
  for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
361
90
  if (!memory) {
91
+ // if no memory -> start from 0
362
92
  p0[s] = 0;
363
93
  } else {
364
94
  p0[s] = memory->seq_pos_max(s) + 1;
@@ -370,8 +100,11 @@ bool llama_batch_allocr::init(
370
100
 
371
101
  pos[i] = p0[seq_id];
372
102
 
103
+ // update the starting position for all sequences that are assigned to the this token
373
104
  for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
374
- p0[batch.seq_id[i][s]] = pos[i] + 1;
105
+ const llama_seq_id seq_id = batch.seq_id[i][s];
106
+
107
+ p0[seq_id] = pos[i] + 1;
375
108
  }
376
109
  }
377
110
 
@@ -379,7 +112,7 @@ bool llama_batch_allocr::init(
379
112
  }
380
113
 
381
114
  if (!batch.logits) {
382
- if (embd_all) {
115
+ if (output_all) {
383
116
  // return the output for all tokens
384
117
  output.resize(batch.n_tokens, true);
385
118
  } else {
@@ -389,7 +122,7 @@ bool llama_batch_allocr::init(
389
122
  }
390
123
 
391
124
  batch.logits = output.data();
392
- } else if (embd_all) {
125
+ } else if (output_all) {
393
126
  bool warn = false;
394
127
 
395
128
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
@@ -410,6 +143,9 @@ bool llama_batch_allocr::init(
410
143
  // compute stats
411
144
  //
412
145
 
146
+ this->n_embd = n_embd;
147
+
148
+ // count the outputs in this batch
413
149
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
414
150
  n_outputs += batch.logits[i] != 0;
415
151
  }
@@ -417,85 +153,86 @@ bool llama_batch_allocr::init(
417
153
  // determine coupled sequences
418
154
  // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
419
155
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
156
+ const llama_seq_id s0 = batch.seq_id[i][0];
157
+
420
158
  for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
421
- seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]);
159
+ const llama_seq_id s1 = batch.seq_id[i][s];
422
160
 
423
- if (s > 0) {
424
- const llama_seq_id s0 = batch.seq_id[i][0];
425
- const llama_seq_id s1 = batch.seq_id[i][s];
161
+ seq_pos[s1].insert(batch.pos[i]);
426
162
 
163
+ if (s > 0) {
427
164
  // mark that sequence s1 is coupled to s0
428
165
  seq_cpl[s1][s0] = true;
429
166
 
430
- // note: the other way around is not necessary for now
167
+ // note: tracking the other way around is not necessary for now
431
168
  //seq_cpl[s0][s1] = true;
432
169
  }
433
170
  }
434
171
  }
435
172
 
436
- if (debug > 0) {
437
- LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
438
- LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
439
- LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
440
- LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
441
- LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
442
- LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) batch.n_seq_id);
443
- LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) batch.seq_id);
444
- LLAMA_LOG_DEBUG("%s: logits = %p\n", __func__, (void *) batch.logits);
445
- LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
173
+ // precompute the sequence sets for each token and determine the unique sequence ids that participate in the batch
174
+ {
175
+ seq_set_t seq_set_unq;
446
176
 
447
- if (debug > 1) {
448
- int seq_id_max = 0;
449
- for (int32_t i = 0; i < batch.n_tokens; ++i) {
450
- for (int s = 0; s < batch.n_seq_id[i]; ++s) {
451
- for (int s = 0; s < batch.n_seq_id[i]; ++s) {
452
- seq_id_max = std::max(seq_id_max, batch.seq_id[i][s]);
453
- }
454
- }
177
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
178
+ seq_set_t cur;
179
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
180
+ const llama_seq_id seq_id = batch.seq_id[i][s];
181
+
182
+ cur .set(seq_id);
183
+ seq_set_unq.set(seq_id);
455
184
  }
456
- ++seq_id_max;
457
185
 
458
- LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
459
- for (int32_t i = 0; i < batch.n_tokens; ++i) {
460
- std::vector<int8_t> seq_id(seq_id_max);
186
+ seq_set.push_back(cur);
187
+ seq_set_map[cur].push_back(i);
188
+ }
461
189
 
462
- for (int s = 0; s < batch.n_seq_id[i]; ++s) {
463
- seq_id[batch.seq_id[i][s]] = 1;
464
- }
190
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
191
+ if (seq_set_unq.test(s)) {
192
+ seq_idx[s] = seq_id_unq.size();
193
+ seq_id_unq.push_back(s);
194
+ }
195
+ }
196
+ }
465
197
 
466
- std::stringstream ss;
467
- for (int s = 0; s < seq_id_max; ++s) {
468
- if (seq_id[s]) {
469
- ss << s%10;
470
- } else {
471
- ss << ".";
472
- }
473
- }
198
+ if (debug > 0) {
199
+ LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
474
200
 
475
- LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
476
- __func__, i, batch.token[i], vocab.token_to_piece(batch.token[i]).c_str(),
477
- batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
201
+ llama_ubatch ubatch {
202
+ /*.equal_seqs =*/ false,
203
+ /*.n_tokens =*/ (uint32_t) batch.n_tokens,
204
+ /*.n_seq_tokens =*/ (uint32_t) 1,
205
+ /*.n_seqs =*/ (uint32_t) batch.n_tokens,
206
+ /*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
207
+ /*.token =*/ batch.token,
208
+ /*.embd =*/ batch.embd,
209
+ /*.pos =*/ batch.pos,
210
+ /*.n_seq_id =*/ batch.n_seq_id,
211
+ /*.seq_id =*/ batch.seq_id,
212
+ /*.seq_id_unq =*/ this->seq_id_unq.data(),
213
+ /*.seq_idx =*/ this->seq_idx.data(),
214
+ /*.output =*/ batch.logits,
215
+ };
216
+
217
+ ubatch_print(ubatch, debug);
218
+
219
+ LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
220
+ for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
221
+ if (seq_pos[s0].empty()) {
222
+ continue;
478
223
  }
479
- LLAMA_LOG_DEBUG("%s: ]\n", __func__);
480
-
481
- LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
482
- for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
483
- if (seq_pos[s0].empty()) {
484
- continue;
485
- }
486
224
 
487
- std::stringstream ss;
488
- for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
489
- if (seq_cpl[s0][s1]) {
490
- ss << s1 << " ";
491
- }
225
+ std::stringstream ss;
226
+ for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
227
+ if (seq_cpl[s0][s1]) {
228
+ ss << s1 << " ";
492
229
  }
493
-
494
- LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
495
- __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
496
230
  }
497
- LLAMA_LOG_DEBUG("%s: ]\n", __func__);
231
+
232
+ LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
233
+ __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
498
234
  }
235
+ LLAMA_LOG_DEBUG("%s: ]\n", __func__);
499
236
  }
500
237
 
501
238
  //
@@ -507,9 +244,35 @@ bool llama_batch_allocr::init(
507
244
  continue;
508
245
  }
509
246
 
510
- if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
511
- LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
512
- return false;
247
+ const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
248
+
249
+ if (p0 >= 0) {
250
+ bool ok = true;
251
+
252
+ if (batch.token) {
253
+ if (seq_pos_min(s) != p0 + 1) {
254
+ ok = false;
255
+ }
256
+ } else {
257
+ assert(batch.embd);
258
+
259
+ // for embeddings (typically used as vision input), we allow them to have repeating positions
260
+ // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
261
+ if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
262
+ ok = false;
263
+ }
264
+ }
265
+
266
+ if (!ok) {
267
+ LLAMA_LOG_ERROR(
268
+ "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
269
+ " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
270
+ " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
271
+ " it is required that the sequence positions remain consecutive: Y = X + 1\n",
272
+ __func__, s, s, p0, s, seq_pos_min(s));
273
+
274
+ return false;
275
+ }
513
276
  }
514
277
 
515
278
  if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
@@ -532,17 +295,120 @@ bool llama_batch_allocr::init(
532
295
  }
533
296
  }
534
297
 
298
+ // disallow partial sequence sub-sets:
299
+ //
300
+ // invalid: x
301
+ // i: 0 1 2 ...
302
+ // ---------------------------------------
303
+ // seq_id[i][0]: 0 0 1
304
+ // seq_id[i][1]: 1 1 2
305
+ // seq_id[i][2]: 2
306
+ //
307
+ // disallow decreasing sequence positions:
308
+ //
309
+ // invalid: x
310
+ // i: 0 1 2 3 4 5 6 ...
311
+ // ---------------------------------------
312
+ // pos[i]: 4 5 0 1 6 2 3
313
+ // seq_id[i][0]: 0 0 1 1 0 1 0
314
+ //
315
+ {
316
+ seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
317
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
318
+ cur_seq_set[s].set();
319
+ }
320
+
321
+ llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
322
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
323
+ cur_seq_pos[s] = -1;
324
+ }
325
+
326
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
327
+ const llama_pos pos = batch.pos[i];
328
+
329
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
330
+ const llama_seq_id seq_id = batch.seq_id[i][s];
331
+
332
+ cur_seq_set[seq_id] &= seq_set[i];
333
+
334
+ if (cur_seq_set[seq_id].none()) {
335
+ LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets (not allowed)\n", __func__, seq_id);
336
+ return false;
337
+ }
338
+
339
+ if (pos < cur_seq_pos[seq_id]) {
340
+ LLAMA_LOG_ERROR("%s: sequence %d positions are decreasing (not allowed)\n", __func__, seq_id);
341
+ return false;
342
+ }
343
+ }
344
+ }
345
+ }
346
+
347
+ split_reset();
348
+
535
349
  return true;
536
350
  }
537
351
 
352
+ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs) {
353
+ const uint32_t n_tokens = n_seq_tokens*n_seqs;
354
+
355
+ clear();
356
+ split_reset();
357
+
358
+ ubatches.emplace_back();
359
+
360
+ auto & ubatch = ubatches.back();
361
+
362
+ ubatch.token .resize(n_tokens);
363
+ ubatch.embd .clear();
364
+ ubatch.pos .resize(n_tokens);
365
+ ubatch.n_seq_id .resize(n_tokens);
366
+ ubatch.seq_id .resize(n_tokens);
367
+ ubatch.seq_id_unq.resize(0);
368
+ ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
369
+ ubatch.output .resize(n_tokens);
370
+
371
+ for (uint32_t s = 0; s < n_seqs; ++s) {
372
+ ubatch.seq_idx[s] = s;
373
+ ubatch.seq_id_unq.push_back(s);
374
+ }
375
+
376
+ llama_ubatch res {
377
+ /*.equal_seqs =*/ true,
378
+ /*.n_tokens =*/ n_tokens,
379
+ /*.n_seq_tokens =*/ n_seq_tokens,
380
+ /*.n_seqs =*/ n_seqs,
381
+ /*.n_seqs_unq =*/ n_seqs,
382
+
383
+ /*.token =*/ ubatch.token.data(),
384
+ /*.embd =*/ nullptr,
385
+ /*.pos =*/ ubatch.pos.data(),
386
+ /*.n_seq_id =*/ ubatch.n_seq_id.data(),
387
+ /*.seq_id =*/ ubatch.seq_id.data(),
388
+ /*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
389
+ /*.seq_idx =*/ ubatch.seq_idx.data(),
390
+ /*.output =*/ ubatch.output.data(),
391
+ };
392
+
393
+ return res;
394
+ }
395
+
538
396
  const llama_batch & llama_batch_allocr::get_batch() const {
539
397
  return batch;
540
398
  }
541
399
 
400
+ uint32_t llama_batch_allocr::get_n_tokens() const {
401
+ return batch.n_tokens;
402
+ }
403
+
542
404
  uint32_t llama_batch_allocr::get_n_outputs() const {
543
405
  return n_outputs;
544
406
  }
545
407
 
408
+ std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
409
+ return out_ids;
410
+ }
411
+
546
412
  llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
547
413
  return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
548
414
  }
@@ -551,14 +417,188 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
551
417
  return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
552
418
  }
553
419
 
420
+ void llama_batch_allocr::split_reset() {
421
+ out_ids.clear();
422
+
423
+ used.clear();
424
+ used.resize(get_n_tokens(), false);
425
+
426
+ ubatches.clear();
427
+ }
428
+
429
+ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
430
+ // find the first unused token
431
+ uint32_t cur_idx = 0;
432
+ while (cur_idx < used.size() && used[cur_idx]) {
433
+ ++cur_idx;
434
+ }
435
+
436
+ // we are done
437
+ if (cur_idx >= used.size()) {
438
+ return {};
439
+ }
440
+
441
+ std::vector<int32_t> idxs;
442
+
443
+ while (true) {
444
+ idxs.push_back(cur_idx);
445
+
446
+ used[cur_idx] = true;
447
+
448
+ ++cur_idx;
449
+
450
+ if (cur_idx >= used.size()) {
451
+ break;
452
+ }
453
+
454
+ if (idxs.size() >= n_ubatch) {
455
+ break;
456
+ }
457
+ }
458
+
459
+ return ubatch_add(idxs, idxs.size(), false);
460
+ }
461
+
462
+ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
463
+ std::vector<seq_set_t> cur_seq_set;
464
+
465
+ // determine the non-overlapping sequence sets participating in this ubatch
466
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
467
+ if (used[i]) {
468
+ continue;
469
+ }
470
+
471
+ bool add = true;
472
+
473
+ for (uint32_t s = 0; s < cur_seq_set.size(); ++s) {
474
+ // no overlap with existing sequence sets:
475
+ if (!(cur_seq_set[s] & seq_set[i]).none()) {
476
+ add = false;
477
+ break;
478
+ }
479
+ }
480
+
481
+ if (add) {
482
+ cur_seq_set.push_back(seq_set[i]);
483
+
484
+ if (cur_seq_set.size() > n_ubatch) {
485
+ break;
486
+ }
487
+ }
488
+ }
489
+
490
+ const uint32_t n_seqs = cur_seq_set.size();
491
+
492
+ // we are done
493
+ if (n_seqs == 0) {
494
+ return {};
495
+ }
496
+
497
+ // the current batch index of each sequence set
498
+ std::vector<int32_t> cur_idx(n_seqs, 0);
499
+
500
+ for (uint32_t s = 0; s < n_seqs; ++s) {
501
+ while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
502
+ ++cur_idx[s];
503
+ }
504
+ }
505
+
506
+ // the list of batch indices for each sequence set
507
+ // at the end we will concat these to get the final ubatch
508
+ std::vector<idx_vec_t> idxs_per_seq(n_seqs);
509
+
510
+ while (true) {
511
+ // we can only add new n_seq_tokens tokens if all the sequence sets have at least one more unused token and
512
+ // if we haven't reached n_ubatch
513
+ bool can_expand = true;
514
+
515
+ for (uint32_t s = 0; s < n_seqs; ++s) {
516
+ if (cur_idx[s] >= (int32_t) seq_set_map[cur_seq_set[s]].size()) {
517
+ can_expand = false;
518
+ break;
519
+ }
520
+ }
521
+
522
+ if (!can_expand) {
523
+ break;
524
+ }
525
+
526
+ for (uint32_t s = 0; s < n_seqs; ++s) {
527
+ const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
528
+
529
+ idxs_per_seq[s].push_back(idx);
530
+
531
+ used[idx] = true;
532
+
533
+ ++cur_idx[s];
534
+ }
535
+
536
+ if ((idxs_per_seq[0].size() + 1)*n_seqs > n_ubatch) {
537
+ break;
538
+ }
539
+ }
540
+
541
+ // concat the per-sequence-set lists
542
+ std::vector<int32_t> idxs;
543
+
544
+ for (uint32_t s = 0; s < n_seqs; ++s) {
545
+ idxs.insert(idxs.end(), idxs_per_seq[s].begin(), idxs_per_seq[s].end());
546
+ }
547
+
548
+ return ubatch_add(idxs, n_seqs, true);
549
+ }
550
+
551
+ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
552
+ // find the first unused token
553
+ uint32_t cur_idx = 0;
554
+ while (cur_idx < used.size() && used[cur_idx]) {
555
+ ++cur_idx;
556
+ }
557
+
558
+ // we are done
559
+ if (cur_idx >= used.size()) {
560
+ return {};
561
+ }
562
+
563
+ // this is the starting sequence set
564
+ // we allow adding tokens only if their sequence set is a subset of the current sequence set
565
+ auto cur_seq_set = seq_set[cur_idx];
566
+
567
+ std::vector<int32_t> idxs;
568
+
569
+ while (true) {
570
+ idxs.push_back(cur_idx);
571
+
572
+ used[cur_idx] = true;
573
+
574
+ if (idxs.size() >= n_ubatch) {
575
+ break;
576
+ }
577
+
578
+ do {
579
+ ++cur_idx;
580
+ } while (cur_idx < get_n_tokens() && (used[cur_idx] || ((cur_seq_set & seq_set[cur_idx]) != seq_set[cur_idx])));
581
+
582
+ if (cur_idx == get_n_tokens()) {
583
+ break;
584
+ }
585
+
586
+ cur_seq_set = seq_set[cur_idx];
587
+ }
588
+
589
+ return ubatch_add(idxs, 1, true);
590
+ }
591
+
554
592
  void llama_batch_allocr::clear() {
555
593
  n_outputs = 0;
556
594
 
557
595
  batch = {};
558
- pos.clear();
559
- n_seq_id.clear();
560
- seq_id.clear();
561
- output.clear();
596
+
597
+ pos .clear();
598
+ n_seq_id .clear();
599
+ seq_id .clear();
600
+ seq_id_unq.clear();
601
+ output .clear();
562
602
 
563
603
  for (auto & cur : seq_pos) {
564
604
  cur.clear();
@@ -567,6 +607,177 @@ void llama_batch_allocr::clear() {
567
607
  for (auto & cur : seq_cpl) {
568
608
  std::fill(cur.begin(), cur.end(), false);
569
609
  }
610
+
611
+ seq_set.clear();
612
+
613
+ seq_set_map.clear();
614
+
615
+ std::fill(seq_idx.begin(), seq_idx.end(), -1);
616
+ }
617
+
618
+ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs) {
619
+ const uint32_t n_tokens = idxs.size();
620
+
621
+ assert(n_tokens%n_seqs == 0);
622
+
623
+ ubatches.emplace_back();
624
+
625
+ auto & ubatch = ubatches.back();
626
+
627
+ const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
628
+
629
+ const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
630
+ const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
631
+
632
+ ubatch.token .resize(n_tokens);
633
+ ubatch.embd .resize(n_embd_all);
634
+ ubatch.pos .resize(n_pos_all);
635
+ ubatch.n_seq_id .resize(n_tokens);
636
+ ubatch.seq_id .resize(n_tokens);
637
+ ubatch.seq_id_unq.resize(0);
638
+ ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
639
+ ubatch.output .resize(n_tokens);
640
+
641
+ seq_set_t seq_set_unq;
642
+
643
+ for (size_t i = 0; i < idxs.size(); ++i) {
644
+ if (batch.token) {
645
+ ubatch.token[i] = batch.token[idxs[i]];
646
+ }
647
+
648
+ if (batch.embd) {
649
+ memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
650
+ }
651
+
652
+ for (int j = 0; j < n_pos_cur; ++j) {
653
+ ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
654
+ }
655
+
656
+ ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
657
+ ubatch.seq_id[i] = batch.seq_id[idxs[i]];
658
+ ubatch.output[i] = batch.logits[idxs[i]];
659
+
660
+ for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
661
+ seq_set_unq.set(ubatch.seq_id[i][s]);
662
+ }
663
+
664
+ if (ubatch.output[i]) {
665
+ out_ids.push_back(idxs[i]);
666
+ }
667
+ }
668
+
669
+ for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
670
+ if (seq_set_unq.test(s)) {
671
+ ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
672
+ ubatch.seq_id_unq.push_back(s);
673
+ }
674
+ }
675
+
676
+ llama_ubatch res {
677
+ /*.equal_seqs =*/ equal_seqs,
678
+ /*.n_tokens =*/ n_tokens,
679
+ /*.n_seq_tokens =*/ n_tokens/n_seqs,
680
+ /*.n_seqs =*/ n_seqs,
681
+ /*.n_seqs_unq =*/ (uint32_t) ubatch.seq_id_unq.size(),
682
+
683
+ /*.token =*/ batch.token ? ubatch.token.data() : nullptr,
684
+ /*.embd =*/ batch.embd ? ubatch.embd.data() : nullptr,
685
+ /*.pos =*/ ubatch.pos.data(),
686
+ /*.n_seq_id =*/ ubatch.n_seq_id.data(),
687
+ /*.seq_id =*/ ubatch.seq_id.data(),
688
+ /*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
689
+ /*.seq_idx =*/ ubatch.seq_idx.data(),
690
+ /*.output =*/ ubatch.output.data(),
691
+ };
692
+
693
+ if (debug > 0) {
694
+ LLAMA_LOG_DEBUG("%s: added ubatch %d to split:\n", __func__, (int) ubatches.size() - 1);
695
+
696
+ ubatch_print(res, debug);
697
+ }
698
+
699
+ return res;
700
+ }
701
+
702
+ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
703
+ if (debug > 0) {
704
+ LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs);
705
+ LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
706
+ LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
707
+ LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
708
+ LLAMA_LOG_DEBUG("%s: n_seqs_unq = %d\n", __func__, ubatch.n_seqs_unq);
709
+
710
+ std::stringstream ss_seq_id_unq;
711
+ std::stringstream ss_seq_idx;
712
+
713
+ ss_seq_id_unq << "[ ";
714
+ ss_seq_idx << "[";
715
+
716
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
717
+ ss_seq_id_unq << ubatch.seq_id_unq[s] << " ";
718
+ }
719
+
720
+ for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
721
+ if (ubatch.seq_idx[s] >= 0) {
722
+ ss_seq_idx << ubatch.seq_idx[s]%10;
723
+ } else {
724
+ ss_seq_idx << ".";
725
+ }
726
+ }
727
+
728
+ ss_seq_id_unq << "]";
729
+ ss_seq_idx << "]";
730
+
731
+ LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) ubatch.token);
732
+ LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) ubatch.embd);
733
+ LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) ubatch.pos);
734
+ LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) ubatch.n_seq_id);
735
+ LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) ubatch.seq_id);
736
+ LLAMA_LOG_DEBUG("%s: seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str());
737
+ LLAMA_LOG_DEBUG("%s: seq_idx = %s\n", __func__, ss_seq_idx.str().c_str());
738
+ LLAMA_LOG_DEBUG("%s: output = %p\n", __func__, (void *) ubatch.output);
739
+ LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
740
+
741
+ if (debug > 1) {
742
+ int seq_id_max = 0;
743
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
744
+ for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
745
+ for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
746
+ seq_id_max = std::max(seq_id_max, ubatch.seq_id[i][s]);
747
+ }
748
+ }
749
+ }
750
+ ++seq_id_max;
751
+
752
+ LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
753
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
754
+ std::vector<int8_t> seq_id(seq_id_max);
755
+
756
+ for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
757
+ seq_id[ubatch.seq_id[i][s]] = 1;
758
+ }
759
+
760
+ std::stringstream ss;
761
+ for (int s = 0; s < seq_id_max; ++s) {
762
+ if (seq_id[s]) {
763
+ ss << s%10;
764
+ } else {
765
+ ss << ".";
766
+ }
767
+ }
768
+
769
+ if (ubatch.token) {
770
+ LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
771
+ __func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(),
772
+ ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
773
+ } else {
774
+ LLAMA_LOG_DEBUG("%s: %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
775
+ __func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
776
+ }
777
+ }
778
+ LLAMA_LOG_DEBUG("%s: ]\n", __func__);
779
+ }
780
+ }
570
781
  }
571
782
 
572
783
  //
@@ -577,25 +788,25 @@ struct llama_batch llama_batch_get_one(
577
788
  llama_token * tokens,
578
789
  int32_t n_tokens) {
579
790
  return {
580
- /*n_tokens =*/ n_tokens,
581
- /*tokens =*/ tokens,
582
- /*embd =*/ nullptr,
583
- /*pos =*/ nullptr,
584
- /*n_seq_id =*/ nullptr,
585
- /*seq_id =*/ nullptr,
586
- /*logits =*/ nullptr,
791
+ /*n_tokens =*/ n_tokens,
792
+ /*tokens =*/ tokens,
793
+ /*embd =*/ nullptr,
794
+ /*pos =*/ nullptr,
795
+ /*n_seq_id =*/ nullptr,
796
+ /*seq_id =*/ nullptr,
797
+ /*logits =*/ nullptr,
587
798
  };
588
799
  }
589
800
 
590
801
  struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
591
802
  llama_batch batch = {
592
- /*n_tokens =*/ 0,
593
- /*tokens =*/ nullptr,
594
- /*embd =*/ nullptr,
595
- /*pos =*/ nullptr,
596
- /*n_seq_id =*/ nullptr,
597
- /*seq_id =*/ nullptr,
598
- /*logits =*/ nullptr,
803
+ /*n_tokens =*/ 0,
804
+ /*tokens =*/ nullptr,
805
+ /*embd =*/ nullptr,
806
+ /*pos =*/ nullptr,
807
+ /*n_seq_id =*/ nullptr,
808
+ /*seq_id =*/ nullptr,
809
+ /*logits =*/ nullptr,
599
810
  };
600
811
 
601
812
  if (embd) {