cui-llama.rn 1.4.3 → 1.4.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. package/README.md +93 -114
  2. package/android/src/main/CMakeLists.txt +5 -0
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +91 -17
  4. package/android/src/main/java/com/rnllama/RNLlama.java +37 -4
  5. package/android/src/main/jni-utils.h +6 -0
  6. package/android/src/main/jni.cpp +289 -31
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  15. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +7 -2
  16. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +7 -2
  17. package/cpp/chat-template.hpp +529 -0
  18. package/cpp/chat.cpp +1779 -0
  19. package/cpp/chat.h +135 -0
  20. package/cpp/common.cpp +2064 -1873
  21. package/cpp/common.h +700 -699
  22. package/cpp/ggml-alloc.c +1039 -1042
  23. package/cpp/ggml-alloc.h +1 -1
  24. package/cpp/ggml-backend-impl.h +255 -255
  25. package/cpp/ggml-backend-reg.cpp +586 -582
  26. package/cpp/ggml-backend.cpp +2004 -2002
  27. package/cpp/ggml-backend.h +354 -354
  28. package/cpp/ggml-common.h +1851 -1853
  29. package/cpp/ggml-cpp.h +39 -39
  30. package/cpp/ggml-cpu-aarch64.cpp +4248 -4247
  31. package/cpp/ggml-cpu-aarch64.h +8 -8
  32. package/cpp/ggml-cpu-impl.h +531 -386
  33. package/cpp/ggml-cpu-quants.c +12527 -10920
  34. package/cpp/ggml-cpu-traits.cpp +36 -36
  35. package/cpp/ggml-cpu-traits.h +38 -38
  36. package/cpp/ggml-cpu.c +15766 -14391
  37. package/cpp/ggml-cpu.cpp +655 -635
  38. package/cpp/ggml-cpu.h +138 -135
  39. package/cpp/ggml-impl.h +567 -567
  40. package/cpp/ggml-metal-impl.h +235 -0
  41. package/cpp/ggml-metal.h +1 -1
  42. package/cpp/ggml-metal.m +5146 -4884
  43. package/cpp/ggml-opt.cpp +854 -854
  44. package/cpp/ggml-opt.h +216 -216
  45. package/cpp/ggml-quants.c +5238 -5238
  46. package/cpp/ggml-threading.h +14 -14
  47. package/cpp/ggml.c +6529 -6514
  48. package/cpp/ggml.h +2198 -2194
  49. package/cpp/gguf.cpp +1329 -1329
  50. package/cpp/gguf.h +202 -202
  51. package/cpp/json-schema-to-grammar.cpp +1024 -1045
  52. package/cpp/json-schema-to-grammar.h +21 -8
  53. package/cpp/json.hpp +24766 -24766
  54. package/cpp/llama-adapter.cpp +347 -347
  55. package/cpp/llama-adapter.h +74 -74
  56. package/cpp/llama-arch.cpp +1513 -1487
  57. package/cpp/llama-arch.h +403 -400
  58. package/cpp/llama-batch.cpp +368 -368
  59. package/cpp/llama-batch.h +88 -88
  60. package/cpp/llama-chat.cpp +588 -578
  61. package/cpp/llama-chat.h +53 -52
  62. package/cpp/llama-context.cpp +1775 -1775
  63. package/cpp/llama-context.h +128 -128
  64. package/cpp/llama-cparams.cpp +1 -1
  65. package/cpp/llama-cparams.h +37 -37
  66. package/cpp/llama-cpp.h +30 -30
  67. package/cpp/llama-grammar.cpp +1219 -1139
  68. package/cpp/llama-grammar.h +173 -143
  69. package/cpp/llama-hparams.cpp +71 -71
  70. package/cpp/llama-hparams.h +139 -139
  71. package/cpp/llama-impl.cpp +167 -167
  72. package/cpp/llama-impl.h +61 -61
  73. package/cpp/llama-kv-cache.cpp +718 -718
  74. package/cpp/llama-kv-cache.h +219 -218
  75. package/cpp/llama-mmap.cpp +600 -590
  76. package/cpp/llama-mmap.h +68 -67
  77. package/cpp/llama-model-loader.cpp +1124 -1124
  78. package/cpp/llama-model-loader.h +167 -167
  79. package/cpp/llama-model.cpp +4087 -3997
  80. package/cpp/llama-model.h +370 -370
  81. package/cpp/llama-sampling.cpp +2558 -2408
  82. package/cpp/llama-sampling.h +32 -32
  83. package/cpp/llama-vocab.cpp +3264 -3247
  84. package/cpp/llama-vocab.h +125 -125
  85. package/cpp/llama.cpp +10284 -10077
  86. package/cpp/llama.h +1354 -1323
  87. package/cpp/log.cpp +393 -401
  88. package/cpp/log.h +132 -121
  89. package/cpp/minja/chat-template.hpp +529 -0
  90. package/cpp/minja/minja.hpp +2915 -0
  91. package/cpp/minja.hpp +2915 -0
  92. package/cpp/rn-llama.cpp +66 -6
  93. package/cpp/rn-llama.h +26 -1
  94. package/cpp/sampling.cpp +570 -505
  95. package/cpp/sampling.h +3 -0
  96. package/cpp/sgemm.cpp +2598 -2597
  97. package/cpp/sgemm.h +14 -14
  98. package/cpp/speculative.cpp +278 -277
  99. package/cpp/speculative.h +28 -28
  100. package/cpp/unicode.cpp +9 -2
  101. package/ios/CMakeLists.txt +6 -0
  102. package/ios/RNLlama.h +0 -8
  103. package/ios/RNLlama.mm +27 -3
  104. package/ios/RNLlamaContext.h +10 -1
  105. package/ios/RNLlamaContext.mm +269 -57
  106. package/jest/mock.js +21 -2
  107. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  108. package/lib/commonjs/grammar.js +3 -0
  109. package/lib/commonjs/grammar.js.map +1 -1
  110. package/lib/commonjs/index.js +87 -13
  111. package/lib/commonjs/index.js.map +1 -1
  112. package/lib/module/NativeRNLlama.js.map +1 -1
  113. package/lib/module/grammar.js +3 -0
  114. package/lib/module/grammar.js.map +1 -1
  115. package/lib/module/index.js +86 -13
  116. package/lib/module/index.js.map +1 -1
  117. package/lib/typescript/NativeRNLlama.d.ts +107 -2
  118. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  119. package/lib/typescript/grammar.d.ts.map +1 -1
  120. package/lib/typescript/index.d.ts +32 -7
  121. package/lib/typescript/index.d.ts.map +1 -1
  122. package/llama-rn.podspec +1 -1
  123. package/package.json +3 -2
  124. package/src/NativeRNLlama.ts +115 -3
  125. package/src/grammar.ts +3 -0
  126. package/src/index.ts +138 -21
  127. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCCompiler.cmake +0 -81
  128. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeSystem.cmake +0 -15
  129. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.c +0 -904
  130. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.o +0 -0
  131. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.cpp +0 -919
  132. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.o +0 -0
  133. package/android/src/main/build-arm64/CMakeFiles/CMakeConfigureLog.yaml +0 -55
  134. package/cpp/rn-llama.hpp +0 -913
