cui-llama.rn 1.3.6 → 1.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. package/README.md +22 -1
  2. package/android/src/main/CMakeLists.txt +25 -26
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +31 -9
  4. package/android/src/main/java/com/rnllama/RNLlama.java +98 -0
  5. package/android/src/main/jni-utils.h +94 -0
  6. package/android/src/main/jni.cpp +133 -63
  7. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +15 -0
  8. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +15 -0
  9. package/cpp/common.cpp +2085 -1982
  10. package/cpp/common.h +696 -664
  11. package/cpp/ggml-alloc.c +1042 -1037
  12. package/cpp/ggml-backend-impl.h +255 -256
  13. package/cpp/ggml-backend-reg.cpp +582 -582
  14. package/cpp/ggml-backend.cpp +2002 -2002
  15. package/cpp/ggml-backend.h +354 -352
  16. package/cpp/ggml-common.h +1853 -1853
  17. package/cpp/ggml-cpp.h +39 -39
  18. package/cpp/ggml-cpu-aarch64.cpp +4247 -4247
  19. package/cpp/ggml-cpu-aarch64.h +8 -8
  20. package/cpp/ggml-cpu-impl.h +386 -386
  21. package/cpp/ggml-cpu-quants.c +10920 -10839
  22. package/cpp/ggml-cpu-traits.cpp +36 -36
  23. package/cpp/ggml-cpu-traits.h +38 -38
  24. package/cpp/ggml-cpu.c +14391 -14122
  25. package/cpp/ggml-cpu.cpp +635 -627
  26. package/cpp/ggml-cpu.h +135 -135
  27. package/cpp/ggml-impl.h +567 -567
  28. package/cpp/ggml-metal-impl.h +288 -0
  29. package/cpp/ggml-metal.m +4884 -4884
  30. package/cpp/ggml-opt.cpp +854 -0
  31. package/cpp/ggml-opt.h +216 -0
  32. package/cpp/ggml-quants.c +5238 -5238
  33. package/cpp/ggml-threading.h +14 -14
  34. package/cpp/ggml.c +6514 -6448
  35. package/cpp/ggml.h +2194 -2163
  36. package/cpp/gguf.cpp +1329 -1325
  37. package/cpp/gguf.h +202 -202
  38. package/cpp/json-schema-to-grammar.cpp +1045 -1045
  39. package/cpp/json-schema-to-grammar.h +8 -8
  40. package/cpp/json.hpp +24766 -24766
  41. package/cpp/llama-adapter.cpp +347 -346
  42. package/cpp/llama-adapter.h +74 -73
  43. package/cpp/llama-arch.cpp +1487 -1434
  44. package/cpp/llama-arch.h +400 -395
  45. package/cpp/llama-batch.cpp +368 -368
  46. package/cpp/llama-batch.h +88 -88
  47. package/cpp/llama-chat.cpp +578 -567
  48. package/cpp/llama-chat.h +52 -51
  49. package/cpp/llama-context.cpp +1775 -1771
  50. package/cpp/llama-context.h +128 -128
  51. package/cpp/llama-cparams.cpp +1 -1
  52. package/cpp/llama-cparams.h +37 -37
  53. package/cpp/llama-cpp.h +30 -30
  54. package/cpp/llama-grammar.cpp +1139 -1139
  55. package/cpp/llama-grammar.h +143 -143
  56. package/cpp/llama-hparams.cpp +71 -71
  57. package/cpp/llama-hparams.h +139 -140
  58. package/cpp/llama-impl.cpp +167 -167
  59. package/cpp/llama-impl.h +61 -61
  60. package/cpp/llama-kv-cache.cpp +718 -718
  61. package/cpp/llama-kv-cache.h +218 -218
  62. package/cpp/llama-mmap.cpp +590 -589
  63. package/cpp/llama-mmap.h +67 -67
  64. package/cpp/llama-model-loader.cpp +1124 -1011
  65. package/cpp/llama-model-loader.h +167 -158
  66. package/cpp/llama-model.cpp +3997 -2202
  67. package/cpp/llama-model.h +370 -391
  68. package/cpp/llama-sampling.cpp +2408 -2406
  69. package/cpp/llama-sampling.h +32 -48
  70. package/cpp/llama-vocab.cpp +3247 -1982
  71. package/cpp/llama-vocab.h +125 -182
  72. package/cpp/llama.cpp +10077 -12544
  73. package/cpp/llama.h +1323 -1285
  74. package/cpp/log.cpp +401 -401
  75. package/cpp/log.h +121 -121
  76. package/cpp/rn-llama.hpp +123 -116
  77. package/cpp/sampling.cpp +505 -500
  78. package/cpp/sgemm.cpp +2597 -2597
  79. package/cpp/sgemm.h +14 -14
  80. package/cpp/speculative.cpp +277 -274
  81. package/cpp/speculative.h +28 -28
  82. package/cpp/unicode.cpp +2 -3
  83. package/ios/RNLlama.mm +47 -0
  84. package/ios/RNLlamaContext.h +3 -1
  85. package/ios/RNLlamaContext.mm +71 -14
  86. package/jest/mock.js +15 -3
  87. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  88. package/lib/commonjs/index.js +33 -37
  89. package/lib/commonjs/index.js.map +1 -1
  90. package/lib/module/NativeRNLlama.js.map +1 -1
  91. package/lib/module/index.js +31 -35
  92. package/lib/module/index.js.map +1 -1
  93. package/lib/typescript/NativeRNLlama.d.ts +26 -6
  94. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  95. package/lib/typescript/index.d.ts +21 -36
  96. package/lib/typescript/index.d.ts.map +1 -1
  97. package/llama-rn.podspec +4 -18
  98. package/package.json +2 -3
  99. package/src/NativeRNLlama.ts +32 -13
  100. package/src/index.ts +52 -47
  101. package/cpp/llama.cpp.rej +0 -23
