cui-llama.rn 1.4.0 → 1.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. package/README.md +4 -23
  2. package/android/build.gradle +12 -3
  3. package/android/src/main/CMakeLists.txt +13 -7
  4. package/android/src/main/java/com/rnllama/LlamaContext.java +27 -20
  5. package/android/src/main/java/com/rnllama/RNLlama.java +5 -1
  6. package/android/src/main/jni.cpp +15 -12
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  11. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  12. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  15. package/cpp/README.md +1 -1
  16. package/cpp/common.cpp +158 -267
  17. package/cpp/common.h +46 -12
  18. package/cpp/ggml-alloc.c +1042 -1037
  19. package/cpp/ggml-backend-impl.h +255 -256
  20. package/cpp/ggml-backend-reg.cpp +582 -582
  21. package/cpp/ggml-backend.cpp +2002 -2002
  22. package/cpp/ggml-backend.h +354 -352
  23. package/cpp/ggml-common.h +1853 -1853
  24. package/cpp/ggml-cpp.h +39 -39
  25. package/cpp/ggml-cpu-aarch64.cpp +4247 -4247
  26. package/cpp/ggml-cpu-aarch64.h +8 -8
  27. package/cpp/ggml-cpu-impl.h +386 -386
  28. package/cpp/ggml-cpu-quants.c +10920 -10839
  29. package/cpp/ggml-cpu-traits.cpp +36 -36
  30. package/cpp/ggml-cpu-traits.h +38 -38
  31. package/cpp/ggml-cpu.c +329 -60
  32. package/cpp/ggml-cpu.cpp +10 -2
  33. package/cpp/ggml-cpu.h +135 -135
  34. package/cpp/ggml-impl.h +567 -567
  35. package/cpp/ggml-metal-impl.h +17 -17
  36. package/cpp/ggml-metal.m +4884 -4884
  37. package/cpp/ggml-quants.c +5238 -5238
  38. package/cpp/ggml-threading.h +14 -14
  39. package/cpp/ggml.c +6514 -6448
  40. package/cpp/ggml.h +2194 -2163
  41. package/cpp/gguf.cpp +1329 -1325
  42. package/cpp/gguf.h +202 -202
  43. package/cpp/json-schema-to-grammar.cpp +1045 -1045
  44. package/cpp/json-schema-to-grammar.h +8 -8
  45. package/cpp/json.hpp +24766 -24766
  46. package/cpp/llama-adapter.cpp +347 -346
  47. package/cpp/llama-adapter.h +74 -73
  48. package/cpp/llama-arch.cpp +1487 -1434
  49. package/cpp/llama-arch.h +400 -395
  50. package/cpp/llama-batch.cpp +368 -368
  51. package/cpp/llama-batch.h +88 -88
  52. package/cpp/llama-chat.cpp +578 -567
  53. package/cpp/llama-chat.h +52 -51
  54. package/cpp/llama-context.cpp +1775 -1771
  55. package/cpp/llama-context.h +128 -128
  56. package/cpp/llama-cparams.cpp +1 -1
  57. package/cpp/llama-cparams.h +37 -37
  58. package/cpp/llama-cpp.h +30 -30
  59. package/cpp/llama-grammar.cpp +1139 -1139
  60. package/cpp/llama-grammar.h +143 -143
  61. package/cpp/llama-hparams.cpp +71 -71
  62. package/cpp/llama-hparams.h +139 -140
  63. package/cpp/llama-impl.cpp +167 -167
  64. package/cpp/llama-impl.h +61 -61
  65. package/cpp/llama-kv-cache.cpp +718 -718
  66. package/cpp/llama-kv-cache.h +218 -218
  67. package/cpp/llama-mmap.cpp +2 -1
  68. package/cpp/llama-mmap.h +67 -67
  69. package/cpp/llama-model-loader.cpp +1124 -1011
  70. package/cpp/llama-model-loader.h +167 -158
  71. package/cpp/llama-model.cpp +3997 -2202
  72. package/cpp/llama-model.h +370 -391
  73. package/cpp/llama-sampling.cpp +2408 -2406
  74. package/cpp/llama-sampling.h +32 -48
  75. package/cpp/llama-vocab.cpp +3247 -1982
  76. package/cpp/llama-vocab.h +125 -182
  77. package/cpp/llama.cpp +416 -2886
  78. package/cpp/llama.h +1323 -1285
  79. package/cpp/log.cpp +401 -401
  80. package/cpp/log.h +121 -121
  81. package/cpp/rn-llama.cpp +822 -0
  82. package/cpp/rn-llama.h +123 -0
  83. package/cpp/rn-llama.hpp +18 -12
  84. package/cpp/sampling.cpp +505 -500
  85. package/cpp/sgemm.cpp +2597 -2597
  86. package/cpp/speculative.cpp +277 -274
  87. package/cpp/speculative.h +28 -28
  88. package/cpp/unicode.cpp +2 -3
  89. package/ios/CMakeLists.txt +99 -0
  90. package/ios/RNLlama.h +5 -1
  91. package/ios/RNLlama.mm +2 -2
  92. package/ios/RNLlamaContext.h +8 -1
  93. package/ios/RNLlamaContext.mm +15 -11
  94. package/ios/rnllama.xcframework/Info.plist +74 -0
  95. package/jest/mock.js +3 -2
  96. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  97. package/lib/commonjs/index.js +4 -2
  98. package/lib/commonjs/index.js.map +1 -1
  99. package/lib/module/NativeRNLlama.js.map +1 -1
  100. package/lib/module/index.js +4 -2
  101. package/lib/module/index.js.map +1 -1
  102. package/lib/typescript/NativeRNLlama.d.ts +5 -1
  103. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  104. package/lib/typescript/index.d.ts.map +1 -1
  105. package/llama-rn.podspec +8 -2
  106. package/package.json +5 -2
  107. package/src/NativeRNLlama.ts +5 -1
  108. package/src/index.ts +9 -2
