cui-llama.rn 1.4.0 → 1.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. package/android/src/main/jni.cpp +9 -9
  2. package/cpp/common.cpp +163 -60
  3. package/cpp/common.h +43 -12
  4. package/cpp/ggml-alloc.c +1042 -1037
  5. package/cpp/ggml-backend-impl.h +255 -256
  6. package/cpp/ggml-backend-reg.cpp +582 -582
  7. package/cpp/ggml-backend.cpp +2002 -2002
  8. package/cpp/ggml-backend.h +354 -352
  9. package/cpp/ggml-common.h +1853 -1853
  10. package/cpp/ggml-cpp.h +39 -39
  11. package/cpp/ggml-cpu-aarch64.cpp +4247 -4247
  12. package/cpp/ggml-cpu-aarch64.h +8 -8
  13. package/cpp/ggml-cpu-impl.h +386 -386
  14. package/cpp/ggml-cpu-quants.c +10920 -10839
  15. package/cpp/ggml-cpu-traits.cpp +36 -36
  16. package/cpp/ggml-cpu-traits.h +38 -38
  17. package/cpp/ggml-cpu.c +329 -60
  18. package/cpp/ggml-cpu.cpp +10 -2
  19. package/cpp/ggml-cpu.h +135 -135
  20. package/cpp/ggml-impl.h +567 -567
  21. package/cpp/ggml-metal-impl.h +17 -17
  22. package/cpp/ggml-metal.m +4884 -4884
  23. package/cpp/ggml-quants.c +5238 -5238
  24. package/cpp/ggml-threading.h +14 -14
  25. package/cpp/ggml.c +6514 -6448
  26. package/cpp/ggml.h +2194 -2163
  27. package/cpp/gguf.cpp +1329 -1325
  28. package/cpp/gguf.h +202 -202
  29. package/cpp/json-schema-to-grammar.cpp +1045 -1045
  30. package/cpp/json-schema-to-grammar.h +8 -8
  31. package/cpp/json.hpp +24766 -24766
  32. package/cpp/llama-adapter.cpp +347 -346
  33. package/cpp/llama-adapter.h +74 -73
  34. package/cpp/llama-arch.cpp +1487 -1434
  35. package/cpp/llama-arch.h +400 -395
  36. package/cpp/llama-batch.cpp +368 -368
  37. package/cpp/llama-batch.h +88 -88
  38. package/cpp/llama-chat.cpp +578 -567
  39. package/cpp/llama-chat.h +52 -51
  40. package/cpp/llama-context.cpp +1775 -1771
  41. package/cpp/llama-context.h +128 -128
  42. package/cpp/llama-cparams.cpp +1 -1
  43. package/cpp/llama-cparams.h +37 -37
  44. package/cpp/llama-cpp.h +30 -30
  45. package/cpp/llama-grammar.cpp +1139 -1139
  46. package/cpp/llama-grammar.h +143 -143
  47. package/cpp/llama-hparams.cpp +71 -71
  48. package/cpp/llama-hparams.h +139 -140
  49. package/cpp/llama-impl.cpp +167 -167
  50. package/cpp/llama-impl.h +61 -61
  51. package/cpp/llama-kv-cache.cpp +718 -718
  52. package/cpp/llama-kv-cache.h +218 -218
  53. package/cpp/llama-mmap.cpp +2 -1
  54. package/cpp/llama-mmap.h +67 -67
  55. package/cpp/llama-model-loader.cpp +1124 -1011
  56. package/cpp/llama-model-loader.h +167 -158
  57. package/cpp/llama-model.cpp +3997 -2202
  58. package/cpp/llama-model.h +370 -391
  59. package/cpp/llama-sampling.cpp +2408 -2406
  60. package/cpp/llama-sampling.h +32 -48
  61. package/cpp/llama-vocab.cpp +3247 -1982
  62. package/cpp/llama-vocab.h +125 -182
  63. package/cpp/llama.cpp +416 -2886
  64. package/cpp/llama.h +1323 -1285
  65. package/cpp/log.cpp +401 -401
  66. package/cpp/log.h +121 -121
  67. package/cpp/rn-llama.hpp +18 -12
  68. package/cpp/sampling.cpp +505 -500
  69. package/cpp/sgemm.cpp +2597 -2597
  70. package/cpp/speculative.cpp +277 -274
  71. package/cpp/speculative.h +28 -28
  72. package/cpp/unicode.cpp +2 -3
  73. package/package.json +1 -1