@@ -1,368 +1,368 @@
1
- #include "llama-batch.h"
2
-
3
- #include <cstring>
4
- #include <algorithm>
5
-
6
- llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
7
- // clear empty sequences
8
- // the previous ubatch is assumed to be gone,
9
- // so nothing should refer to values in these sequences anymore.
10
- for (size_t i = seq.size(); i-- > 0;) {
11
- if (seq[i].length == 0) {
12
- seq.pop_back();
13
- } else {
14
- break;
15
- }
16
- }
17
- ubatch_token.resize(!has_embd ? n_ubatch : 0);
18
- ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
19
- ubatch_pos.resize(n_ubatch);
20
- ubatch_n_seq_id.resize(n_ubatch);
21
- ubatch_seq_id.resize(n_ubatch);
22
- ubatch_output.resize(n_ubatch);
23
- llama_ubatch ubatch = {
24
- /*equal_seqs =*/ true,
25
- /*n_tokens =*/ 0,
26
- /*n_seq_tokens =*/ 0,
27
- /*n_seqs =*/ 0,
28
- /*token =*/ !has_embd ? ubatch_token.data() : nullptr,
29
- /*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
30
- /*pos =*/ ubatch_pos.data(),
31
- /*n_seq_id =*/ ubatch_n_seq_id.data(),
32
- /*seq_id =*/ ubatch_seq_id.data(),
33
- /*output =*/ ubatch_output.data(),
34
- };
35
- return ubatch;
36
- }
37
-
38
- void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
39
- LM_GGML_ASSERT(batch != nullptr);
40
- LM_GGML_ASSERT(length <= seq.length);
41
- // Can only add sequences of equal lengths to a batch,
42
- // otherwise it isn't clear to which sequence a token belongs
43
- LM_GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
44
- LM_GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
45
- // NOTE: loops are separated for cache-friendliness
46
- if (batch->token) {
47
- if (ubatch.equal_seqs) {
48
- for (size_t i = 0; i < length; ++i) {
49
- ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
50
- }
51
- } else {
52
- // simple split
53
- ubatch.token = batch->token + seq.offset;
54
- }
55
- } else {
56
- ubatch.token = nullptr;
57
- }
58
- if (batch->embd) {
59
- if (ubatch.equal_seqs) {
60
- for (size_t i = 0; i < length; ++i) {
61
- memcpy(
62
- ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
63
- batch->embd + (n_embd * ids[seq.offset + i]),
64
- n_embd * sizeof(float)
65
- );
66
- }
67
- } else {
68
- // simple split
69
- ubatch.embd = batch->embd + (n_embd * seq.offset);
70
- }
71
- } else {
72
- ubatch.embd = nullptr;
73
- }
74
- if (ubatch.equal_seqs) {
75
- for (size_t i = 0; i < length; ++i) {
76
- ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
77
- }
78
- } else {
79
- // simple split
80
- ubatch.pos = batch->pos + seq.offset;
81
- }
82
- if (ubatch.equal_seqs) {
83
- ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
84
- if (seq.seq_id) {
85
- ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
86
- }
87
- } else {
88
- // simple split
89
- if (batch->n_seq_id) {
90
- ubatch.n_seq_id = batch->n_seq_id + seq.offset;
91
- } else {
92
- for (size_t i = 0; i < length; ++i) {
93
- ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
94
- }
95
- }
96
- if (batch->seq_id) {
97
- ubatch.seq_id = batch->seq_id + seq.offset;
98
- }
99
- }
100
- if (logits_all) {
101
- for (size_t i = 0; i < length; ++i) {
102
- ubatch.output[ubatch.n_tokens + i] = 1;
103
- out_ids.push_back(ids[seq.offset + i]);
104
- }
105
- } else if (batch->logits) {
106
- if (ubatch.equal_seqs) {
107
- for (size_t i = 0; i < length; ++i) {
108
- size_t id = ids[seq.offset + i];
109
- int8_t is_output = batch->logits[id];
110
- ubatch.output[ubatch.n_tokens + i] = is_output;
111
- if (is_output) { out_ids.push_back(id); }
112
- }
113
- } else {
114
- // simple split
115
- ubatch.output = batch->logits + seq.offset;
116
- for (size_t i = 0; i < length; ++i) {
117
- if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
118
- }
119
- }
120
- } else {
121
- // only get last output
122
- for (size_t i = 0; i < length; ++i) {
123
- size_t id = ids[seq.offset + i];
124
- int8_t is_last = id == ids.size() - 1;
125
- ubatch.output[ubatch.n_tokens + i] = is_last;
126
- if (is_last) { out_ids.push_back(id); }
127
- }
128
- }
129
- if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
130
- ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
131
- }
132
- ubatch.n_tokens += length;
133
- ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
134
- seq.offset += length;
135
- seq.length -= length;
136
- n_tokens -= length;
137
- LM_GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
138
- }
139
-
140
- llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
141
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
142
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
143
- ubatch.equal_seqs = false;
144
- if (!seq.empty()) {
145
- llama_sbatch_seq & s = seq[0];
146
- size_t length = s.length < n_ubatch ? s.length : n_ubatch;
147
- LM_GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
148
- add_seq_to_ubatch(ubatch, s, length);
149
- }
150
- return ubatch;
151
- }
152
-
153
- llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
154
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
155
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
156
- if (!seq.empty()) {
157
- size_t length = 0;
158
- size_t n_tokens_in_ubatch = 0;
159
- LM_GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
160
- // smallest first, because it's easier to split this way;
161
- // starting from the end to pop in constant time.
162
- for (size_t i = seq.size(); i-- > 0;) {
163
- llama_sbatch_seq & s = seq[i];
164
- LM_GGML_ASSERT(s.length > 0);
165
- if (length == 0) {
166
- length = s.length < n_ubatch ? s.length : n_ubatch;
167
- }
168
- add_seq_to_ubatch(ubatch, s, length);
169
- n_tokens_in_ubatch += length;
170
- // shared prompts can't be mixed with any of their sequences,
171
- // so it's safer to compute them in their own ubatch
172
- if (s.n_seq_id > 1) { break; }
173
- // stop when there isn't enough space for another sequence
174
- if (length + n_tokens_in_ubatch > n_ubatch) { break; }
175
- }
176
- }
177
- return ubatch;
178
- }
179
-
180
- llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
181
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
182
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
183
- if (!seq.empty()) {
184
- llama_sbatch_seq & s = seq[seq.size() - 1];
185
- size_t length = s.length < n_ubatch ? s.length : n_ubatch;
186
- LM_GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
187
- add_seq_to_ubatch(ubatch, s, length);
188
- }
189
- return ubatch;
190
- }
191
-
192
- void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
193
- LM_GGML_ASSERT(batch.n_tokens >= 0);
194
- this->batch = &batch;
195
- this->n_embd = n_embd;
196
- this->logits_all = logits_all;
197
-
198
- n_tokens = batch.n_tokens;
199
- ids.resize(n_tokens);
200
- out_ids.clear();
201
- // TODO: reserve out_ids and seq
202
-
203
- for (size_t i = 0; i < n_tokens; ++i) {
204
- ids[i] = i;
205
- }
206
- if (simple_split) {
207
- seq.resize(1);
208
- llama_sbatch_seq & s = seq[0];
209
- s.n_seq_id = 0;
210
- s.seq_id = nullptr;
211
- s.offset = 0;
212
- s.length = n_tokens;
213
- return;
214
- }
215
- std::sort(ids.begin(), ids.end(),
216
- [&batch](size_t a, size_t b) {
217
- int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
218
- int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
219
- // sort by seq_id, then by pos
220
- if (n_seq_a == n_seq_b) {
221
- if (batch.seq_id) {
222
- for (int32_t i = 0; i < n_seq_a; ++i) {
223
- llama_seq_id seq_id_a = batch.seq_id[a][i];
224
- llama_seq_id seq_id_b = batch.seq_id[b][i];
225
- // smaller seq_ids go first
226
- if (seq_id_a != seq_id_b) {
227
- return seq_id_a < seq_id_b;
228
- }
229
- }
230
- }
231
- // when all else is equal, sort by pos
232
- if (batch.pos) {
233
- return batch.pos[a] < batch.pos[b];
234
- }
235
- // no pos, sort by id
236
- return a < b;
237
- }
238
- // shared prompts go first
239
- return n_seq_a > n_seq_b;
240
- }
241
- );
242
- // init seq
243
- llama_sbatch_seq * last_seq = nullptr;
244
-
245
- for (size_t i = 0; i < n_tokens; ++i) {
246
- const size_t bi = ids[i];
247
- const int32_t n_seqs = batch.n_seq_id[bi];
248
- llama_seq_id * seq_ids = batch.seq_id[bi];
249
- if (last_seq != nullptr) {
250
- bool same = n_seqs == last_seq->n_seq_id;
251
- for (int32_t j = 0; same && j < n_seqs; ++j) {
252
- if (seq_ids[j] != last_seq->seq_id[j]) {
253
- same = false;
254
- }
255
- }
256
- if (same) {
257
- last_seq->length += 1;
258
- continue;
259
- }
260
- }
261
- llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
262
- seq.push_back(new_seq);
263
- last_seq = &seq.back();
264
- }
265
- // keep shared prompts first at the end, then sort by length descending.
266
- std::sort(seq.begin(), seq.end(),
267
- [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
268
- if (a.n_seq_id == b.n_seq_id) {
269
- return a.length > b.length;
270
- }
271
- return a.n_seq_id < b.n_seq_id;
272
- }
273
- );
274
- }
275
-
276
- llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
277
- batch = in_batch;
278
- LM_GGML_ASSERT(batch.n_tokens > 0);
279
- if (!batch.pos) {
280
- pos.resize(batch.n_tokens);
281
- for (int32_t i = 0; i < batch.n_tokens; i++) {
282
- pos[i] = i + p0;
283
- }
284
- batch.pos = pos.data();
285
- }
286
- if (!batch.n_seq_id) {
287
- n_seq_id.resize(batch.n_tokens);
288
- for (int32_t i = 0; i < batch.n_tokens; i++) {
289
- n_seq_id[i] = seq_id_0.size();
290
- }
291
- batch.n_seq_id = n_seq_id.data();
292
- }
293
- if (!batch.seq_id) {
294
- seq_id.resize(batch.n_tokens + 1);
295
- seq_id[batch.n_tokens] = NULL;
296
- for (int32_t i = 0; i < batch.n_tokens; i++) {
297
- seq_id[i] = seq_id_0.data();
298
- }
299
- batch.seq_id = seq_id.data();
300
- }
301
- if (!batch.logits) {
302
- logits.resize(batch.n_tokens);
303
- logits[logits.size() - 1] = true;
304
- batch.logits = logits.data();
305
- }
306
- }
307
-
308
- //
309
- // interface implementation
310
- //
311
-
312
- struct llama_batch llama_batch_get_one(
313
- llama_token * tokens,
314
- int32_t n_tokens) {
315
- return {
316
- /*n_tokens =*/ n_tokens,
317
- /*tokens =*/ tokens,
318
- /*embd =*/ nullptr,
319
- /*pos =*/ nullptr,
320
- /*n_seq_id =*/ nullptr,
321
- /*seq_id =*/ nullptr,
322
- /*logits =*/ nullptr,
323
- };
324
- }
325
-
326
- struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
327
- llama_batch batch = {
328
- /*n_tokens =*/ 0,
329
- /*tokens =*/ nullptr,
330
- /*embd =*/ nullptr,
331
- /*pos =*/ nullptr,
332
- /*n_seq_id =*/ nullptr,
333
- /*seq_id =*/ nullptr,
334
- /*logits =*/ nullptr,
335
- };
336
-
337
- if (embd) {
338
- batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
339
- } else {
340
- batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
341
- }
342
-
343
- batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
344
- batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
345
- batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
346
- for (int i = 0; i < n_tokens_alloc; ++i) {
347
- batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
348
- }
349
- batch.seq_id[n_tokens_alloc] = nullptr;
350
-
351
- batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
352
-
353
- return batch;
354
- }
355
-
356
- void llama_batch_free(struct llama_batch batch) {
357
- if (batch.token) free(batch.token);
358
- if (batch.embd) free(batch.embd);
359
- if (batch.pos) free(batch.pos);
360
- if (batch.n_seq_id) free(batch.n_seq_id);
361
- if (batch.seq_id) {
362
- for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
363
- free(batch.seq_id[i]);
364
- }
365
- free(batch.seq_id);
366
- }
367
- if (batch.logits) free(batch.logits);
368
- }
1
+ #include "llama-batch.h"
2
+
3
+ #include <cstring>
4
+ #include <algorithm>
5
+
6
+ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
7
+ // clear empty sequences
8
+ // the previous ubatch is assumed to be gone,
9
+ // so nothing should refer to values in these sequences anymore.
10
+ for (size_t i = seq.size(); i-- > 0;) {
11
+ if (seq[i].length == 0) {
12
+ seq.pop_back();
13
+ } else {
14
+ break;
15
+ }
16
+ }
17
+ ubatch_token.resize(!has_embd ? n_ubatch : 0);
18
+ ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
19
+ ubatch_pos.resize(n_ubatch);
20
+ ubatch_n_seq_id.resize(n_ubatch);
21
+ ubatch_seq_id.resize(n_ubatch);
22
+ ubatch_output.resize(n_ubatch);
23
+ llama_ubatch ubatch = {
24
+ /*equal_seqs =*/ true,
25
+ /*n_tokens =*/ 0,
26
+ /*n_seq_tokens =*/ 0,
27
+ /*n_seqs =*/ 0,
28
+ /*token =*/ !has_embd ? ubatch_token.data() : nullptr,
29
+ /*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
30
+ /*pos =*/ ubatch_pos.data(),
31
+ /*n_seq_id =*/ ubatch_n_seq_id.data(),
32
+ /*seq_id =*/ ubatch_seq_id.data(),
33
+ /*output =*/ ubatch_output.data(),
34
+ };
35
+ return ubatch;
36
+ }
37
+
38
+ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
39
+ LM_GGML_ASSERT(batch != nullptr);
40
+ LM_GGML_ASSERT(length <= seq.length);
41
+ // Can only add sequences of equal lengths to a batch,
42
+ // otherwise it isn't clear to which sequence a token belongs
43
+ LM_GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
44
+ LM_GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
45
+ // NOTE: loops are separated for cache-friendliness
46
+ if (batch->token) {
47
+ if (ubatch.equal_seqs) {
48
+ for (size_t i = 0; i < length; ++i) {
49
+ ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
50
+ }
51
+ } else {
52
+ // simple split
53
+ ubatch.token = batch->token + seq.offset;
54
+ }
55
+ } else {
56
+ ubatch.token = nullptr;
57
+ }
58
+ if (batch->embd) {
59
+ if (ubatch.equal_seqs) {
60
+ for (size_t i = 0; i < length; ++i) {
61
+ memcpy(
62
+ ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
63
+ batch->embd + (n_embd * ids[seq.offset + i]),
64
+ n_embd * sizeof(float)
65
+ );
66
+ }
67
+ } else {
68
+ // simple split
69
+ ubatch.embd = batch->embd + (n_embd * seq.offset);
70
+ }
71
+ } else {
72
+ ubatch.embd = nullptr;
73
+ }
74
+ if (ubatch.equal_seqs) {
75
+ for (size_t i = 0; i < length; ++i) {
76
+ ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
77
+ }
78
+ } else {
79
+ // simple split
80
+ ubatch.pos = batch->pos + seq.offset;
81
+ }
82
+ if (ubatch.equal_seqs) {
83
+ ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
84
+ if (seq.seq_id) {
85
+ ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
86
+ }
87
+ } else {
88
+ // simple split
89
+ if (batch->n_seq_id) {
90
+ ubatch.n_seq_id = batch->n_seq_id + seq.offset;
91
+ } else {
92
+ for (size_t i = 0; i < length; ++i) {
93
+ ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
94
+ }
95
+ }
96
+ if (batch->seq_id) {
97
+ ubatch.seq_id = batch->seq_id + seq.offset;
98
+ }
99
+ }
100
+ if (logits_all) {
101
+ for (size_t i = 0; i < length; ++i) {
102
+ ubatch.output[ubatch.n_tokens + i] = 1;
103
+ out_ids.push_back(ids[seq.offset + i]);
104
+ }
105
+ } else if (batch->logits) {
106
+ if (ubatch.equal_seqs) {
107
+ for (size_t i = 0; i < length; ++i) {
108
+ size_t id = ids[seq.offset + i];
109
+ int8_t is_output = batch->logits[id];
110
+ ubatch.output[ubatch.n_tokens + i] = is_output;
111
+ if (is_output) { out_ids.push_back(id); }
112
+ }
113
+ } else {
114
+ // simple split
115
+ ubatch.output = batch->logits + seq.offset;
116
+ for (size_t i = 0; i < length; ++i) {
117
+ if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
118
+ }
119
+ }
120
+ } else {
121
+ // only get last output
122
+ for (size_t i = 0; i < length; ++i) {
123
+ size_t id = ids[seq.offset + i];
124
+ int8_t is_last = id == ids.size() - 1;
125
+ ubatch.output[ubatch.n_tokens + i] = is_last;
126
+ if (is_last) { out_ids.push_back(id); }
127
+ }
128
+ }
129
+ if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
130
+ ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
131
+ }
132
+ ubatch.n_tokens += length;
133
+ ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
134
+ seq.offset += length;
135
+ seq.length -= length;
136
+ n_tokens -= length;
137
+ LM_GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
138
+ }
139
+
140
+ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
141
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
142
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
143
+ ubatch.equal_seqs = false;
144
+ if (!seq.empty()) {
145
+ llama_sbatch_seq & s = seq[0];
146
+ size_t length = s.length < n_ubatch ? s.length : n_ubatch;
147
+ LM_GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
148
+ add_seq_to_ubatch(ubatch, s, length);
149
+ }
150
+ return ubatch;
151
+ }
152
+
153
+ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
154
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
155
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
156
+ if (!seq.empty()) {
157
+ size_t length = 0;
158
+ size_t n_tokens_in_ubatch = 0;
159
+ LM_GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
160
+ // smallest first, because it's easier to split this way;
161
+ // starting from the end to pop in constant time.
162
+ for (size_t i = seq.size(); i-- > 0;) {
163
+ llama_sbatch_seq & s = seq[i];
164
+ LM_GGML_ASSERT(s.length > 0);
165
+ if (length == 0) {
166
+ length = s.length < n_ubatch ? s.length : n_ubatch;
167
+ }
168
+ add_seq_to_ubatch(ubatch, s, length);
169
+ n_tokens_in_ubatch += length;
170
+ // shared prompts can't be mixed with any of their sequences,
171
+ // so it's safer to compute them in their own ubatch
172
+ if (s.n_seq_id > 1) { break; }
173
+ // stop when there isn't enough space for another sequence
174
+ if (length + n_tokens_in_ubatch > n_ubatch) { break; }
175
+ }
176
+ }
177
+ return ubatch;
178
+ }
179
+
180
+ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
181
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
182
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
183
+ if (!seq.empty()) {
184
+ llama_sbatch_seq & s = seq[seq.size() - 1];
185
+ size_t length = s.length < n_ubatch ? s.length : n_ubatch;
186
+ LM_GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
187
+ add_seq_to_ubatch(ubatch, s, length);
188
+ }
189
+ return ubatch;
190
+ }
191
+
192
+ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
193
+ LM_GGML_ASSERT(batch.n_tokens >= 0);
194
+ this->batch = &batch;
195
+ this->n_embd = n_embd;
196
+ this->logits_all = logits_all;
197
+
198
+ n_tokens = batch.n_tokens;
199
+ ids.resize(n_tokens);
200
+ out_ids.clear();
201
+ // TODO: reserve out_ids and seq
202
+
203
+ for (size_t i = 0; i < n_tokens; ++i) {
204
+ ids[i] = i;
205
+ }
206
+ if (simple_split) {
207
+ seq.resize(1);
208
+ llama_sbatch_seq & s = seq[0];
209
+ s.n_seq_id = 0;
210
+ s.seq_id = nullptr;
211
+ s.offset = 0;
212
+ s.length = n_tokens;
213
+ return;
214
+ }
215
+ std::sort(ids.begin(), ids.end(),
216
+ [&batch](size_t a, size_t b) {
217
+ int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
218
+ int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
219
+ // sort by seq_id, then by pos
220
+ if (n_seq_a == n_seq_b) {
221
+ if (batch.seq_id) {
222
+ for (int32_t i = 0; i < n_seq_a; ++i) {
223
+ llama_seq_id seq_id_a = batch.seq_id[a][i];
224
+ llama_seq_id seq_id_b = batch.seq_id[b][i];
225
+ // smaller seq_ids go first
226
+ if (seq_id_a != seq_id_b) {
227
+ return seq_id_a < seq_id_b;
228
+ }
229
+ }
230
+ }
231
+ // when all else is equal, sort by pos
232
+ if (batch.pos) {
233
+ return batch.pos[a] < batch.pos[b];
234
+ }
235
+ // no pos, sort by id
236
+ return a < b;
237
+ }
238
+ // shared prompts go first
239
+ return n_seq_a > n_seq_b;
240
+ }
241
+ );
242
+ // init seq
243
+ llama_sbatch_seq * last_seq = nullptr;
244
+
245
+ for (size_t i = 0; i < n_tokens; ++i) {
246
+ const size_t bi = ids[i];
247
+ const int32_t n_seqs = batch.n_seq_id[bi];
248
+ llama_seq_id * seq_ids = batch.seq_id[bi];
249
+ if (last_seq != nullptr) {
250
+ bool same = n_seqs == last_seq->n_seq_id;
251
+ for (int32_t j = 0; same && j < n_seqs; ++j) {
252
+ if (seq_ids[j] != last_seq->seq_id[j]) {
253
+ same = false;
254
+ }
255
+ }
256
+ if (same) {
257
+ last_seq->length += 1;
258
+ continue;
259
+ }
260
+ }
261
+ llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
262
+ seq.push_back(new_seq);
263
+ last_seq = &seq.back();
264
+ }
265
+ // keep shared prompts first at the end, then sort by length descending.
266
+ std::sort(seq.begin(), seq.end(),
267
+ [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
268
+ if (a.n_seq_id == b.n_seq_id) {
269
+ return a.length > b.length;
270
+ }
271
+ return a.n_seq_id < b.n_seq_id;
272
+ }
273
+ );
274
+ }
275
+
276
+ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
277
+ batch = in_batch;
278
+ LM_GGML_ASSERT(batch.n_tokens > 0);
279
+ if (!batch.pos) {
280
+ pos.resize(batch.n_tokens);
281
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
282
+ pos[i] = i + p0;
283
+ }
284
+ batch.pos = pos.data();
285
+ }
286
+ if (!batch.n_seq_id) {
287
+ n_seq_id.resize(batch.n_tokens);
288
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
289
+ n_seq_id[i] = seq_id_0.size();
290
+ }
291
+ batch.n_seq_id = n_seq_id.data();
292
+ }
293
+ if (!batch.seq_id) {
294
+ seq_id.resize(batch.n_tokens + 1);
295
+ seq_id[batch.n_tokens] = NULL;
296
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
297
+ seq_id[i] = seq_id_0.data();
298
+ }
299
+ batch.seq_id = seq_id.data();
300
+ }
301
+ if (!batch.logits) {
302
+ logits.resize(batch.n_tokens);
303
+ logits[logits.size() - 1] = true;
304
+ batch.logits = logits.data();
305
+ }
306
+ }
307
+
308
+ //
309
+ // interface implementation
310
+ //
311
+
312
+ struct llama_batch llama_batch_get_one(
313
+ llama_token * tokens,
314
+ int32_t n_tokens) {
315
+ return {
316
+ /*n_tokens =*/ n_tokens,
317
+ /*tokens =*/ tokens,
318
+ /*embd =*/ nullptr,
319
+ /*pos =*/ nullptr,
320
+ /*n_seq_id =*/ nullptr,
321
+ /*seq_id =*/ nullptr,
322
+ /*logits =*/ nullptr,
323
+ };
324
+ }
325
+
326
+ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
327
+ llama_batch batch = {
328
+ /*n_tokens =*/ 0,
329
+ /*tokens =*/ nullptr,
330
+ /*embd =*/ nullptr,
331
+ /*pos =*/ nullptr,
332
+ /*n_seq_id =*/ nullptr,
333
+ /*seq_id =*/ nullptr,
334
+ /*logits =*/ nullptr,
335
+ };
336
+
337
+ if (embd) {
338
+ batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
339
+ } else {
340
+ batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
341
+ }
342
+
343
+ batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
344
+ batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
345
+ batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
346
+ for (int i = 0; i < n_tokens_alloc; ++i) {
347
+ batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
348
+ }
349
+ batch.seq_id[n_tokens_alloc] = nullptr;
350
+
351
+ batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
352
+
353
+ return batch;
354
+ }
355
+
356
+ void llama_batch_free(struct llama_batch batch) {
357
+ if (batch.token) free(batch.token);
358
+ if (batch.embd) free(batch.embd);
359
+ if (batch.pos) free(batch.pos);
360
+ if (batch.n_seq_id) free(batch.n_seq_id);
361
+ if (batch.seq_id) {
362
+ for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
363
+ free(batch.seq_id[i]);
364
+ }
365
+ free(batch.seq_id);
366
+ }
367
+ if (batch.logits) free(batch.logits);
368
+ }