cui-llama.rn 1.4.4 → 1.4.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. package/android/src/main/CMakeLists.txt +2 -2
  2. package/android/src/main/jni.cpp +12 -10
  3. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  11. package/cpp/chat-template.hpp +529 -529
  12. package/cpp/chat.cpp +959 -265
  13. package/cpp/chat.h +135 -0
  14. package/cpp/common.cpp +2064 -1996
  15. package/cpp/common.h +700 -744
  16. package/cpp/ggml-alloc.c +1039 -1030
  17. package/cpp/ggml-alloc.h +1 -1
  18. package/cpp/ggml-backend-impl.h +255 -255
  19. package/cpp/ggml-backend-reg.cpp +586 -582
  20. package/cpp/ggml-backend.cpp +2004 -2002
  21. package/cpp/ggml-backend.h +354 -354
  22. package/cpp/ggml-common.h +1851 -1851
  23. package/cpp/ggml-cpp.h +39 -39
  24. package/cpp/ggml-cpu-aarch64.cpp +4248 -4247
  25. package/cpp/ggml-cpu-aarch64.h +8 -8
  26. package/cpp/ggml-cpu-impl.h +531 -380
  27. package/cpp/ggml-cpu-quants.c +12527 -11517
  28. package/cpp/ggml-cpu-traits.cpp +36 -36
  29. package/cpp/ggml-cpu-traits.h +38 -38
  30. package/cpp/ggml-cpu.c +15766 -14485
  31. package/cpp/ggml-cpu.cpp +655 -633
  32. package/cpp/ggml-cpu.h +138 -135
  33. package/cpp/ggml-impl.h +567 -567
  34. package/cpp/ggml-metal-impl.h +235 -0
  35. package/cpp/ggml-metal.h +66 -66
  36. package/cpp/ggml-metal.m +5146 -5002
  37. package/cpp/ggml-opt.cpp +854 -854
  38. package/cpp/ggml-opt.h +216 -216
  39. package/cpp/ggml-quants.c +5238 -5238
  40. package/cpp/ggml-threading.h +14 -14
  41. package/cpp/ggml.c +6529 -6524
  42. package/cpp/ggml.h +2198 -2194
  43. package/cpp/gguf.cpp +1329 -1329
  44. package/cpp/gguf.h +202 -202
  45. package/cpp/json-schema-to-grammar.cpp +1024 -1025
  46. package/cpp/json-schema-to-grammar.h +21 -22
  47. package/cpp/json.hpp +24766 -24766
  48. package/cpp/llama-adapter.cpp +347 -347
  49. package/cpp/llama-adapter.h +74 -74
  50. package/cpp/llama-arch.cpp +1513 -1492
  51. package/cpp/llama-arch.h +403 -402
  52. package/cpp/llama-batch.cpp +368 -368
  53. package/cpp/llama-batch.h +88 -88
  54. package/cpp/llama-chat.cpp +588 -587
  55. package/cpp/llama-chat.h +53 -53
  56. package/cpp/llama-context.cpp +1775 -1775
  57. package/cpp/llama-context.h +128 -128
  58. package/cpp/llama-cparams.cpp +1 -1
  59. package/cpp/llama-cparams.h +37 -37
  60. package/cpp/llama-cpp.h +30 -30
  61. package/cpp/llama-grammar.cpp +1219 -1219
  62. package/cpp/llama-grammar.h +173 -164
  63. package/cpp/llama-hparams.cpp +71 -71
  64. package/cpp/llama-hparams.h +139 -139
  65. package/cpp/llama-impl.cpp +167 -167
  66. package/cpp/llama-impl.h +61 -61
  67. package/cpp/llama-kv-cache.cpp +718 -718
  68. package/cpp/llama-kv-cache.h +219 -218
  69. package/cpp/llama-mmap.cpp +600 -590
  70. package/cpp/llama-mmap.h +68 -68
  71. package/cpp/llama-model-loader.cpp +1124 -1124
  72. package/cpp/llama-model-loader.h +167 -167
  73. package/cpp/llama-model.cpp +4087 -4023
  74. package/cpp/llama-model.h +370 -370
  75. package/cpp/llama-sampling.cpp +2558 -2525
  76. package/cpp/llama-sampling.h +32 -32
  77. package/cpp/llama-vocab.cpp +3264 -3252
  78. package/cpp/llama-vocab.h +125 -125
  79. package/cpp/llama.cpp +10284 -10137
  80. package/cpp/llama.h +1354 -1340
  81. package/cpp/log.cpp +393 -423
  82. package/cpp/log.h +132 -132
  83. package/cpp/minja/chat-template.hpp +529 -0
  84. package/cpp/minja/minja.hpp +2915 -0
  85. package/cpp/minja.hpp +2915 -2883
  86. package/cpp/rn-llama.cpp +20 -37
  87. package/cpp/rn-llama.h +12 -2
  88. package/cpp/sampling.cpp +570 -532
  89. package/cpp/sgemm.cpp +2598 -2598
  90. package/cpp/sgemm.h +14 -14
  91. package/cpp/speculative.cpp +278 -277
  92. package/cpp/speculative.h +28 -28
  93. package/package.json +1 -1
  94. package/android/src/main/build-arm64/CMakeCache.txt +0 -429
  95. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCCompiler.cmake +0 -81
  96. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCXXCompiler.cmake +0 -101
  97. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_C.bin +0 -0
  98. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_CXX.bin +0 -0
  99. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeSystem.cmake +0 -15
  100. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.c +0 -904
  101. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.o +0 -0
  102. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.cpp +0 -919
  103. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.o +0 -0
  104. package/android/src/main/build-arm64/CMakeFiles/CMakeConfigureLog.yaml +0 -431
  105. package/android/src/main/build-arm64/CMakeFiles/CMakeDirectoryInformation.cmake +0 -16
  106. package/android/src/main/build-arm64/CMakeFiles/Makefile.cmake +0 -165
  107. package/android/src/main/build-arm64/CMakeFiles/Makefile2 +0 -297
  108. package/android/src/main/build-arm64/CMakeFiles/Progress/1 +0 -1
  109. package/android/src/main/build-arm64/CMakeFiles/Progress/2 +0 -1
  110. package/android/src/main/build-arm64/CMakeFiles/Progress/3 +0 -1
  111. package/android/src/main/build-arm64/CMakeFiles/Progress/4 +0 -1
  112. package/android/src/main/build-arm64/CMakeFiles/Progress/5 +0 -1
  113. package/android/src/main/build-arm64/CMakeFiles/Progress/6 +0 -1
  114. package/android/src/main/build-arm64/CMakeFiles/Progress/count.txt +0 -1
  115. package/android/src/main/build-arm64/CMakeFiles/TargetDirectories.txt +0 -8
  116. package/android/src/main/build-arm64/CMakeFiles/cmake.check_cache +0 -1
  117. package/android/src/main/build-arm64/CMakeFiles/progress.marks +0 -1
  118. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o +0 -0
  119. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o.d +0 -58
  120. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o +0 -0
  121. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o.d +0 -756
  122. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o +0 -0
  123. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o.d +0 -709
  124. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o +0 -0
  125. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o.d +0 -714
  126. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o +0 -0
  127. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o.d +0 -62
  128. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o +0 -0
  129. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o.d +0 -708
  130. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o +0 -0
  131. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o.d +0 -113
  132. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o +0 -0
  133. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o.d +0 -713
  134. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o +0 -0
  135. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o.d +0 -763
  136. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o +0 -0
  137. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o.d +0 -61
  138. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o +0 -0
  139. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o.d +0 -707
  140. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o +0 -0
  141. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o.d +0 -104
  142. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o +0 -0
  143. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o.d +0 -714
  144. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o +0 -0
  145. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o.d +0 -723
  146. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/DependInfo.cmake +0 -62
  147. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/build.make +0 -722
  148. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/cmake_clean.cmake +0 -89
  149. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.make +0 -2
  150. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.ts +0 -2
  151. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/depend.make +0 -2
  152. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/flags.make +0 -17
  153. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/progress.make +0 -41
  154. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/DependInfo.cmake +0 -62
  155. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/build.make +0 -722
  156. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/cmake_clean.cmake +0 -89
  157. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.make +0 -2
  158. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.ts +0 -2
  159. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/depend.make +0 -2
  160. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/flags.make +0 -17
  161. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/progress.make +0 -41
  162. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/DependInfo.cmake +0 -62
  163. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/build.make +0 -722
  164. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/cmake_clean.cmake +0 -89
  165. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.make +0 -2
  166. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.ts +0 -2
  167. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/depend.make +0 -2
  168. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/flags.make +0 -17
  169. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/progress.make +0 -41
  170. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/DependInfo.cmake +0 -62
  171. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/build.make +0 -722
  172. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/cmake_clean.cmake +0 -89
  173. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.make +0 -2
  174. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.ts +0 -2
  175. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/depend.make +0 -2
  176. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/flags.make +0 -17
  177. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/progress.make +0 -41
  178. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/DependInfo.cmake +0 -62
  179. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/build.make +0 -722
  180. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/cmake_clean.cmake +0 -89
  181. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.make +0 -2
  182. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.ts +0 -2
  183. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/depend.make +0 -2
  184. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/flags.make +0 -17
  185. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/progress.make +0 -41
  186. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/DependInfo.cmake +0 -62
  187. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/build.make +0 -722
  188. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/cmake_clean.cmake +0 -89
  189. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.make +0 -2
  190. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.ts +0 -2
  191. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/depend.make +0 -2
  192. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/flags.make +0 -17
  193. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/progress.make +0 -41
  194. package/android/src/main/build-arm64/Makefile +0 -1862
  195. package/android/src/main/build-arm64/cmake_install.cmake +0 -66
  196. package/cpp/chat.hpp +0 -55
  197. package/cpp/rn-llama.hpp +0 -913