@@ -1,368 +1,368 @@
1
- #include "llama-batch.h"
2
-
3
- #include <cstring>
4
- #include <algorithm>
5
-
6
- llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
7
- // clear empty sequences
8
- // the previous ubatch is assumed to be gone,
9
- // so nothing should refer to values in these sequences anymore.
10
- for (size_t i = seq.size(); i-- > 0;) {
11
- if (seq[i].length == 0) {
12
- seq.pop_back();
13
- } else {
14
- break;
15
- }
16
- }
17
- ubatch_token.resize(!has_embd ? n_ubatch : 0);
18
- ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
19
- ubatch_pos.resize(n_ubatch);
20
- ubatch_n_seq_id.resize(n_ubatch);
21
- ubatch_seq_id.resize(n_ubatch);
22
- ubatch_output.resize(n_ubatch);
23
- llama_ubatch ubatch = {
24
- /*equal_seqs =*/ true,
25
- /*n_tokens =*/ 0,
26
- /*n_seq_tokens =*/ 0,
27
- /*n_seqs =*/ 0,
28
- /*token =*/ !has_embd ? ubatch_token.data() : nullptr,
29
- /*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
30
- /*pos =*/ ubatch_pos.data(),
31
- /*n_seq_id =*/ ubatch_n_seq_id.data(),
32
- /*seq_id =*/ ubatch_seq_id.data(),
33
- /*output =*/ ubatch_output.data(),
34
- };
35
- return ubatch;
36
- }
37
-
38
- void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
39
- LM_GGML_ASSERT(batch != nullptr);
40
- LM_GGML_ASSERT(length <= seq.length);
41
- // Can only add sequences of equal lengths to a batch,
42
- // otherwise it isn't clear to which sequence a token belongs
43
- LM_GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
44
- LM_GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
45
- // NOTE: loops are separated for cache-friendliness
46
- if (batch->token) {
47
- if (ubatch.equal_seqs) {
48
- for (size_t i = 0; i < length; ++i) {
49
- ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
50
- }
51
- } else {
52
- // simple split
53
- ubatch.token = batch->token + seq.offset;
54
- }
55
- } else {
56
- ubatch.token = nullptr;
57
- }
58
- if (batch->embd) {
59
- if (ubatch.equal_seqs) {
60
- for (size_t i = 0; i < length; ++i) {
61
- memcpy(
62
- ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
63
- batch->embd + (n_embd * ids[seq.offset + i]),
64
- n_embd * sizeof(float)
65
- );
66
- }
67
- } else {
68
- // simple split
69
- ubatch.embd = batch->embd + (n_embd * seq.offset);
70
- }
71
- } else {
72
- ubatch.embd = nullptr;
73
- }
74
- if (ubatch.equal_seqs) {
75
- for (size_t i = 0; i < length; ++i) {
76
- ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
77
- }
78
- } else {
79
- // simple split
80
- ubatch.pos = batch->pos + seq.offset;
81
- }
82
- if (ubatch.equal_seqs) {
83
- ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
84
- if (seq.seq_id) {
85
- ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
86
- }
87
- } else {
88
- // simple split
89
- if (batch->n_seq_id) {
90
- ubatch.n_seq_id = batch->n_seq_id + seq.offset;
91
- } else {
92
- for (size_t i = 0; i < length; ++i) {
93
- ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
94
- }
95
- }
96
- if (batch->seq_id) {
97
- ubatch.seq_id = batch->seq_id + seq.offset;
98
- }
99
- }
100
- if (logits_all) {
101
- for (size_t i = 0; i < length; ++i) {
102
- ubatch.output[ubatch.n_tokens + i] = 1;
103
- out_ids.push_back(ids[seq.offset + i]);
104
- }
105
- } else if (batch->logits) {
106
- if (ubatch.equal_seqs) {
107
- for (size_t i = 0; i < length; ++i) {
108
- size_t id = ids[seq.offset + i];
109
- int8_t is_output = batch->logits[id];
110
- ubatch.output[ubatch.n_tokens + i] = is_output;
111
- if (is_output) { out_ids.push_back(id); }
112
- }
113
- } else {
114
- // simple split
115
- ubatch.output = batch->logits + seq.offset;
116
- for (size_t i = 0; i < length; ++i) {
117
- if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
118
- }
119
- }
120
- } else {
121
- // only get last output
122
- for (size_t i = 0; i < length; ++i) {
123
- size_t id = ids[seq.offset + i];
124
- int8_t is_last = id == ids.size() - 1;
125
- ubatch.output[ubatch.n_tokens + i] = is_last;
126
- if (is_last) { out_ids.push_back(id); }
127
- }
128
- }
129
- if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
130
- ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
131
- }
132
- ubatch.n_tokens += length;
133
- ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
134
- seq.offset += length;
135
- seq.length -= length;
136
- n_tokens -= length;
137
- LM_GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
138
- }
139
-
140
- llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
141
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
142
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
143
- ubatch.equal_seqs = false;
144
- if (!seq.empty()) {
145
- llama_sbatch_seq & s = seq[0];
146
- size_t length = s.length < n_ubatch ? s.length : n_ubatch;
147
- LM_GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
148
- add_seq_to_ubatch(ubatch, s, length);
149
- }
150
- return ubatch;
151
- }
152
-
153
- llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
154
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
155
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
156
- if (!seq.empty()) {
157
- size_t length = 0;
158
- size_t n_tokens_in_ubatch = 0;
159
- LM_GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
160
- // smallest first, because it's easier to split this way;
161
- // starting from the end to pop in constant time.
162
- for (size_t i = seq.size(); i-- > 0;) {
163
- llama_sbatch_seq & s = seq[i];
164
- LM_GGML_ASSERT(s.length > 0);
165
- if (length == 0) {
166
- length = s.length < n_ubatch ? s.length : n_ubatch;
167
- }
168
- add_seq_to_ubatch(ubatch, s, length);
169
- n_tokens_in_ubatch += length;
170
- // shared prompts can't be mixed with any of their sequences,
171
- // so it's safer to compute them in their own ubatch
172
- if (s.n_seq_id > 1) { break; }
173
- // stop when there isn't enough space for another sequence
174
- if (length + n_tokens_in_ubatch > n_ubatch) { break; }
175
- }
176
- }
177
- return ubatch;
178
- }
179
-
180
- llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
181
- n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
182
- llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
183
- if (!seq.empty()) {
184
- llama_sbatch_seq & s = seq[seq.size() - 1];
185
- size_t length = s.length < n_ubatch ? s.length : n_ubatch;
186
- LM_GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
187
- add_seq_to_ubatch(ubatch, s, length);
188
- }
189
- return ubatch;
190
- }
191
-
192
- void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
193
- LM_GGML_ASSERT(batch.n_tokens >= 0);
194
- this->batch = &batch;
195
- this->n_embd = n_embd;
196
- this->logits_all = logits_all;
197
-
198
- n_tokens = batch.n_tokens;
199
- ids.resize(n_tokens);
200
- out_ids.clear();
201
- // TODO: reserve out_ids and seq
202
-
203
- for (size_t i = 0; i < n_tokens; ++i) {
204
- ids[i] = i;
205
- }
206
- if (simple_split) {
207
- seq.resize(1);
208
- llama_sbatch_seq & s = seq[0];
209
- s.n_seq_id = 0;
210
- s.seq_id = nullptr;
211
- s.offset = 0;
212
- s.length = n_tokens;
213
- return;
214
- }
215
- std::sort(ids.begin(), ids.end(),
216
- [&batch](size_t a, size_t b) {
217
- int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
218
- int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
219
- // sort by seq_id, then by pos
220
- if (n_seq_a == n_seq_b) {
221
- if (batch.seq_id) {
222
- for (int32_t i = 0; i < n_seq_a; ++i) {
223
- llama_seq_id seq_id_a = batch.seq_id[a][i];
224
- llama_seq_id seq_id_b = batch.seq_id[b][i];
225
- // smaller seq_ids go first
226
- if (seq_id_a != seq_id_b) {
227
- return seq_id_a < seq_id_b;
228
- }
229
- }
230
- }
231
- // when all else is equal, sort by pos
232
- if (batch.pos) {
233
- return batch.pos[a] < batch.pos[b];
234
- }
235
- // no pos, sort by id
236
- return a < b;
237
- }
238
- // shared prompts go first
239
- return n_seq_a > n_seq_b;
240
- }
241
- );
242
- // init seq
243
- llama_sbatch_seq * last_seq = nullptr;
244
-
245
- for (size_t i = 0; i < n_tokens; ++i) {
246
- const size_t bi = ids[i];
247
- const int32_t n_seqs = batch.n_seq_id[bi];
248
- llama_seq_id * seq_ids = batch.seq_id[bi];
249
- if (last_seq != nullptr) {
250
- bool same = n_seqs == last_seq->n_seq_id;
251
- for (int32_t j = 0; same && j < n_seqs; ++j) {
252
- if (seq_ids[j] != last_seq->seq_id[j]) {
253
- same = false;
254
- }
255
- }
256
- if (same) {
257
- last_seq->length += 1;
258
- continue;
259
- }
260
- }
261
- llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
262
- seq.push_back(new_seq);
263
- last_seq = &seq.back();
264
- }
265
- // keep shared prompts first at the end, then sort by length descending.
266
- std::sort(seq.begin(), seq.end(),
267
- [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
268
- if (a.n_seq_id == b.n_seq_id) {
269
- return a.length > b.length;
270
- }
271
- return a.n_seq_id < b.n_seq_id;
272
- }
273
- );
274
- }
275
-
276
- llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
277
- batch = in_batch;
278
- LM_GGML_ASSERT(batch.n_tokens > 0);
279
- if (!batch.pos) {
280
- pos.resize(batch.n_tokens);
281
- for (int32_t i = 0; i < batch.n_tokens; i++) {
282
- pos[i] = i + p0;
283
- }
284
- batch.pos = pos.data();
285
- }
286
- if (!batch.n_seq_id) {
287
- n_seq_id.resize(batch.n_tokens);
288
- for (int32_t i = 0; i < batch.n_tokens; i++) {
289
- n_seq_id[i] = seq_id_0.size();
290
- }
291
- batch.n_seq_id = n_seq_id.data();
292
- }
293
- if (!batch.seq_id) {
294
- seq_id.resize(batch.n_tokens + 1);
295
- seq_id[batch.n_tokens] = NULL;
296
- for (int32_t i = 0; i < batch.n_tokens; i++) {
297
- seq_id[i] = seq_id_0.data();
298
- }
299
- batch.seq_id = seq_id.data();
300
- }
301
- if (!batch.logits) {
302
- logits.resize(batch.n_tokens);
303
- logits[logits.size() - 1] = true;
304
- batch.logits = logits.data();
305
- }
306
- }
307
-
308
- //
309
- // interface implementation
310
- //
311
-
312
- struct llama_batch llama_batch_get_one(
313
- llama_token * tokens,
314
- int32_t n_tokens) {
315
- return {
316
- /*n_tokens =*/ n_tokens,
317
- /*tokens =*/ tokens,
318
- /*embd =*/ nullptr,
319
- /*pos =*/ nullptr,
320
- /*n_seq_id =*/ nullptr,
321
- /*seq_id =*/ nullptr,
322
- /*logits =*/ nullptr,
323
- };
324
- }
325
-
326
- struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
327
- llama_batch batch = {
328
- /*n_tokens =*/ 0,
329
- /*tokens =*/ nullptr,
330
- /*embd =*/ nullptr,
331
- /*pos =*/ nullptr,
332
- /*n_seq_id =*/ nullptr,
333
- /*seq_id =*/ nullptr,
334
- /*logits =*/ nullptr,
335
- };
336
-
337
- if (embd) {
338
- batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
339
- } else {
340
- batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
341
- }
342
-
343
- batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
344
- batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
345
- batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
346
- for (int i = 0; i < n_tokens_alloc; ++i) {
347
- batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
348
- }
349
- batch.seq_id[n_tokens_alloc] = nullptr;
350
-
351
- batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
352
-
353
- return batch;
354
- }
355
-
356
- void llama_batch_free(struct llama_batch batch) {
357
- if (batch.token) free(batch.token);
358
- if (batch.embd) free(batch.embd);
359
- if (batch.pos) free(batch.pos);
360
- if (batch.n_seq_id) free(batch.n_seq_id);
361
- if (batch.seq_id) {
362
- for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
363
- free(batch.seq_id[i]);
364
- }
365
- free(batch.seq_id);
366
- }
367
- if (batch.logits) free(batch.logits);
368
- }
1
+ #include "llama-batch.h"
2
+
3
+ #include <cstring>
4
+ #include <algorithm>
5
+
6
+ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
7
+ // clear empty sequences
8
+ // the previous ubatch is assumed to be gone,
9
+ // so nothing should refer to values in these sequences anymore.
10
+ for (size_t i = seq.size(); i-- > 0;) {
11
+ if (seq[i].length == 0) {
12
+ seq.pop_back();
13
+ } else {
14
+ break;
15
+ }
16
+ }
17
+ ubatch_token.resize(!has_embd ? n_ubatch : 0);
18
+ ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
19
+ ubatch_pos.resize(n_ubatch);
20
+ ubatch_n_seq_id.resize(n_ubatch);
21
+ ubatch_seq_id.resize(n_ubatch);
22
+ ubatch_output.resize(n_ubatch);
23
+ llama_ubatch ubatch = {
24
+ /*equal_seqs =*/ true,
25
+ /*n_tokens =*/ 0,
26
+ /*n_seq_tokens =*/ 0,
27
+ /*n_seqs =*/ 0,
28
+ /*token =*/ !has_embd ? ubatch_token.data() : nullptr,
29
+ /*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
30
+ /*pos =*/ ubatch_pos.data(),
31
+ /*n_seq_id =*/ ubatch_n_seq_id.data(),
32
+ /*seq_id =*/ ubatch_seq_id.data(),
33
+ /*output =*/ ubatch_output.data(),
34
+ };
35
+ return ubatch;
36
+ }
37
+
38
+ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
39
+ LM_GGML_ASSERT(batch != nullptr);
40
+ LM_GGML_ASSERT(length <= seq.length);
41
+ // Can only add sequences of equal lengths to a batch,
42
+ // otherwise it isn't clear to which sequence a token belongs
43
+ LM_GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
44
+ LM_GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
45
+ // NOTE: loops are separated for cache-friendliness
46
+ if (batch->token) {
47
+ if (ubatch.equal_seqs) {
48
+ for (size_t i = 0; i < length; ++i) {
49
+ ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
50
+ }
51
+ } else {
52
+ // simple split
53
+ ubatch.token = batch->token + seq.offset;
54
+ }
55
+ } else {
56
+ ubatch.token = nullptr;
57
+ }
58
+ if (batch->embd) {
59
+ if (ubatch.equal_seqs) {
60
+ for (size_t i = 0; i < length; ++i) {
61
+ memcpy(
62
+ ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
63
+ batch->embd + (n_embd * ids[seq.offset + i]),
64
+ n_embd * sizeof(float)
65
+ );
66
+ }
67
+ } else {
68
+ // simple split
69
+ ubatch.embd = batch->embd + (n_embd * seq.offset);
70
+ }
71
+ } else {
72
+ ubatch.embd = nullptr;
73
+ }
74
+ if (ubatch.equal_seqs) {
75
+ for (size_t i = 0; i < length; ++i) {
76
+ ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
77
+ }
78
+ } else {
79
+ // simple split
80
+ ubatch.pos = batch->pos + seq.offset;
81
+ }
82
+ if (ubatch.equal_seqs) {
83
+ ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
84
+ if (seq.seq_id) {
85
+ ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
86
+ }
87
+ } else {
88
+ // simple split
89
+ if (batch->n_seq_id) {
90
+ ubatch.n_seq_id = batch->n_seq_id + seq.offset;
91
+ } else {
92
+ for (size_t i = 0; i < length; ++i) {
93
+ ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
94
+ }
95
+ }
96
+ if (batch->seq_id) {
97
+ ubatch.seq_id = batch->seq_id + seq.offset;
98
+ }
99
+ }
100
+ if (logits_all) {
101
+ for (size_t i = 0; i < length; ++i) {
102
+ ubatch.output[ubatch.n_tokens + i] = 1;
103
+ out_ids.push_back(ids[seq.offset + i]);
104
+ }
105
+ } else if (batch->logits) {
106
+ if (ubatch.equal_seqs) {
107
+ for (size_t i = 0; i < length; ++i) {
108
+ size_t id = ids[seq.offset + i];
109
+ int8_t is_output = batch->logits[id];
110
+ ubatch.output[ubatch.n_tokens + i] = is_output;
111
+ if (is_output) { out_ids.push_back(id); }
112
+ }
113
+ } else {
114
+ // simple split
115
+ ubatch.output = batch->logits + seq.offset;
116
+ for (size_t i = 0; i < length; ++i) {
117
+ if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
118
+ }
119
+ }
120
+ } else {
121
+ // only get last output
122
+ for (size_t i = 0; i < length; ++i) {
123
+ size_t id = ids[seq.offset + i];
124
+ int8_t is_last = id == ids.size() - 1;
125
+ ubatch.output[ubatch.n_tokens + i] = is_last;
126
+ if (is_last) { out_ids.push_back(id); }
127
+ }
128
+ }
129
+ if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
130
+ ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
131
+ }
132
+ ubatch.n_tokens += length;
133
+ ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
134
+ seq.offset += length;
135
+ seq.length -= length;
136
+ n_tokens -= length;
137
+ LM_GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
138
+ }
139
+
140
+ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
141
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
142
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
143
+ ubatch.equal_seqs = false;
144
+ if (!seq.empty()) {
145
+ llama_sbatch_seq & s = seq[0];
146
+ size_t length = s.length < n_ubatch ? s.length : n_ubatch;
147
+ LM_GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
148
+ add_seq_to_ubatch(ubatch, s, length);
149
+ }
150
+ return ubatch;
151
+ }
152
+
153
+ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
154
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
155
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
156
+ if (!seq.empty()) {
157
+ size_t length = 0;
158
+ size_t n_tokens_in_ubatch = 0;
159
+ LM_GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
160
+ // smallest first, because it's easier to split this way;
161
+ // starting from the end to pop in constant time.
162
+ for (size_t i = seq.size(); i-- > 0;) {
163
+ llama_sbatch_seq & s = seq[i];
164
+ LM_GGML_ASSERT(s.length > 0);
165
+ if (length == 0) {
166
+ length = s.length < n_ubatch ? s.length : n_ubatch;
167
+ }
168
+ add_seq_to_ubatch(ubatch, s, length);
169
+ n_tokens_in_ubatch += length;
170
+ // shared prompts can't be mixed with any of their sequences,
171
+ // so it's safer to compute them in their own ubatch
172
+ if (s.n_seq_id > 1) { break; }
173
+ // stop when there isn't enough space for another sequence
174
+ if (length + n_tokens_in_ubatch > n_ubatch) { break; }
175
+ }
176
+ }
177
+ return ubatch;
178
+ }
179
+
180
+ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
181
+ n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
182
+ llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
183
+ if (!seq.empty()) {
184
+ llama_sbatch_seq & s = seq[seq.size() - 1];
185
+ size_t length = s.length < n_ubatch ? s.length : n_ubatch;
186
+ LM_GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
187
+ add_seq_to_ubatch(ubatch, s, length);
188
+ }
189
+ return ubatch;
190
+ }
191
+
192
+ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
193
+ LM_GGML_ASSERT(batch.n_tokens >= 0);
194
+ this->batch = &batch;
195
+ this->n_embd = n_embd;
196
+ this->logits_all = logits_all;
197
+
198
+ n_tokens = batch.n_tokens;
199
+ ids.resize(n_tokens);
200
+ out_ids.clear();
201
+ // TODO: reserve out_ids and seq
202
+
203
+ for (size_t i = 0; i < n_tokens; ++i) {
204
+ ids[i] = i;
205
+ }
206
+ if (simple_split) {
207
+ seq.resize(1);
208
+ llama_sbatch_seq & s = seq[0];
209
+ s.n_seq_id = 0;
210
+ s.seq_id = nullptr;
211
+ s.offset = 0;
212
+ s.length = n_tokens;
213
+ return;
214
+ }
215
+ std::sort(ids.begin(), ids.end(),
216
+ [&batch](size_t a, size_t b) {
217
+ int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
218
+ int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
219
+ // sort by seq_id, then by pos
220
+ if (n_seq_a == n_seq_b) {
221
+ if (batch.seq_id) {
222
+ for (int32_t i = 0; i < n_seq_a; ++i) {
223
+ llama_seq_id seq_id_a = batch.seq_id[a][i];
224
+ llama_seq_id seq_id_b = batch.seq_id[b][i];
225
+ // smaller seq_ids go first
226
+ if (seq_id_a != seq_id_b) {
227
+ return seq_id_a < seq_id_b;
228
+ }
229
+ }
230
+ }
231
+ // when all else is equal, sort by pos
232
+ if (batch.pos) {
233
+ return batch.pos[a] < batch.pos[b];
234
+ }
235
+ // no pos, sort by id
236
+ return a < b;
237
+ }
238
+ // shared prompts go first
239
+ return n_seq_a > n_seq_b;
240
+ }
241
+ );
242
+ // init seq
243
+ llama_sbatch_seq * last_seq = nullptr;
244
+
245
+ for (size_t i = 0; i < n_tokens; ++i) {
246
+ const size_t bi = ids[i];
247
+ const int32_t n_seqs = batch.n_seq_id[bi];
248
+ llama_seq_id * seq_ids = batch.seq_id[bi];
249
+ if (last_seq != nullptr) {
250
+ bool same = n_seqs == last_seq->n_seq_id;
251
+ for (int32_t j = 0; same && j < n_seqs; ++j) {
252
+ if (seq_ids[j] != last_seq->seq_id[j]) {
253
+ same = false;
254
+ }
255
+ }
256
+ if (same) {
257
+ last_seq->length += 1;
258
+ continue;
259
+ }
260
+ }
261
+ llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
262
+ seq.push_back(new_seq);
263
+ last_seq = &seq.back();
264
+ }
265
+ // keep shared prompts first at the end, then sort by length descending.
266
+ std::sort(seq.begin(), seq.end(),
267
+ [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
268
+ if (a.n_seq_id == b.n_seq_id) {
269
+ return a.length > b.length;
270
+ }
271
+ return a.n_seq_id < b.n_seq_id;
272
+ }
273
+ );
274
+ }
275
+
276
+ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
277
+ batch = in_batch;
278
+ LM_GGML_ASSERT(batch.n_tokens > 0);
279
+ if (!batch.pos) {
280
+ pos.resize(batch.n_tokens);
281
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
282
+ pos[i] = i + p0;
283
+ }
284
+ batch.pos = pos.data();
285
+ }
286
+ if (!batch.n_seq_id) {
287
+ n_seq_id.resize(batch.n_tokens);
288
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
289
+ n_seq_id[i] = seq_id_0.size();
290
+ }
291
+ batch.n_seq_id = n_seq_id.data();
292
+ }
293
+ if (!batch.seq_id) {
294
+ seq_id.resize(batch.n_tokens + 1);
295
+ seq_id[batch.n_tokens] = NULL;
296
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
297
+ seq_id[i] = seq_id_0.data();
298
+ }
299
+ batch.seq_id = seq_id.data();
300
+ }
301
+ if (!batch.logits) {
302
+ logits.resize(batch.n_tokens);
303
+ logits[logits.size() - 1] = true;
304
+ batch.logits = logits.data();
305
+ }
306
+ }
307
+
308
+ //
309
+ // interface implementation
310
+ //
311
+
312
+ struct llama_batch llama_batch_get_one(
313
+ llama_token * tokens,
314
+ int32_t n_tokens) {
315
+ return {
316
+ /*n_tokens =*/ n_tokens,
317
+ /*tokens =*/ tokens,
318
+ /*embd =*/ nullptr,
319
+ /*pos =*/ nullptr,
320
+ /*n_seq_id =*/ nullptr,
321
+ /*seq_id =*/ nullptr,
322
+ /*logits =*/ nullptr,
323
+ };
324
+ }
325
+
326
+ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
327
+ llama_batch batch = {
328
+ /*n_tokens =*/ 0,
329
+ /*tokens =*/ nullptr,
330
+ /*embd =*/ nullptr,
331
+ /*pos =*/ nullptr,
332
+ /*n_seq_id =*/ nullptr,
333
+ /*seq_id =*/ nullptr,
334
+ /*logits =*/ nullptr,
335
+ };
336
+
337
+ if (embd) {
338
+ batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
339
+ } else {
340
+ batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
341
+ }
342
+
343
+ batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
344
+ batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
345
+ batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
346
+ for (int i = 0; i < n_tokens_alloc; ++i) {
347
+ batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
348
+ }
349
+ batch.seq_id[n_tokens_alloc] = nullptr;
350
+
351
+ batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
352
+
353
+ return batch;
354
+ }
355
+
356
+ void llama_batch_free(struct llama_batch batch) {
357
+ if (batch.token) free(batch.token);
358
+ if (batch.embd) free(batch.embd);
359
+ if (batch.pos) free(batch.pos);
360
+ if (batch.n_seq_id) free(batch.n_seq_id);
361
+ if (batch.seq_id) {
362
+ for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
363
+ free(batch.seq_id[i]);
364
+ }
365
+ free(batch.seq_id);
366
+ }
367
+ if (batch.logits) free(batch.logits);
368
+ }