@@ -1,368 +1,368 @@
1
- #include "llama-batch.h"
2
-
3
- #include <cstring>
4
- #include <algorithm>
5
-
6
- llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
7
- // clear empty sequences
8
- // the previous ubatch is assumed to be gone,
9
- // so nothing should refer to values in these sequences anymore.
10
- for (size_t i = seq.size(); i-- > 0;) {
11
- if (seq[i].length == 0) {
12
- seq.pop_back();
13
- } else {
14
- break;
15
- }
16
- }
17
- ubatch_token.resize(!has_embd ? n_ubatch : 0);
18
- ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
19
- ubatch_pos.resize(n_ubatch);
20
- ubatch_n_seq_id.resize(n_ubatch);
21
- ubatch_seq_id.resize(n_ubatch);
22
- ubatch_output.resize(n_ubatch);
23
- llama_ubatch ubatch = {
24
- /*equal_seqs =*/ true,
25
- /*n_tokens =*/ 0,
26
- /*n_seq_tokens =*/ 0,
27
- /*n_seqs =*/ 0,
28
- /*token =*/ !has_embd ? ubatch_token.data() : nullptr,
29
- /*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
30
- /*pos =*/ ubatch_pos.data(),
31
- /*n_seq_id =*/ ubatch_n_seq_id.data(),
32
- /*seq_id =*/ ubatch_seq_id.data(),
33
- /*output =*/ ubatch_output.data(),
34
- };
35
- return ubatch;
36
- }
37
-
38
- void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
39
- LM_GGML_ASSERT(batch != nullptr);
40
- LM_GGML_ASSERT(length <= seq.length);
41
- // Can only add sequences of equal lengths to a batch,
42
- // otherwise it isn't clear to which sequence a token belongs
43
- LM_GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
44
- LM_GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
45
- // NOTE: loops are separated for cache-friendliness
46
- if (batch->token) {
47
- if (ubatch.equal_seqs) {
48
- for (size_t i = 0; i < length; ++i) {
49
- ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
50
- }
51
- } else {
52
- // simple split
53
- ubatch.token = batch->token + seq.offset;
54
- }
55
- } else {
56
- ubatch.token = nullptr;
57
- }
58
- if (batch->embd) {
59
- if (ubatch.equal_seqs) {
60
- for (size_t i = 0; i < length; ++i) {
61
- memcpy(
62
- ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
63
- batch->embd + (n_embd * ids[seq.offset + i]),
64
- n_embd * sizeof(float)
65
- );
66
- }
67
- } else {
68
- // simple split
69
- ubatch.embd = batch->embd + (n_embd * seq.offset);
70
- }
71
- } else {
72
- ubatch.embd = nullptr;
73
- }
74
- if (ubatch.equal_seqs) {
75
- for (size_t i = 0; i < length; ++i) {
76
- ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
77
- }
78
- } else {
79
- // simple split
80
- ubatch.pos = batch->pos + seq.offset;
81
- }
82
- if (ubatch.equal_seqs) {
83
- ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
84
- if (seq.seq_id) {
85
- ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
86
- }
87
- } else {
88
- // simple split
89
- if (batch->n_seq_id) {
90
- ubatch.n_seq_id = batch->n_seq_id + seq.offset;
91
- } else {
92
- for (size_t i = 0; i < length; ++i) {
93
- ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
94
- }
95
- }
96
- if (batch->seq_id) {
97
- ubatch.seq_id = batch->seq_id + seq.offset;
98
- }
99
- }
100
- if (logits_all) {
101
- for (size_t i = 0; i < length; ++i) {
102
- ubatch.output[ubatch.n_tokens + i] = 1;
103
- out_ids.push_back(ids[seq.offset + i]);
104
- }
105
- } else if (batch->logits) {
106
- if (ubatch.equal_seqs) {
107
- for (size_t i = 0; i < length; ++i) {
108
- size_t id = ids[seq.offset + i];
109
- int8_t is_output = batch->logits[id];
110
- ubatch.output[ubatch.n_tokens + i] = is_output;
111
- if (is_output) { out_ids.push_back(id); }
112
- }
113
- } else {
114
- // simple split
115
- ubatch.output = batch->logits + seq.offset;
116
- for (size_t i = 0; i < length; ++i) {
117
- if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
118
- }
119
- }
120
- } else {
121
- // only get last output
122
- for (size_t i = 0; i < length; ++i) {
123
- size_t id = ids[seq.offset + i];
124
- int8_t is_last = id == ids.size() - 1;
125
- ubatch.output[ubatch.n_tokens + i] = is_last;
126
- if (is_last) { out_ids.push_back(id); }
127
- }
128
- }
129
- if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
130
- ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
131
- }
132
- ubatch.n_tokens += length;
133
- ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
134
- seq.offset += length;
135
- seq.length -= length;
136
- n_tokens -= length;
137
- LM_GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
138
- }
139
-
140
- llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
141
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
142
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
143
- ubatch.equal_seqs = false;
144
- if (!seq.empty()) {
145
- llama_sbatch_seq & s = seq[0];
146
- size_t length = s.length < n_ubatch ? s.length : n_ubatch;
147
- LM_GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
148
- add_seq_to_ubatch(ubatch, s, length);
149
- }
150
- return ubatch;
151
- }
152
-
153
- llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
154
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
155
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
156
- if (!seq.empty()) {
157
- size_t length = 0;
158
- size_t n_tokens_in_ubatch = 0;
159
- LM_GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
160
- // smallest first, because it's easier to split this way;
161
- // starting from the end to pop in constant time.
162
- for (size_t i = seq.size(); i-- > 0;) {
163
- llama_sbatch_seq & s = seq[i];
164
- LM_GGML_ASSERT(s.length > 0);
165
- if (length == 0) {
166
- length = s.length < n_ubatch ? s.length : n_ubatch;
167
- }
168
- add_seq_to_ubatch(ubatch, s, length);
169
- n_tokens_in_ubatch += length;
170
- // shared prompts can't be mixed with any of their sequences,
171
- // so it's safer to compute them in their own ubatch
172
- if (s.n_seq_id > 1) { break; }
173
- // stop when there isn't enough space for another sequence
174
- if (length + n_tokens_in_ubatch > n_ubatch) { break; }
175
- }
176
- }
177
- return ubatch;
178
- }
179
-
180
- llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
181
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
182
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
183
- if (!seq.empty()) {
184
- llama_sbatch_seq & s = seq[seq.size() - 1];
185
- size_t length = s.length < n_ubatch ? s.length : n_ubatch;
186
- LM_GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
187
- add_seq_to_ubatch(ubatch, s, length);
188
- }
189
- return ubatch;
190
- }
191
-
192
- void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
193
- LM_GGML_ASSERT(batch.n_tokens >= 0);
194
- this->batch = &batch;
195
- this->n_embd = n_embd;
196
- this->logits_all = logits_all;
197
-
198
- n_tokens = batch.n_tokens;
199
- ids.resize(n_tokens);
200
- out_ids.clear();
201
- // TODO: reserve out_ids and seq
202
-
203
- for (size_t i = 0; i < n_tokens; ++i) {
204
- ids[i] = i;
205
- }
206
- if (simple_split) {
207
- seq.resize(1);
208
- llama_sbatch_seq & s = seq[0];
209
- s.n_seq_id = 0;
210
- s.seq_id = nullptr;
211
- s.offset = 0;
212
- s.length = n_tokens;
213
- return;
214
- }
215
- std::sort(ids.begin(), ids.end(),
216
- [&batch](size_t a, size_t b) {
217
- int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
218
- int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
219
- // sort by seq_id, then by pos
220
- if (n_seq_a == n_seq_b) {
221
- if (batch.seq_id) {
222
- for (int32_t i = 0; i < n_seq_a; ++i) {
223
- llama_seq_id seq_id_a = batch.seq_id[a][i];
224
- llama_seq_id seq_id_b = batch.seq_id[b][i];
225
- // smaller seq_ids go first
226
- if (seq_id_a != seq_id_b) {
227
- return seq_id_a < seq_id_b;
228
- }
229
- }
230
- }
231
- // when all else is equal, sort by pos
232
- if (batch.pos) {
233
- return batch.pos[a] < batch.pos[b];
234
- }
235
- // no pos, sort by id
236
- return a < b;
237
- }
238
- // shared prompts go first
239
- return n_seq_a > n_seq_b;
240
- }
241
- );
242
- // init seq
243
- llama_sbatch_seq * last_seq = nullptr;
244
-
245
- for (size_t i = 0; i < n_tokens; ++i) {
246
- const size_t bi = ids[i];
247
- const int32_t n_seqs = batch.n_seq_id[bi];
248
- llama_seq_id * seq_ids = batch.seq_id[bi];
249
- if (last_seq != nullptr) {
250
- bool same = n_seqs == last_seq->n_seq_id;
251
- for (int32_t j = 0; same && j < n_seqs; ++j) {
252
- if (seq_ids[j] != last_seq->seq_id[j]) {
253
- same = false;
254
- }
255
- }
256
- if (same) {
257
- last_seq->length += 1;
258
- continue;
259
- }
260
- }
261
- llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
262
- seq.push_back(new_seq);
263
- last_seq = &seq.back();
264
- }
265
- // keep shared prompts first at the end, then sort by length descending.
266
- std::sort(seq.begin(), seq.end(),
267
- [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
268
- if (a.n_seq_id == b.n_seq_id) {
269
- return a.length > b.length;
270
- }
271
- return a.n_seq_id < b.n_seq_id;
272
- }
273
- );
274
- }
275
-
276
- llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
277
- batch = in_batch;
278
- LM_GGML_ASSERT(batch.n_tokens > 0);
279
- if (!batch.pos) {
280
- pos.resize(batch.n_tokens);
281
- for (int32_t i = 0; i < batch.n_tokens; i++) {
282
- pos[i] = i + p0;
283
- }
284
- batch.pos = pos.data();
285
- }
286
- if (!batch.n_seq_id) {
287
- n_seq_id.resize(batch.n_tokens);
288
- for (int32_t i = 0; i < batch.n_tokens; i++) {
289
- n_seq_id[i] = seq_id_0.size();
290
- }
291
- batch.n_seq_id = n_seq_id.data();
292
- }
293
- if (!batch.seq_id) {
294
- seq_id.resize(batch.n_tokens + 1);
295
- seq_id[batch.n_tokens] = NULL;
296
- for (int32_t i = 0; i < batch.n_tokens; i++) {
297
- seq_id[i] = seq_id_0.data();
298
- }
299
- batch.seq_id = seq_id.data();
300
- }
301
- if (!batch.logits) {
302
- logits.resize(batch.n_tokens);
303
- logits[logits.size() - 1] = true;
304
- batch.logits = logits.data();
305
- }
306
- }
307
-
308
- //
309
- // interface implementation
310
- //
311
-
312
- struct llama_batch llama_batch_get_one(
313
- llama_token * tokens,
314
- int32_t n_tokens) {
315
- return {
316
- /*n_tokens =*/ n_tokens,
317
- /*tokens =*/ tokens,
318
- /*embd =*/ nullptr,
319
- /*pos =*/ nullptr,
320
- /*n_seq_id =*/ nullptr,
321
- /*seq_id =*/ nullptr,
322
- /*logits =*/ nullptr,
323
- };
324
- }
325
-
326
- struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
327
- llama_batch batch = {
328
- /*n_tokens =*/ 0,
329
- /*tokens =*/ nullptr,
330
- /*embd =*/ nullptr,
331
- /*pos =*/ nullptr,
332
- /*n_seq_id =*/ nullptr,
333
- /*seq_id =*/ nullptr,
334
- /*logits =*/ nullptr,
335
- };
336
-
337
- if (embd) {
338
- batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
339
- } else {
340
- batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
341
- }
342
-
343
- batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
344
- batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
345
- batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
346
- for (int i = 0; i < n_tokens_alloc; ++i) {
347
- batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
348
- }
349
- batch.seq_id[n_tokens_alloc] = nullptr;
350
-
351
- batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
352
-
353
- return batch;
354
- }
355
-
356
- void llama_batch_free(struct llama_batch batch) {
357
- if (batch.token) free(batch.token);
358
- if (batch.embd) free(batch.embd);
359
- if (batch.pos) free(batch.pos);
360
- if (batch.n_seq_id) free(batch.n_seq_id);
361
- if (batch.seq_id) {
362
- for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
363
- free(batch.seq_id[i]);
364
- }
365
- free(batch.seq_id);
366
- }
367
- if (batch.logits) free(batch.logits);
368
- }
1
+ #include "llama-batch.h"
2
+
3
+ #include <cstring>
4
+ #include <algorithm>
5
+
6
+ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
7
+ // clear empty sequences
8
+ // the previous ubatch is assumed to be gone,
9
+ // so nothing should refer to values in these sequences anymore.
10
+ for (size_t i = seq.size(); i-- > 0;) {
11
+ if (seq[i].length == 0) {
12
+ seq.pop_back();
13
+ } else {
14
+ break;
15
+ }
16
+ }
17
+ ubatch_token.resize(!has_embd ? n_ubatch : 0);
18
+ ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
19
+ ubatch_pos.resize(n_ubatch);
20
+ ubatch_n_seq_id.resize(n_ubatch);
21
+ ubatch_seq_id.resize(n_ubatch);
22
+ ubatch_output.resize(n_ubatch);
23
+ llama_ubatch ubatch = {
24
+ /*equal_seqs =*/ true,
25
+ /*n_tokens =*/ 0,
26
+ /*n_seq_tokens =*/ 0,
27
+ /*n_seqs =*/ 0,
28
+ /*token =*/ !has_embd ? ubatch_token.data() : nullptr,
29
+ /*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
30
+ /*pos =*/ ubatch_pos.data(),
31
+ /*n_seq_id =*/ ubatch_n_seq_id.data(),
32
+ /*seq_id =*/ ubatch_seq_id.data(),
33
+ /*output =*/ ubatch_output.data(),
34
+ };
35
+ return ubatch;
36
+ }
37
+
38
+ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
39
+ LM_GGML_ASSERT(batch != nullptr);
40
+ LM_GGML_ASSERT(length <= seq.length);
41
+ // Can only add sequences of equal lengths to a batch,
42
+ // otherwise it isn't clear to which sequence a token belongs
43
+ LM_GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
44
+ LM_GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
45
+ // NOTE: loops are separated for cache-friendliness
46
+ if (batch->token) {
47
+ if (ubatch.equal_seqs) {
48
+ for (size_t i = 0; i < length; ++i) {
49
+ ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
50
+ }
51
+ } else {
52
+ // simple split
53
+ ubatch.token = batch->token + seq.offset;
54
+ }
55
+ } else {
56
+ ubatch.token = nullptr;
57
+ }
58
+ if (batch->embd) {
59
+ if (ubatch.equal_seqs) {
60
+ for (size_t i = 0; i < length; ++i) {
61
+ memcpy(
62
+ ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
63
+ batch->embd + (n_embd * ids[seq.offset + i]),
64
+ n_embd * sizeof(float)
65
+ );
66
+ }
67
+ } else {
68
+ // simple split
69
+ ubatch.embd = batch->embd + (n_embd * seq.offset);
70
+ }
71
+ } else {
72
+ ubatch.embd = nullptr;
73
+ }
74
+ if (ubatch.equal_seqs) {
75
+ for (size_t i = 0; i < length; ++i) {
76
+ ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
77
+ }
78
+ } else {
79
+ // simple split
80
+ ubatch.pos = batch->pos + seq.offset;
81
+ }
82
+ if (ubatch.equal_seqs) {
83
+ ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
84
+ if (seq.seq_id) {
85
+ ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
86
+ }
87
+ } else {
88
+ // simple split
89
+ if (batch->n_seq_id) {
90
+ ubatch.n_seq_id = batch->n_seq_id + seq.offset;
91
+ } else {
92
+ for (size_t i = 0; i < length; ++i) {
93
+ ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
94
+ }
95
+ }
96
+ if (batch->seq_id) {
97
+ ubatch.seq_id = batch->seq_id + seq.offset;
98
+ }
99
+ }
100
+ if (logits_all) {
101
+ for (size_t i = 0; i < length; ++i) {
102
+ ubatch.output[ubatch.n_tokens + i] = 1;
103
+ out_ids.push_back(ids[seq.offset + i]);
104
+ }
105
+ } else if (batch->logits) {
106
+ if (ubatch.equal_seqs) {
107
+ for (size_t i = 0; i < length; ++i) {
108
+ size_t id = ids[seq.offset + i];
109
+ int8_t is_output = batch->logits[id];
110
+ ubatch.output[ubatch.n_tokens + i] = is_output;
111
+ if (is_output) { out_ids.push_back(id); }
112
+ }
113
+ } else {
114
+ // simple split
115
+ ubatch.output = batch->logits + seq.offset;
116
+ for (size_t i = 0; i < length; ++i) {
117
+ if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
118
+ }
119
+ }
120
+ } else {
121
+ // only get last output
122
+ for (size_t i = 0; i < length; ++i) {
123
+ size_t id = ids[seq.offset + i];
124
+ int8_t is_last = id == ids.size() - 1;
125
+ ubatch.output[ubatch.n_tokens + i] = is_last;
126
+ if (is_last) { out_ids.push_back(id); }
127
+ }
128
+ }
129
+ if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
130
+ ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
131
+ }
132
+ ubatch.n_tokens += length;
133
+ ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
134
+ seq.offset += length;
135
+ seq.length -= length;
136
+ n_tokens -= length;
137
+ LM_GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
138
+ }
139
+
140
+ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
141
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
142
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
143
+ ubatch.equal_seqs = false;
144
+ if (!seq.empty()) {
145
+ llama_sbatch_seq & s = seq[0];
146
+ size_t length = s.length < n_ubatch ? s.length : n_ubatch;
147
+ LM_GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
148
+ add_seq_to_ubatch(ubatch, s, length);
149
+ }
150
+ return ubatch;
151
+ }
152
+
153
+ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
154
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
155
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
156
+ if (!seq.empty()) {
157
+ size_t length = 0;
158
+ size_t n_tokens_in_ubatch = 0;
159
+ LM_GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
160
+ // smallest first, because it's easier to split this way;
161
+ // starting from the end to pop in constant time.
162
+ for (size_t i = seq.size(); i-- > 0;) {
163
+ llama_sbatch_seq & s = seq[i];
164
+ LM_GGML_ASSERT(s.length > 0);
165
+ if (length == 0) {
166
+ length = s.length < n_ubatch ? s.length : n_ubatch;
167
+ }
168
+ add_seq_to_ubatch(ubatch, s, length);
169
+ n_tokens_in_ubatch += length;
170
+ // shared prompts can't be mixed with any of their sequences,
171
+ // so it's safer to compute them in their own ubatch
172
+ if (s.n_seq_id > 1) { break; }
173
+ // stop when there isn't enough space for another sequence
174
+ if (length + n_tokens_in_ubatch > n_ubatch) { break; }
175
+ }
176
+ }
177
+ return ubatch;
178
+ }
179
+
180
+ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
181
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
182
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
183
+ if (!seq.empty()) {
184
+ llama_sbatch_seq & s = seq[seq.size() - 1];
185
+ size_t length = s.length < n_ubatch ? s.length : n_ubatch;
186
+ LM_GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
187
+ add_seq_to_ubatch(ubatch, s, length);
188
+ }
189
+ return ubatch;
190
+ }
191
+
192
+ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
193
+ LM_GGML_ASSERT(batch.n_tokens >= 0);
194
+ this->batch = &batch;
195
+ this->n_embd = n_embd;
196
+ this->logits_all = logits_all;
197
+
198
+ n_tokens = batch.n_tokens;
199
+ ids.resize(n_tokens);
200
+ out_ids.clear();
201
+ // TODO: reserve out_ids and seq
202
+
203
+ for (size_t i = 0; i < n_tokens; ++i) {
204
+ ids[i] = i;
205
+ }
206
+ if (simple_split) {
207
+ seq.resize(1);
208
+ llama_sbatch_seq & s = seq[0];
209
+ s.n_seq_id = 0;
210
+ s.seq_id = nullptr;
211
+ s.offset = 0;
212
+ s.length = n_tokens;
213
+ return;
214
+ }
215
+ std::sort(ids.begin(), ids.end(),
216
+ [&batch](size_t a, size_t b) {
217
+ int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
218
+ int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
219
+ // sort by seq_id, then by pos
220
+ if (n_seq_a == n_seq_b) {
221
+ if (batch.seq_id) {
222
+ for (int32_t i = 0; i < n_seq_a; ++i) {
223
+ llama_seq_id seq_id_a = batch.seq_id[a][i];
224
+ llama_seq_id seq_id_b = batch.seq_id[b][i];
225
+ // smaller seq_ids go first
226
+ if (seq_id_a != seq_id_b) {
227
+ return seq_id_a < seq_id_b;
228
+ }
229
+ }
230
+ }
231
+ // when all else is equal, sort by pos
232
+ if (batch.pos) {
233
+ return batch.pos[a] < batch.pos[b];
234
+ }
235
+ // no pos, sort by id
236
+ return a < b;
237
+ }
238
+ // shared prompts go first
239
+ return n_seq_a > n_seq_b;
240
+ }
241
+ );
242
+ // init seq
243
+ llama_sbatch_seq * last_seq = nullptr;
244
+
245
+ for (size_t i = 0; i < n_tokens; ++i) {
246
+ const size_t bi = ids[i];
247
+ const int32_t n_seqs = batch.n_seq_id[bi];
248
+ llama_seq_id * seq_ids = batch.seq_id[bi];
249
+ if (last_seq != nullptr) {
250
+ bool same = n_seqs == last_seq->n_seq_id;
251
+ for (int32_t j = 0; same && j < n_seqs; ++j) {
252
+ if (seq_ids[j] != last_seq->seq_id[j]) {
253
+ same = false;
254
+ }
255
+ }
256
+ if (same) {
257
+ last_seq->length += 1;
258
+ continue;
259
+ }
260
+ }
261
+ llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
262
+ seq.push_back(new_seq);
263
+ last_seq = &seq.back();
264
+ }
265
+ // keep shared prompts first at the end, then sort by length descending.
266
+ std::sort(seq.begin(), seq.end(),
267
+ [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
268
+ if (a.n_seq_id == b.n_seq_id) {
269
+ return a.length > b.length;
270
+ }
271
+ return a.n_seq_id < b.n_seq_id;
272
+ }
273
+ );
274
+ }
275
+
276
+ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
277
+ batch = in_batch;
278
+ LM_GGML_ASSERT(batch.n_tokens > 0);
279
+ if (!batch.pos) {
280
+ pos.resize(batch.n_tokens);
281
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
282
+ pos[i] = i + p0;
283
+ }
284
+ batch.pos = pos.data();
285
+ }
286
+ if (!batch.n_seq_id) {
287
+ n_seq_id.resize(batch.n_tokens);
288
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
289
+ n_seq_id[i] = seq_id_0.size();
290
+ }
291
+ batch.n_seq_id = n_seq_id.data();
292
+ }
293
+ if (!batch.seq_id) {
294
+ seq_id.resize(batch.n_tokens + 1);
295
+ seq_id[batch.n_tokens] = NULL;
296
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
297
+ seq_id[i] = seq_id_0.data();
298
+ }
299
+ batch.seq_id = seq_id.data();
300
+ }
301
+ if (!batch.logits) {
302
+ logits.resize(batch.n_tokens);
303
+ logits[logits.size() - 1] = true;
304
+ batch.logits = logits.data();
305
+ }
306
+ }
307
+
308
+ //
309
+ // interface implementation
310
+ //
311
+
312
+ struct llama_batch llama_batch_get_one(
313
+ llama_token * tokens,
314
+ int32_t n_tokens) {
315
+ return {
316
+ /*n_tokens =*/ n_tokens,
317
+ /*tokens =*/ tokens,
318
+ /*embd =*/ nullptr,
319
+ /*pos =*/ nullptr,
320
+ /*n_seq_id =*/ nullptr,
321
+ /*seq_id =*/ nullptr,
322
+ /*logits =*/ nullptr,
323
+ };
324
+ }
325
+
326
+ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
327
+ llama_batch batch = {
328
+ /*n_tokens =*/ 0,
329
+ /*tokens =*/ nullptr,
330
+ /*embd =*/ nullptr,
331
+ /*pos =*/ nullptr,
332
+ /*n_seq_id =*/ nullptr,
333
+ /*seq_id =*/ nullptr,
334
+ /*logits =*/ nullptr,
335
+ };
336
+
337
+ if (embd) {
338
+ batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
339
+ } else {
340
+ batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
341
+ }
342
+
343
+ batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
344
+ batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
345
+ batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
346
+ for (int i = 0; i < n_tokens_alloc; ++i) {
347
+ batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
348
+ }
349
+ batch.seq_id[n_tokens_alloc] = nullptr;
350
+
351
+ batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
352
+
353
+ return batch;
354
+ }
355
+
356
+ void llama_batch_free(struct llama_batch batch) {
357
+ if (batch.token) free(batch.token);
358
+ if (batch.embd) free(batch.embd);
359
+ if (batch.pos) free(batch.pos);
360
+ if (batch.n_seq_id) free(batch.n_seq_id);
361
+ if (batch.seq_id) {
362
+ for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
363
+ free(batch.seq_id[i]);
364
+ }
365
+ free(batch.seq_id);
366
+ }
367
+ if (batch.logits) free(batch.logits);
368
+ }