@@ -1,368 +1,368 @@
1
- #include "llama-batch.h"
2
-
3
- #include <cstring>
4
- #include <algorithm>
5
-
6
- llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
7
- // clear empty sequences
8
- // the previous ubatch is assumed to be gone,
9
- // so nothing should refer to values in these sequences anymore.
10
- for (size_t i = seq.size(); i-- > 0;) {
11
- if (seq[i].length == 0) {
12
- seq.pop_back();
13
- } else {
14
- break;
15
- }
16
- }
17
- ubatch_token.resize(!has_embd ? n_ubatch : 0);
18
- ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
19
- ubatch_pos.resize(n_ubatch);
20
- ubatch_n_seq_id.resize(n_ubatch);
21
- ubatch_seq_id.resize(n_ubatch);
22
- ubatch_output.resize(n_ubatch);
23
- llama_ubatch ubatch = {
24
- /*equal_seqs =*/ true,
25
- /*n_tokens =*/ 0,
26
- /*n_seq_tokens =*/ 0,
27
- /*n_seqs =*/ 0,
28
- /*token =*/ !has_embd ? ubatch_token.data() : nullptr,
29
- /*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
30
- /*pos =*/ ubatch_pos.data(),
31
- /*n_seq_id =*/ ubatch_n_seq_id.data(),
32
- /*seq_id =*/ ubatch_seq_id.data(),
33
- /*output =*/ ubatch_output.data(),
34
- };
35
- return ubatch;
36
- }
37
-
38
- void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
39
- LM_GGML_ASSERT(batch != nullptr);
40
- LM_GGML_ASSERT(length <= seq.length);
41
- // Can only add sequences of equal lengths to a batch,
42
- // otherwise it isn't clear to which sequence a token belongs
43
- LM_GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
44
- LM_GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
45
- // NOTE: loops are separated for cache-friendliness
46
- if (batch->token) {
47
- if (ubatch.equal_seqs) {
48
- for (size_t i = 0; i < length; ++i) {
49
- ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
50
- }
51
- } else {
52
- // simple split
53
- ubatch.token = batch->token + seq.offset;
54
- }
55
- } else {
56
- ubatch.token = nullptr;
57
- }
58
- if (batch->embd) {
59
- if (ubatch.equal_seqs) {
60
- for (size_t i = 0; i < length; ++i) {
61
- memcpy(
62
- ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
63
- batch->embd + (n_embd * ids[seq.offset + i]),
64
- n_embd * sizeof(float)
65
- );
66
- }
67
- } else {
68
- // simple split
69
- ubatch.embd = batch->embd + (n_embd * seq.offset);
70
- }
71
- } else {
72
- ubatch.embd = nullptr;
73
- }
74
- if (ubatch.equal_seqs) {
75
- for (size_t i = 0; i < length; ++i) {
76
- ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
77
- }
78
- } else {
79
- // simple split
80
- ubatch.pos = batch->pos + seq.offset;
81
- }
82
- if (ubatch.equal_seqs) {
83
- ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
84
- if (seq.seq_id) {
85
- ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
86
- }
87
- } else {
88
- // simple split
89
- if (batch->n_seq_id) {
90
- ubatch.n_seq_id = batch->n_seq_id + seq.offset;
91
- } else {
92
- for (size_t i = 0; i < length; ++i) {
93
- ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
94
- }
95
- }
96
- if (batch->seq_id) {
97
- ubatch.seq_id = batch->seq_id + seq.offset;
98
- }
99
- }
100
- if (logits_all) {
101
- for (size_t i = 0; i < length; ++i) {
102
- ubatch.output[ubatch.n_tokens + i] = 1;
103
- out_ids.push_back(ids[seq.offset + i]);
104
- }
105
- } else if (batch->logits) {
106
- if (ubatch.equal_seqs) {
107
- for (size_t i = 0; i < length; ++i) {
108
- size_t id = ids[seq.offset + i];
109
- int8_t is_output = batch->logits[id];
110
- ubatch.output[ubatch.n_tokens + i] = is_output;
111
- if (is_output) { out_ids.push_back(id); }
112
- }
113
- } else {
114
- // simple split
115
- ubatch.output = batch->logits + seq.offset;
116
- for (size_t i = 0; i < length; ++i) {
117
- if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
118
- }
119
- }
120
- } else {
121
- // only get last output
122
- for (size_t i = 0; i < length; ++i) {
123
- size_t id = ids[seq.offset + i];
124
- int8_t is_last = id == ids.size() - 1;
125
- ubatch.output[ubatch.n_tokens + i] = is_last;
126
- if (is_last) { out_ids.push_back(id); }
127
- }
128
- }
129
- if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
130
- ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
131
- }
132
- ubatch.n_tokens += length;
133
- ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
134
- seq.offset += length;
135
- seq.length -= length;
136
- n_tokens -= length;
137
- LM_GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
138
- }
139
-
140
- llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
141
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
142
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
143
- ubatch.equal_seqs = false;
144
- if (!seq.empty()) {
145
- llama_sbatch_seq & s = seq[0];
146
- size_t length = s.length < n_ubatch ? s.length : n_ubatch;
147
- LM_GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
148
- add_seq_to_ubatch(ubatch, s, length);
149
- }
150
- return ubatch;
151
- }
152
-
153
- llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
154
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
155
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
156
- if (!seq.empty()) {
157
- size_t length = 0;
158
- size_t n_tokens_in_ubatch = 0;
159
- LM_GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
160
- // smallest first, because it's easier to split this way;
161
- // starting from the end to pop in constant time.
162
- for (size_t i = seq.size(); i-- > 0;) {
163
- llama_sbatch_seq & s = seq[i];
164
- LM_GGML_ASSERT(s.length > 0);
165
- if (length == 0) {
166
- length = s.length < n_ubatch ? s.length : n_ubatch;
167
- }
168
- add_seq_to_ubatch(ubatch, s, length);
169
- n_tokens_in_ubatch += length;
170
- // shared prompts can't be mixed with any of their sequences,
171
- // so it's safer to compute them in their own ubatch
172
- if (s.n_seq_id > 1) { break; }
173
- // stop when there isn't enough space for another sequence
174
- if (length + n_tokens_in_ubatch > n_ubatch) { break; }
175
- }
176
- }
177
- return ubatch;
178
- }
179
-
180
- llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
181
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
182
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
183
- if (!seq.empty()) {
184
- llama_sbatch_seq & s = seq[seq.size() - 1];
185
- size_t length = s.length < n_ubatch ? s.length : n_ubatch;
186
- LM_GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
187
- add_seq_to_ubatch(ubatch, s, length);
188
- }
189
- return ubatch;
190
- }
191
-
192
- void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
193
- LM_GGML_ASSERT(batch.n_tokens >= 0);
194
- this->batch = &batch;
195
- this->n_embd = n_embd;
196
- this->logits_all = logits_all;
197
-
198
- n_tokens = batch.n_tokens;
199
- ids.resize(n_tokens);
200
- out_ids.clear();
201
- // TODO: reserve out_ids and seq
202
-
203
- for (size_t i = 0; i < n_tokens; ++i) {
204
- ids[i] = i;
205
- }
206
- if (simple_split) {
207
- seq.resize(1);
208
- llama_sbatch_seq & s = seq[0];
209
- s.n_seq_id = 0;
210
- s.seq_id = nullptr;
211
- s.offset = 0;
212
- s.length = n_tokens;
213
- return;
214
- }
215
- std::sort(ids.begin(), ids.end(),
216
- [&batch](size_t a, size_t b) {
217
- int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
218
- int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
219
- // sort by seq_id, then by pos
220
- if (n_seq_a == n_seq_b) {
221
- if (batch.seq_id) {
222
- for (int32_t i = 0; i < n_seq_a; ++i) {
223
- llama_seq_id seq_id_a = batch.seq_id[a][i];
224
- llama_seq_id seq_id_b = batch.seq_id[b][i];
225
- // smaller seq_ids go first
226
- if (seq_id_a != seq_id_b) {
227
- return seq_id_a < seq_id_b;
228
- }
229
- }
230
- }
231
- // when all else is equal, sort by pos
232
- if (batch.pos) {
233
- return batch.pos[a] < batch.pos[b];
234
- }
235
- // no pos, sort by id
236
- return a < b;
237
- }
238
- // shared prompts go first
239
- return n_seq_a > n_seq_b;
240
- }
241
- );
242
- // init seq
243
- llama_sbatch_seq * last_seq = nullptr;
244
-
245
- for (size_t i = 0; i < n_tokens; ++i) {
246
- const size_t bi = ids[i];
247
- const int32_t n_seqs = batch.n_seq_id[bi];
248
- llama_seq_id * seq_ids = batch.seq_id[bi];
249
- if (last_seq != nullptr) {
250
- bool same = n_seqs == last_seq->n_seq_id;
251
- for (int32_t j = 0; same && j < n_seqs; ++j) {
252
- if (seq_ids[j] != last_seq->seq_id[j]) {
253
- same = false;
254
- }
255
- }
256
- if (same) {
257
- last_seq->length += 1;
258
- continue;
259
- }
260
- }
261
- llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
262
- seq.push_back(new_seq);
263
- last_seq = &seq.back();
264
- }
265
- // keep shared prompts first at the end, then sort by length descending.
266
- std::sort(seq.begin(), seq.end(),
267
- [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
268
- if (a.n_seq_id == b.n_seq_id) {
269
- return a.length > b.length;
270
- }
271
- return a.n_seq_id < b.n_seq_id;
272
- }
273
- );
274
- }
275
-
276
- llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
277
- batch = in_batch;
278
- LM_GGML_ASSERT(batch.n_tokens > 0);
279
- if (!batch.pos) {
280
- pos.resize(batch.n_tokens);
281
- for (int32_t i = 0; i < batch.n_tokens; i++) {
282
- pos[i] = i + p0;
283
- }
284
- batch.pos = pos.data();
285
- }
286
- if (!batch.n_seq_id) {
287
- n_seq_id.resize(batch.n_tokens);
288
- for (int32_t i = 0; i < batch.n_tokens; i++) {
289
- n_seq_id[i] = seq_id_0.size();
290
- }
291
- batch.n_seq_id = n_seq_id.data();
292
- }
293
- if (!batch.seq_id) {
294
- seq_id.resize(batch.n_tokens + 1);
295
- seq_id[batch.n_tokens] = NULL;
296
- for (int32_t i = 0; i < batch.n_tokens; i++) {
297
- seq_id[i] = seq_id_0.data();
298
- }
299
- batch.seq_id = seq_id.data();
300
- }
301
- if (!batch.logits) {
302
- logits.resize(batch.n_tokens);
303
- logits[logits.size() - 1] = true;
304
- batch.logits = logits.data();
305
- }
306
- }
307
-
308
- //
309
- // interface implementation
310
- //
311
-
312
- struct llama_batch llama_batch_get_one(
313
- llama_token * tokens,
314
- int32_t n_tokens) {
315
- return {
316
- /*n_tokens =*/ n_tokens,
317
- /*tokens =*/ tokens,
318
- /*embd =*/ nullptr,
319
- /*pos =*/ nullptr,
320
- /*n_seq_id =*/ nullptr,
321
- /*seq_id =*/ nullptr,
322
- /*logits =*/ nullptr,
323
- };
324
- }
325
-
326
- struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
327
- llama_batch batch = {
328
- /*n_tokens =*/ 0,
329
- /*tokens =*/ nullptr,
330
- /*embd =*/ nullptr,
331
- /*pos =*/ nullptr,
332
- /*n_seq_id =*/ nullptr,
333
- /*seq_id =*/ nullptr,
334
- /*logits =*/ nullptr,
335
- };
336
-
337
- if (embd) {
338
- batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
339
- } else {
340
- batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
341
- }
342
-
343
- batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
344
- batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
345
- batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
346
- for (int i = 0; i < n_tokens_alloc; ++i) {
347
- batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
348
- }
349
- batch.seq_id[n_tokens_alloc] = nullptr;
350
-
351
- batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
352
-
353
- return batch;
354
- }
355
-
356
- void llama_batch_free(struct llama_batch batch) {
357
- if (batch.token) free(batch.token);
358
- if (batch.embd) free(batch.embd);
359
- if (batch.pos) free(batch.pos);
360
- if (batch.n_seq_id) free(batch.n_seq_id);
361
- if (batch.seq_id) {
362
- for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
363
- free(batch.seq_id[i]);
364
- }
365
- free(batch.seq_id);
366
- }
367
- if (batch.logits) free(batch.logits);
368
- }
1
+ #include "llama-batch.h"
2
+
3
+ #include <cstring>
4
+ #include <algorithm>
5
+
6
+ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
7
+ // clear empty sequences
8
+ // the previous ubatch is assumed to be gone,
9
+ // so nothing should refer to values in these sequences anymore.
10
+ for (size_t i = seq.size(); i-- > 0;) {
11
+ if (seq[i].length == 0) {
12
+ seq.pop_back();
13
+ } else {
14
+ break;
15
+ }
16
+ }
17
+ ubatch_token.resize(!has_embd ? n_ubatch : 0);
18
+ ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
19
+ ubatch_pos.resize(n_ubatch);
20
+ ubatch_n_seq_id.resize(n_ubatch);
21
+ ubatch_seq_id.resize(n_ubatch);
22
+ ubatch_output.resize(n_ubatch);
23
+ llama_ubatch ubatch = {
24
+ /*equal_seqs =*/ true,
25
+ /*n_tokens =*/ 0,
26
+ /*n_seq_tokens =*/ 0,
27
+ /*n_seqs =*/ 0,
28
+ /*token =*/ !has_embd ? ubatch_token.data() : nullptr,
29
+ /*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
30
+ /*pos =*/ ubatch_pos.data(),
31
+ /*n_seq_id =*/ ubatch_n_seq_id.data(),
32
+ /*seq_id =*/ ubatch_seq_id.data(),
33
+ /*output =*/ ubatch_output.data(),
34
+ };
35
+ return ubatch;
36
+ }
37
+
38
+ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
39
+ LM_GGML_ASSERT(batch != nullptr);
40
+ LM_GGML_ASSERT(length <= seq.length);
41
+ // Can only add sequences of equal lengths to a batch,
42
+ // otherwise it isn't clear to which sequence a token belongs
43
+ LM_GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
44
+ LM_GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
45
+ // NOTE: loops are separated for cache-friendliness
46
+ if (batch->token) {
47
+ if (ubatch.equal_seqs) {
48
+ for (size_t i = 0; i < length; ++i) {
49
+ ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
50
+ }
51
+ } else {
52
+ // simple split
53
+ ubatch.token = batch->token + seq.offset;
54
+ }
55
+ } else {
56
+ ubatch.token = nullptr;
57
+ }
58
+ if (batch->embd) {
59
+ if (ubatch.equal_seqs) {
60
+ for (size_t i = 0; i < length; ++i) {
61
+ memcpy(
62
+ ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
63
+ batch->embd + (n_embd * ids[seq.offset + i]),
64
+ n_embd * sizeof(float)
65
+ );
66
+ }
67
+ } else {
68
+ // simple split
69
+ ubatch.embd = batch->embd + (n_embd * seq.offset);
70
+ }
71
+ } else {
72
+ ubatch.embd = nullptr;
73
+ }
74
+ if (ubatch.equal_seqs) {
75
+ for (size_t i = 0; i < length; ++i) {
76
+ ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
77
+ }
78
+ } else {
79
+ // simple split
80
+ ubatch.pos = batch->pos + seq.offset;
81
+ }
82
+ if (ubatch.equal_seqs) {
83
+ ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
84
+ if (seq.seq_id) {
85
+ ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
86
+ }
87
+ } else {
88
+ // simple split
89
+ if (batch->n_seq_id) {
90
+ ubatch.n_seq_id = batch->n_seq_id + seq.offset;
91
+ } else {
92
+ for (size_t i = 0; i < length; ++i) {
93
+ ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
94
+ }
95
+ }
96
+ if (batch->seq_id) {
97
+ ubatch.seq_id = batch->seq_id + seq.offset;
98
+ }
99
+ }
100
+ if (logits_all) {
101
+ for (size_t i = 0; i < length; ++i) {
102
+ ubatch.output[ubatch.n_tokens + i] = 1;
103
+ out_ids.push_back(ids[seq.offset + i]);
104
+ }
105
+ } else if (batch->logits) {
106
+ if (ubatch.equal_seqs) {
107
+ for (size_t i = 0; i < length; ++i) {
108
+ size_t id = ids[seq.offset + i];
109
+ int8_t is_output = batch->logits[id];
110
+ ubatch.output[ubatch.n_tokens + i] = is_output;
111
+ if (is_output) { out_ids.push_back(id); }
112
+ }
113
+ } else {
114
+ // simple split
115
+ ubatch.output = batch->logits + seq.offset;
116
+ for (size_t i = 0; i < length; ++i) {
117
+ if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
118
+ }
119
+ }
120
+ } else {
121
+ // only get last output
122
+ for (size_t i = 0; i < length; ++i) {
123
+ size_t id = ids[seq.offset + i];
124
+ int8_t is_last = id == ids.size() - 1;
125
+ ubatch.output[ubatch.n_tokens + i] = is_last;
126
+ if (is_last) { out_ids.push_back(id); }
127
+ }
128
+ }
129
+ if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
130
+ ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
131
+ }
132
+ ubatch.n_tokens += length;
133
+ ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
134
+ seq.offset += length;
135
+ seq.length -= length;
136
+ n_tokens -= length;
137
+ LM_GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
138
+ }
139
+
140
+ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
141
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
142
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
143
+ ubatch.equal_seqs = false;
144
+ if (!seq.empty()) {
145
+ llama_sbatch_seq & s = seq[0];
146
+ size_t length = s.length < n_ubatch ? s.length : n_ubatch;
147
+ LM_GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
148
+ add_seq_to_ubatch(ubatch, s, length);
149
+ }
150
+ return ubatch;
151
+ }
152
+
153
+ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
154
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
155
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
156
+ if (!seq.empty()) {
157
+ size_t length = 0;
158
+ size_t n_tokens_in_ubatch = 0;
159
+ LM_GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
160
+ // smallest first, because it's easier to split this way;
161
+ // starting from the end to pop in constant time.
162
+ for (size_t i = seq.size(); i-- > 0;) {
163
+ llama_sbatch_seq & s = seq[i];
164
+ LM_GGML_ASSERT(s.length > 0);
165
+ if (length == 0) {
166
+ length = s.length < n_ubatch ? s.length : n_ubatch;
167
+ }
168
+ add_seq_to_ubatch(ubatch, s, length);
169
+ n_tokens_in_ubatch += length;
170
+ // shared prompts can't be mixed with any of their sequences,
171
+ // so it's safer to compute them in their own ubatch
172
+ if (s.n_seq_id > 1) { break; }
173
+ // stop when there isn't enough space for another sequence
174
+ if (length + n_tokens_in_ubatch > n_ubatch) { break; }
175
+ }
176
+ }
177
+ return ubatch;
178
+ }
179
+
180
+ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
181
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
182
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
183
+ if (!seq.empty()) {
184
+ llama_sbatch_seq & s = seq[seq.size() - 1];
185
+ size_t length = s.length < n_ubatch ? s.length : n_ubatch;
186
+ LM_GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
187
+ add_seq_to_ubatch(ubatch, s, length);
188
+ }
189
+ return ubatch;
190
+ }
191
+
192
+ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
193
+ LM_GGML_ASSERT(batch.n_tokens >= 0);
194
+ this->batch = &batch;
195
+ this->n_embd = n_embd;
196
+ this->logits_all = logits_all;
197
+
198
+ n_tokens = batch.n_tokens;
199
+ ids.resize(n_tokens);
200
+ out_ids.clear();
201
+ // TODO: reserve out_ids and seq
202
+
203
+ for (size_t i = 0; i < n_tokens; ++i) {
204
+ ids[i] = i;
205
+ }
206
+ if (simple_split) {
207
+ seq.resize(1);
208
+ llama_sbatch_seq & s = seq[0];
209
+ s.n_seq_id = 0;
210
+ s.seq_id = nullptr;
211
+ s.offset = 0;
212
+ s.length = n_tokens;
213
+ return;
214
+ }
215
+ std::sort(ids.begin(), ids.end(),
216
+ [&batch](size_t a, size_t b) {
217
+ int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
218
+ int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
219
+ // sort by seq_id, then by pos
220
+ if (n_seq_a == n_seq_b) {
221
+ if (batch.seq_id) {
222
+ for (int32_t i = 0; i < n_seq_a; ++i) {
223
+ llama_seq_id seq_id_a = batch.seq_id[a][i];
224
+ llama_seq_id seq_id_b = batch.seq_id[b][i];
225
+ // smaller seq_ids go first
226
+ if (seq_id_a != seq_id_b) {
227
+ return seq_id_a < seq_id_b;
228
+ }
229
+ }
230
+ }
231
+ // when all else is equal, sort by pos
232
+ if (batch.pos) {
233
+ return batch.pos[a] < batch.pos[b];
234
+ }
235
+ // no pos, sort by id
236
+ return a < b;
237
+ }
238
+ // shared prompts go first
239
+ return n_seq_a > n_seq_b;
240
+ }
241
+ );
242
+ // init seq
243
+ llama_sbatch_seq * last_seq = nullptr;
244
+
245
+ for (size_t i = 0; i < n_tokens; ++i) {
246
+ const size_t bi = ids[i];
247
+ const int32_t n_seqs = batch.n_seq_id[bi];
248
+ llama_seq_id * seq_ids = batch.seq_id[bi];
249
+ if (last_seq != nullptr) {
250
+ bool same = n_seqs == last_seq->n_seq_id;
251
+ for (int32_t j = 0; same && j < n_seqs; ++j) {
252
+ if (seq_ids[j] != last_seq->seq_id[j]) {
253
+ same = false;
254
+ }
255
+ }
256
+ if (same) {
257
+ last_seq->length += 1;
258
+ continue;
259
+ }
260
+ }
261
+ llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
262
+ seq.push_back(new_seq);
263
+ last_seq = &seq.back();
264
+ }
265
+ // keep shared prompts first at the end, then sort by length descending.
266
+ std::sort(seq.begin(), seq.end(),
267
+ [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
268
+ if (a.n_seq_id == b.n_seq_id) {
269
+ return a.length > b.length;
270
+ }
271
+ return a.n_seq_id < b.n_seq_id;
272
+ }
273
+ );
274
+ }
275
+
276
+ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
277
+ batch = in_batch;
278
+ LM_GGML_ASSERT(batch.n_tokens > 0);
279
+ if (!batch.pos) {
280
+ pos.resize(batch.n_tokens);
281
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
282
+ pos[i] = i + p0;
283
+ }
284
+ batch.pos = pos.data();
285
+ }
286
+ if (!batch.n_seq_id) {
287
+ n_seq_id.resize(batch.n_tokens);
288
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
289
+ n_seq_id[i] = seq_id_0.size();
290
+ }
291
+ batch.n_seq_id = n_seq_id.data();
292
+ }
293
+ if (!batch.seq_id) {
294
+ seq_id.resize(batch.n_tokens + 1);
295
+ seq_id[batch.n_tokens] = NULL;
296
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
297
+ seq_id[i] = seq_id_0.data();
298
+ }
299
+ batch.seq_id = seq_id.data();
300
+ }
301
+ if (!batch.logits) {
302
+ logits.resize(batch.n_tokens);
303
+ logits[logits.size() - 1] = true;
304
+ batch.logits = logits.data();
305
+ }
306
+ }
307
+
308
+ //
309
+ // interface implementation
310
+ //
311
+
312
+ struct llama_batch llama_batch_get_one(
313
+ llama_token * tokens,
314
+ int32_t n_tokens) {
315
+ return {
316
+ /*n_tokens =*/ n_tokens,
317
+ /*tokens =*/ tokens,
318
+ /*embd =*/ nullptr,
319
+ /*pos =*/ nullptr,
320
+ /*n_seq_id =*/ nullptr,
321
+ /*seq_id =*/ nullptr,
322
+ /*logits =*/ nullptr,
323
+ };
324
+ }
325
+
326
+ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
327
+ llama_batch batch = {
328
+ /*n_tokens =*/ 0,
329
+ /*tokens =*/ nullptr,
330
+ /*embd =*/ nullptr,
331
+ /*pos =*/ nullptr,
332
+ /*n_seq_id =*/ nullptr,
333
+ /*seq_id =*/ nullptr,
334
+ /*logits =*/ nullptr,
335
+ };
336
+
337
+ if (embd) {
338
+ batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
339
+ } else {
340
+ batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
341
+ }
342
+
343
+ batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
344
+ batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
345
+ batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
346
+ for (int i = 0; i < n_tokens_alloc; ++i) {
347
+ batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
348
+ }
349
+ batch.seq_id[n_tokens_alloc] = nullptr;
350
+
351
+ batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
352
+
353
+ return batch;
354
+ }
355
+
356
+ void llama_batch_free(struct llama_batch batch) {
357
+ if (batch.token) free(batch.token);
358
+ if (batch.embd) free(batch.embd);
359
+ if (batch.pos) free(batch.pos);
360
+ if (batch.n_seq_id) free(batch.n_seq_id);
361
+ if (batch.seq_id) {
362
+ for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
363
+ free(batch.seq_id[i]);
364
+ }
365
+ free(batch.seq_id);
366
+ }
367
+ if (batch.logits) free(batch.logits);
368
+ }