@@ -1,368 +1,368 @@
1
- #include "llama-batch.h"
2
-
3
- #include <cstring>
4
- #include <algorithm>
5
-
6
- llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
7
- // clear empty sequences
8
- // the previous ubatch is assumed to be gone,
9
- // so nothing should refer to values in these sequences anymore.
10
- for (size_t i = seq.size(); i-- > 0;) {
11
- if (seq[i].length == 0) {
12
- seq.pop_back();
13
- } else {
14
- break;
15
- }
16
- }
17
- ubatch_token.resize(!has_embd ? n_ubatch : 0);
18
- ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
19
- ubatch_pos.resize(n_ubatch);
20
- ubatch_n_seq_id.resize(n_ubatch);
21
- ubatch_seq_id.resize(n_ubatch);
22
- ubatch_output.resize(n_ubatch);
23
- llama_ubatch ubatch = {
24
- /*equal_seqs =*/ true,
25
- /*n_tokens =*/ 0,
26
- /*n_seq_tokens =*/ 0,
27
- /*n_seqs =*/ 0,
28
- /*token =*/ !has_embd ? ubatch_token.data() : nullptr,
29
- /*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
30
- /*pos =*/ ubatch_pos.data(),
31
- /*n_seq_id =*/ ubatch_n_seq_id.data(),
32
- /*seq_id =*/ ubatch_seq_id.data(),
33
- /*output =*/ ubatch_output.data(),
34
- };
35
- return ubatch;
36
- }
37
-
38
- void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
39
- LM_GGML_ASSERT(batch != nullptr);
40
- LM_GGML_ASSERT(length <= seq.length);
41
- // Can only add sequences of equal lengths to a batch,
42
- // otherwise it isn't clear to which sequence a token belongs
43
- LM_GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
44
- LM_GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
45
- // NOTE: loops are separated for cache-friendliness
46
- if (batch->token) {
47
- if (ubatch.equal_seqs) {
48
- for (size_t i = 0; i < length; ++i) {
49
- ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
50
- }
51
- } else {
52
- // simple split
53
- ubatch.token = batch->token + seq.offset;
54
- }
55
- } else {
56
- ubatch.token = nullptr;
57
- }
58
- if (batch->embd) {
59
- if (ubatch.equal_seqs) {
60
- for (size_t i = 0; i < length; ++i) {
61
- memcpy(
62
- ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
63
- batch->embd + (n_embd * ids[seq.offset + i]),
64
- n_embd * sizeof(float)
65
- );
66
- }
67
- } else {
68
- // simple split
69
- ubatch.embd = batch->embd + (n_embd * seq.offset);
70
- }
71
- } else {
72
- ubatch.embd = nullptr;
73
- }
74
- if (ubatch.equal_seqs) {
75
- for (size_t i = 0; i < length; ++i) {
76
- ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
77
- }
78
- } else {
79
- // simple split
80
- ubatch.pos = batch->pos + seq.offset;
81
- }
82
- if (ubatch.equal_seqs) {
83
- ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
84
- if (seq.seq_id) {
85
- ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
86
- }
87
- } else {
88
- // simple split
89
- if (batch->n_seq_id) {
90
- ubatch.n_seq_id = batch->n_seq_id + seq.offset;
91
- } else {
92
- for (size_t i = 0; i < length; ++i) {
93
- ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
94
- }
95
- }
96
- if (batch->seq_id) {
97
- ubatch.seq_id = batch->seq_id + seq.offset;
98
- }
99
- }
100
- if (logits_all) {
101
- for (size_t i = 0; i < length; ++i) {
102
- ubatch.output[ubatch.n_tokens + i] = 1;
103
- out_ids.push_back(ids[seq.offset + i]);
104
- }
105
- } else if (batch->logits) {
106
- if (ubatch.equal_seqs) {
107
- for (size_t i = 0; i < length; ++i) {
108
- size_t id = ids[seq.offset + i];
109
- int8_t is_output = batch->logits[id];
110
- ubatch.output[ubatch.n_tokens + i] = is_output;
111
- if (is_output) { out_ids.push_back(id); }
112
- }
113
- } else {
114
- // simple split
115
- ubatch.output = batch->logits + seq.offset;
116
- for (size_t i = 0; i < length; ++i) {
117
- if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
118
- }
119
- }
120
- } else {
121
- // only get last output
122
- for (size_t i = 0; i < length; ++i) {
123
- size_t id = ids[seq.offset + i];
124
- int8_t is_last = id == ids.size() - 1;
125
- ubatch.output[ubatch.n_tokens + i] = is_last;
126
- if (is_last) { out_ids.push_back(id); }
127
- }
128
- }
129
- if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
130
- ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
131
- }
132
- ubatch.n_tokens += length;
133
- ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
134
- seq.offset += length;
135
- seq.length -= length;
136
- n_tokens -= length;
137
- LM_GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
138
- }
139
-
140
- llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
141
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
142
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
143
- ubatch.equal_seqs = false;
144
- if (!seq.empty()) {
145
- llama_sbatch_seq & s = seq[0];
146
- size_t length = s.length < n_ubatch ? s.length : n_ubatch;
147
- LM_GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
148
- add_seq_to_ubatch(ubatch, s, length);
149
- }
150
- return ubatch;
151
- }
152
-
153
- llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
154
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
155
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
156
- if (!seq.empty()) {
157
- size_t length = 0;
158
- size_t n_tokens_in_ubatch = 0;
159
- LM_GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
160
- // smallest first, because it's easier to split this way;
161
- // starting from the end to pop in constant time.
162
- for (size_t i = seq.size(); i-- > 0;) {
163
- llama_sbatch_seq & s = seq[i];
164
- LM_GGML_ASSERT(s.length > 0);
165
- if (length == 0) {
166
- length = s.length < n_ubatch ? s.length : n_ubatch;
167
- }
168
- add_seq_to_ubatch(ubatch, s, length);
169
- n_tokens_in_ubatch += length;
170
- // shared prompts can't be mixed with any of their sequences,
171
- // so it's safer to compute them in their own ubatch
172
- if (s.n_seq_id > 1) { break; }
173
- // stop when there isn't enough space for another sequence
174
- if (length + n_tokens_in_ubatch > n_ubatch) { break; }
175
- }
176
- }
177
- return ubatch;
178
- }
179
-
180
- llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
181
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
182
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
183
- if (!seq.empty()) {
184
- llama_sbatch_seq & s = seq[seq.size() - 1];
185
- size_t length = s.length < n_ubatch ? s.length : n_ubatch;
186
- LM_GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
187
- add_seq_to_ubatch(ubatch, s, length);
188
- }
189
- return ubatch;
190
- }
191
-
192
- void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
193
- LM_GGML_ASSERT(batch.n_tokens >= 0);
194
- this->batch = &batch;
195
- this->n_embd = n_embd;
196
- this->logits_all = logits_all;
197
-
198
- n_tokens = batch.n_tokens;
199
- ids.resize(n_tokens);
200
- out_ids.clear();
201
- // TODO: reserve out_ids and seq
202
-
203
- for (size_t i = 0; i < n_tokens; ++i) {
204
- ids[i] = i;
205
- }
206
- if (simple_split) {
207
- seq.resize(1);
208
- llama_sbatch_seq & s = seq[0];
209
- s.n_seq_id = 0;
210
- s.seq_id = nullptr;
211
- s.offset = 0;
212
- s.length = n_tokens;
213
- return;
214
- }
215
- std::sort(ids.begin(), ids.end(),
216
- [&batch](size_t a, size_t b) {
217
- int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
218
- int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
219
- // sort by seq_id, then by pos
220
- if (n_seq_a == n_seq_b) {
221
- if (batch.seq_id) {
222
- for (int32_t i = 0; i < n_seq_a; ++i) {
223
- llama_seq_id seq_id_a = batch.seq_id[a][i];
224
- llama_seq_id seq_id_b = batch.seq_id[b][i];
225
- // smaller seq_ids go first
226
- if (seq_id_a != seq_id_b) {
227
- return seq_id_a < seq_id_b;
228
- }
229
- }
230
- }
231
- // when all else is equal, sort by pos
232
- if (batch.pos) {
233
- return batch.pos[a] < batch.pos[b];
234
- }
235
- // no pos, sort by id
236
- return a < b;
237
- }
238
- // shared prompts go first
239
- return n_seq_a > n_seq_b;
240
- }
241
- );
242
- // init seq
243
- llama_sbatch_seq * last_seq = nullptr;
244
-
245
- for (size_t i = 0; i < n_tokens; ++i) {
246
- const size_t bi = ids[i];
247
- const int32_t n_seqs = batch.n_seq_id[bi];
248
- llama_seq_id * seq_ids = batch.seq_id[bi];
249
- if (last_seq != nullptr) {
250
- bool same = n_seqs == last_seq->n_seq_id;
251
- for (int32_t j = 0; same && j < n_seqs; ++j) {
252
- if (seq_ids[j] != last_seq->seq_id[j]) {
253
- same = false;
254
- }
255
- }
256
- if (same) {
257
- last_seq->length += 1;
258
- continue;
259
- }
260
- }
261
- llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
262
- seq.push_back(new_seq);
263
- last_seq = &seq.back();
264
- }
265
- // keep shared prompts first at the end, then sort by length descending.
266
- std::sort(seq.begin(), seq.end(),
267
- [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
268
- if (a.n_seq_id == b.n_seq_id) {
269
- return a.length > b.length;
270
- }
271
- return a.n_seq_id < b.n_seq_id;
272
- }
273
- );
274
- }
275
-
276
- llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
277
- batch = in_batch;
278
- LM_GGML_ASSERT(batch.n_tokens > 0);
279
- if (!batch.pos) {
280
- pos.resize(batch.n_tokens);
281
- for (int32_t i = 0; i < batch.n_tokens; i++) {
282
- pos[i] = i + p0;
283
- }
284
- batch.pos = pos.data();
285
- }
286
- if (!batch.n_seq_id) {
287
- n_seq_id.resize(batch.n_tokens);
288
- for (int32_t i = 0; i < batch.n_tokens; i++) {
289
- n_seq_id[i] = seq_id_0.size();
290
- }
291
- batch.n_seq_id = n_seq_id.data();
292
- }
293
- if (!batch.seq_id) {
294
- seq_id.resize(batch.n_tokens + 1);
295
- seq_id[batch.n_tokens] = NULL;
296
- for (int32_t i = 0; i < batch.n_tokens; i++) {
297
- seq_id[i] = seq_id_0.data();
298
- }
299
- batch.seq_id = seq_id.data();
300
- }
301
- if (!batch.logits) {
302
- logits.resize(batch.n_tokens);
303
- logits[logits.size() - 1] = true;
304
- batch.logits = logits.data();
305
- }
306
- }
307
-
308
- //
309
- // interface implementation
310
- //
311
-
312
- struct llama_batch llama_batch_get_one(
313
- llama_token * tokens,
314
- int32_t n_tokens) {
315
- return {
316
- /*n_tokens =*/ n_tokens,
317
- /*tokens =*/ tokens,
318
- /*embd =*/ nullptr,
319
- /*pos =*/ nullptr,
320
- /*n_seq_id =*/ nullptr,
321
- /*seq_id =*/ nullptr,
322
- /*logits =*/ nullptr,
323
- };
324
- }
325
-
326
- struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
327
- llama_batch batch = {
328
- /*n_tokens =*/ 0,
329
- /*tokens =*/ nullptr,
330
- /*embd =*/ nullptr,
331
- /*pos =*/ nullptr,
332
- /*n_seq_id =*/ nullptr,
333
- /*seq_id =*/ nullptr,
334
- /*logits =*/ nullptr,
335
- };
336
-
337
- if (embd) {
338
- batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
339
- } else {
340
- batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
341
- }
342
-
343
- batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
344
- batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
345
- batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
346
- for (int i = 0; i < n_tokens_alloc; ++i) {
347
- batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
348
- }
349
- batch.seq_id[n_tokens_alloc] = nullptr;
350
-
351
- batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
352
-
353
- return batch;
354
- }
355
-
356
- void llama_batch_free(struct llama_batch batch) {
357
- if (batch.token) free(batch.token);
358
- if (batch.embd) free(batch.embd);
359
- if (batch.pos) free(batch.pos);
360
- if (batch.n_seq_id) free(batch.n_seq_id);
361
- if (batch.seq_id) {
362
- for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
363
- free(batch.seq_id[i]);
364
- }
365
- free(batch.seq_id);
366
- }
367
- if (batch.logits) free(batch.logits);
368
- }
1
+ #include "llama-batch.h"
2
+
3
+ #include <cstring>
4
+ #include <algorithm>
5
+
6
+ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
7
+ // clear empty sequences
8
+ // the previous ubatch is assumed to be gone,
9
+ // so nothing should refer to values in these sequences anymore.
10
+ for (size_t i = seq.size(); i-- > 0;) {
11
+ if (seq[i].length == 0) {
12
+ seq.pop_back();
13
+ } else {
14
+ break;
15
+ }
16
+ }
17
+ ubatch_token.resize(!has_embd ? n_ubatch : 0);
18
+ ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
19
+ ubatch_pos.resize(n_ubatch);
20
+ ubatch_n_seq_id.resize(n_ubatch);
21
+ ubatch_seq_id.resize(n_ubatch);
22
+ ubatch_output.resize(n_ubatch);
23
+ llama_ubatch ubatch = {
24
+ /*equal_seqs =*/ true,
25
+ /*n_tokens =*/ 0,
26
+ /*n_seq_tokens =*/ 0,
27
+ /*n_seqs =*/ 0,
28
+ /*token =*/ !has_embd ? ubatch_token.data() : nullptr,
29
+ /*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
30
+ /*pos =*/ ubatch_pos.data(),
31
+ /*n_seq_id =*/ ubatch_n_seq_id.data(),
32
+ /*seq_id =*/ ubatch_seq_id.data(),
33
+ /*output =*/ ubatch_output.data(),
34
+ };
35
+ return ubatch;
36
+ }
37
+
38
+ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
39
+ LM_GGML_ASSERT(batch != nullptr);
40
+ LM_GGML_ASSERT(length <= seq.length);
41
+ // Can only add sequences of equal lengths to a batch,
42
+ // otherwise it isn't clear to which sequence a token belongs
43
+ LM_GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
44
+ LM_GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
45
+ // NOTE: loops are separated for cache-friendliness
46
+ if (batch->token) {
47
+ if (ubatch.equal_seqs) {
48
+ for (size_t i = 0; i < length; ++i) {
49
+ ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
50
+ }
51
+ } else {
52
+ // simple split
53
+ ubatch.token = batch->token + seq.offset;
54
+ }
55
+ } else {
56
+ ubatch.token = nullptr;
57
+ }
58
+ if (batch->embd) {
59
+ if (ubatch.equal_seqs) {
60
+ for (size_t i = 0; i < length; ++i) {
61
+ memcpy(
62
+ ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
63
+ batch->embd + (n_embd * ids[seq.offset + i]),
64
+ n_embd * sizeof(float)
65
+ );
66
+ }
67
+ } else {
68
+ // simple split
69
+ ubatch.embd = batch->embd + (n_embd * seq.offset);
70
+ }
71
+ } else {
72
+ ubatch.embd = nullptr;
73
+ }
74
+ if (ubatch.equal_seqs) {
75
+ for (size_t i = 0; i < length; ++i) {
76
+ ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
77
+ }
78
+ } else {
79
+ // simple split
80
+ ubatch.pos = batch->pos + seq.offset;
81
+ }
82
+ if (ubatch.equal_seqs) {
83
+ ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
84
+ if (seq.seq_id) {
85
+ ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
86
+ }
87
+ } else {
88
+ // simple split
89
+ if (batch->n_seq_id) {
90
+ ubatch.n_seq_id = batch->n_seq_id + seq.offset;
91
+ } else {
92
+ for (size_t i = 0; i < length; ++i) {
93
+ ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
94
+ }
95
+ }
96
+ if (batch->seq_id) {
97
+ ubatch.seq_id = batch->seq_id + seq.offset;
98
+ }
99
+ }
100
+ if (logits_all) {
101
+ for (size_t i = 0; i < length; ++i) {
102
+ ubatch.output[ubatch.n_tokens + i] = 1;
103
+ out_ids.push_back(ids[seq.offset + i]);
104
+ }
105
+ } else if (batch->logits) {
106
+ if (ubatch.equal_seqs) {
107
+ for (size_t i = 0; i < length; ++i) {
108
+ size_t id = ids[seq.offset + i];
109
+ int8_t is_output = batch->logits[id];
110
+ ubatch.output[ubatch.n_tokens + i] = is_output;
111
+ if (is_output) { out_ids.push_back(id); }
112
+ }
113
+ } else {
114
+ // simple split
115
+ ubatch.output = batch->logits + seq.offset;
116
+ for (size_t i = 0; i < length; ++i) {
117
+ if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
118
+ }
119
+ }
120
+ } else {
121
+ // only get last output
122
+ for (size_t i = 0; i < length; ++i) {
123
+ size_t id = ids[seq.offset + i];
124
+ int8_t is_last = id == ids.size() - 1;
125
+ ubatch.output[ubatch.n_tokens + i] = is_last;
126
+ if (is_last) { out_ids.push_back(id); }
127
+ }
128
+ }
129
+ if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
130
+ ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
131
+ }
132
+ ubatch.n_tokens += length;
133
+ ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
134
+ seq.offset += length;
135
+ seq.length -= length;
136
+ n_tokens -= length;
137
+ LM_GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
138
+ }
139
+
140
+ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
141
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
142
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
143
+ ubatch.equal_seqs = false;
144
+ if (!seq.empty()) {
145
+ llama_sbatch_seq & s = seq[0];
146
+ size_t length = s.length < n_ubatch ? s.length : n_ubatch;
147
+ LM_GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
148
+ add_seq_to_ubatch(ubatch, s, length);
149
+ }
150
+ return ubatch;
151
+ }
152
+
153
+ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
154
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
155
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
156
+ if (!seq.empty()) {
157
+ size_t length = 0;
158
+ size_t n_tokens_in_ubatch = 0;
159
+ LM_GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
160
+ // smallest first, because it's easier to split this way;
161
+ // starting from the end to pop in constant time.
162
+ for (size_t i = seq.size(); i-- > 0;) {
163
+ llama_sbatch_seq & s = seq[i];
164
+ LM_GGML_ASSERT(s.length > 0);
165
+ if (length == 0) {
166
+ length = s.length < n_ubatch ? s.length : n_ubatch;
167
+ }
168
+ add_seq_to_ubatch(ubatch, s, length);
169
+ n_tokens_in_ubatch += length;
170
+ // shared prompts can't be mixed with any of their sequences,
171
+ // so it's safer to compute them in their own ubatch
172
+ if (s.n_seq_id > 1) { break; }
173
+ // stop when there isn't enough space for another sequence
174
+ if (length + n_tokens_in_ubatch > n_ubatch) { break; }
175
+ }
176
+ }
177
+ return ubatch;
178
+ }
179
+
180
+ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
181
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
182
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
183
+ if (!seq.empty()) {
184
+ llama_sbatch_seq & s = seq[seq.size() - 1];
185
+ size_t length = s.length < n_ubatch ? s.length : n_ubatch;
186
+ LM_GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
187
+ add_seq_to_ubatch(ubatch, s, length);
188
+ }
189
+ return ubatch;
190
+ }
191
+
192
+ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
193
+ LM_GGML_ASSERT(batch.n_tokens >= 0);
194
+ this->batch = &batch;
195
+ this->n_embd = n_embd;
196
+ this->logits_all = logits_all;
197
+
198
+ n_tokens = batch.n_tokens;
199
+ ids.resize(n_tokens);
200
+ out_ids.clear();
201
+ // TODO: reserve out_ids and seq
202
+
203
+ for (size_t i = 0; i < n_tokens; ++i) {
204
+ ids[i] = i;
205
+ }
206
+ if (simple_split) {
207
+ seq.resize(1);
208
+ llama_sbatch_seq & s = seq[0];
209
+ s.n_seq_id = 0;
210
+ s.seq_id = nullptr;
211
+ s.offset = 0;
212
+ s.length = n_tokens;
213
+ return;
214
+ }
215
+ std::sort(ids.begin(), ids.end(),
216
+ [&batch](size_t a, size_t b) {
217
+ int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
218
+ int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
219
+ // sort by seq_id, then by pos
220
+ if (n_seq_a == n_seq_b) {
221
+ if (batch.seq_id) {
222
+ for (int32_t i = 0; i < n_seq_a; ++i) {
223
+ llama_seq_id seq_id_a = batch.seq_id[a][i];
224
+ llama_seq_id seq_id_b = batch.seq_id[b][i];
225
+ // smaller seq_ids go first
226
+ if (seq_id_a != seq_id_b) {
227
+ return seq_id_a < seq_id_b;
228
+ }
229
+ }
230
+ }
231
+ // when all else is equal, sort by pos
232
+ if (batch.pos) {
233
+ return batch.pos[a] < batch.pos[b];
234
+ }
235
+ // no pos, sort by id
236
+ return a < b;
237
+ }
238
+ // shared prompts go first
239
+ return n_seq_a > n_seq_b;
240
+ }
241
+ );
242
+ // init seq
243
+ llama_sbatch_seq * last_seq = nullptr;
244
+
245
+ for (size_t i = 0; i < n_tokens; ++i) {
246
+ const size_t bi = ids[i];
247
+ const int32_t n_seqs = batch.n_seq_id[bi];
248
+ llama_seq_id * seq_ids = batch.seq_id[bi];
249
+ if (last_seq != nullptr) {
250
+ bool same = n_seqs == last_seq->n_seq_id;
251
+ for (int32_t j = 0; same && j < n_seqs; ++j) {
252
+ if (seq_ids[j] != last_seq->seq_id[j]) {
253
+ same = false;
254
+ }
255
+ }
256
+ if (same) {
257
+ last_seq->length += 1;
258
+ continue;
259
+ }
260
+ }
261
+ llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
262
+ seq.push_back(new_seq);
263
+ last_seq = &seq.back();
264
+ }
265
+ // keep shared prompts first at the end, then sort by length descending.
266
+ std::sort(seq.begin(), seq.end(),
267
+ [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
268
+ if (a.n_seq_id == b.n_seq_id) {
269
+ return a.length > b.length;
270
+ }
271
+ return a.n_seq_id < b.n_seq_id;
272
+ }
273
+ );
274
+ }
275
+
276
+ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
277
+ batch = in_batch;
278
+ LM_GGML_ASSERT(batch.n_tokens > 0);
279
+ if (!batch.pos) {
280
+ pos.resize(batch.n_tokens);
281
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
282
+ pos[i] = i + p0;
283
+ }
284
+ batch.pos = pos.data();
285
+ }
286
+ if (!batch.n_seq_id) {
287
+ n_seq_id.resize(batch.n_tokens);
288
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
289
+ n_seq_id[i] = seq_id_0.size();
290
+ }
291
+ batch.n_seq_id = n_seq_id.data();
292
+ }
293
+ if (!batch.seq_id) {
294
+ seq_id.resize(batch.n_tokens + 1);
295
+ seq_id[batch.n_tokens] = NULL;
296
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
297
+ seq_id[i] = seq_id_0.data();
298
+ }
299
+ batch.seq_id = seq_id.data();
300
+ }
301
+ if (!batch.logits) {
302
+ logits.resize(batch.n_tokens);
303
+ logits[logits.size() - 1] = true;
304
+ batch.logits = logits.data();
305
+ }
306
+ }
307
+
308
+ //
309
+ // interface implementation
310
+ //
311
+
312
+ struct llama_batch llama_batch_get_one(
313
+ llama_token * tokens,
314
+ int32_t n_tokens) {
315
+ return {
316
+ /*n_tokens =*/ n_tokens,
317
+ /*tokens =*/ tokens,
318
+ /*embd =*/ nullptr,
319
+ /*pos =*/ nullptr,
320
+ /*n_seq_id =*/ nullptr,
321
+ /*seq_id =*/ nullptr,
322
+ /*logits =*/ nullptr,
323
+ };
324
+ }
325
+
326
+ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
327
+ llama_batch batch = {
328
+ /*n_tokens =*/ 0,
329
+ /*tokens =*/ nullptr,
330
+ /*embd =*/ nullptr,
331
+ /*pos =*/ nullptr,
332
+ /*n_seq_id =*/ nullptr,
333
+ /*seq_id =*/ nullptr,
334
+ /*logits =*/ nullptr,
335
+ };
336
+
337
+ if (embd) {
338
+ batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
339
+ } else {
340
+ batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
341
+ }
342
+
343
+ batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
344
+ batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
345
+ batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
346
+ for (int i = 0; i < n_tokens_alloc; ++i) {
347
+ batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
348
+ }
349
+ batch.seq_id[n_tokens_alloc] = nullptr;
350
+
351
+ batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
352
+
353
+ return batch;
354
+ }
355
+
356
+ void llama_batch_free(struct llama_batch batch) {
357
+ if (batch.token) free(batch.token);
358
+ if (batch.embd) free(batch.embd);
359
+ if (batch.pos) free(batch.pos);
360
+ if (batch.n_seq_id) free(batch.n_seq_id);
361
+ if (batch.seq_id) {
362
+ for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
363
+ free(batch.seq_id[i]);
364
+ }
365
+ free(batch.seq_id);
366
+ }
367
+ if (batch.logits) free(batch.logits);
368
+ }