llama-cpp-capacitor 0.0.6 → 0.0.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. package/cpp/LICENSE +21 -0
  2. package/cpp/README.md +4 -0
  3. package/cpp/anyascii.c +22223 -0
  4. package/cpp/anyascii.h +42 -0
  5. package/cpp/chat-parser.cpp +393 -0
  6. package/cpp/chat-parser.h +120 -0
  7. package/cpp/chat.cpp +2315 -0
  8. package/cpp/chat.h +221 -0
  9. package/cpp/common.cpp +1619 -0
  10. package/cpp/common.h +744 -0
  11. package/cpp/ggml-alloc.c +1028 -0
  12. package/cpp/ggml-alloc.h +76 -0
  13. package/cpp/ggml-backend-impl.h +255 -0
  14. package/cpp/ggml-backend-reg.cpp +600 -0
  15. package/cpp/ggml-backend.cpp +2118 -0
  16. package/cpp/ggml-backend.h +354 -0
  17. package/cpp/ggml-common.h +1878 -0
  18. package/cpp/ggml-cpp.h +39 -0
  19. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  20. package/cpp/ggml-cpu/amx/amx.h +8 -0
  21. package/cpp/ggml-cpu/amx/common.h +91 -0
  22. package/cpp/ggml-cpu/amx/mmq.cpp +2512 -0
  23. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  24. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  25. package/cpp/ggml-cpu/arch/arm/quants.c +3650 -0
  26. package/cpp/ggml-cpu/arch/arm/repack.cpp +1891 -0
  27. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  28. package/cpp/ggml-cpu/arch/x86/quants.c +3820 -0
  29. package/cpp/ggml-cpu/arch/x86/repack.cpp +6307 -0
  30. package/cpp/ggml-cpu/arch-fallback.h +215 -0
  31. package/cpp/ggml-cpu/binary-ops.cpp +158 -0
  32. package/cpp/ggml-cpu/binary-ops.h +16 -0
  33. package/cpp/ggml-cpu/common.h +73 -0
  34. package/cpp/ggml-cpu/ggml-cpu-impl.h +525 -0
  35. package/cpp/ggml-cpu/ggml-cpu.c +3578 -0
  36. package/cpp/ggml-cpu/ggml-cpu.cpp +672 -0
  37. package/cpp/ggml-cpu/ops.cpp +10587 -0
  38. package/cpp/ggml-cpu/ops.h +114 -0
  39. package/cpp/ggml-cpu/quants.c +1193 -0
  40. package/cpp/ggml-cpu/quants.h +97 -0
  41. package/cpp/ggml-cpu/repack.cpp +1982 -0
  42. package/cpp/ggml-cpu/repack.h +120 -0
  43. package/cpp/ggml-cpu/simd-mappings.h +1184 -0
  44. package/cpp/ggml-cpu/traits.cpp +36 -0
  45. package/cpp/ggml-cpu/traits.h +38 -0
  46. package/cpp/ggml-cpu/unary-ops.cpp +186 -0
  47. package/cpp/ggml-cpu/unary-ops.h +28 -0
  48. package/cpp/ggml-cpu/vec.cpp +348 -0
  49. package/cpp/ggml-cpu/vec.h +1121 -0
  50. package/cpp/ggml-cpu.h +145 -0
  51. package/cpp/ggml-impl.h +622 -0
  52. package/cpp/ggml-metal-impl.h +688 -0
  53. package/cpp/ggml-metal.h +66 -0
  54. package/cpp/ggml-metal.m +6833 -0
  55. package/cpp/ggml-opt.cpp +1093 -0
  56. package/cpp/ggml-opt.h +256 -0
  57. package/cpp/ggml-quants.c +5324 -0
  58. package/cpp/ggml-quants.h +106 -0
  59. package/cpp/ggml-threading.cpp +12 -0
  60. package/cpp/ggml-threading.h +14 -0
  61. package/cpp/ggml.c +7108 -0
  62. package/cpp/ggml.h +2492 -0
  63. package/cpp/gguf.cpp +1358 -0
  64. package/cpp/gguf.h +202 -0
  65. package/cpp/json-partial.cpp +256 -0
  66. package/cpp/json-partial.h +38 -0
  67. package/cpp/json-schema-to-grammar.cpp +985 -0
  68. package/cpp/json-schema-to-grammar.h +21 -0
  69. package/cpp/llama-adapter.cpp +388 -0
  70. package/cpp/llama-adapter.h +76 -0
  71. package/cpp/llama-arch.cpp +2355 -0
  72. package/cpp/llama-arch.h +499 -0
  73. package/cpp/llama-batch.cpp +875 -0
  74. package/cpp/llama-batch.h +160 -0
  75. package/cpp/llama-chat.cpp +783 -0
  76. package/cpp/llama-chat.h +65 -0
  77. package/cpp/llama-context.cpp +2748 -0
  78. package/cpp/llama-context.h +306 -0
  79. package/cpp/llama-cparams.cpp +5 -0
  80. package/cpp/llama-cparams.h +41 -0
  81. package/cpp/llama-cpp.h +30 -0
  82. package/cpp/llama-grammar.cpp +1229 -0
  83. package/cpp/llama-grammar.h +173 -0
  84. package/cpp/llama-graph.cpp +1891 -0
  85. package/cpp/llama-graph.h +810 -0
  86. package/cpp/llama-hparams.cpp +180 -0
  87. package/cpp/llama-hparams.h +233 -0
  88. package/cpp/llama-impl.cpp +167 -0
  89. package/cpp/llama-impl.h +61 -0
  90. package/cpp/llama-io.cpp +15 -0
  91. package/cpp/llama-io.h +35 -0
  92. package/cpp/llama-kv-cache-iswa.cpp +318 -0
  93. package/cpp/llama-kv-cache-iswa.h +135 -0
  94. package/cpp/llama-kv-cache.cpp +2059 -0
  95. package/cpp/llama-kv-cache.h +374 -0
  96. package/cpp/llama-kv-cells.h +491 -0
  97. package/cpp/llama-memory-hybrid.cpp +258 -0
  98. package/cpp/llama-memory-hybrid.h +137 -0
  99. package/cpp/llama-memory-recurrent.cpp +1146 -0
  100. package/cpp/llama-memory-recurrent.h +179 -0
  101. package/cpp/llama-memory.cpp +59 -0
  102. package/cpp/llama-memory.h +119 -0
  103. package/cpp/llama-mmap.cpp +600 -0
  104. package/cpp/llama-mmap.h +68 -0
  105. package/cpp/llama-model-loader.cpp +1164 -0
  106. package/cpp/llama-model-loader.h +170 -0
  107. package/cpp/llama-model-saver.cpp +282 -0
  108. package/cpp/llama-model-saver.h +37 -0
  109. package/cpp/llama-model.cpp +19042 -0
  110. package/cpp/llama-model.h +491 -0
  111. package/cpp/llama-sampling.cpp +2575 -0
  112. package/cpp/llama-sampling.h +32 -0
  113. package/cpp/llama-vocab.cpp +3792 -0
  114. package/cpp/llama-vocab.h +176 -0
  115. package/cpp/llama.cpp +358 -0
  116. package/cpp/llama.h +1373 -0
  117. package/cpp/log.cpp +427 -0
  118. package/cpp/log.h +103 -0
  119. package/cpp/minja/chat-template.hpp +550 -0
  120. package/cpp/minja/minja.hpp +3009 -0
  121. package/cpp/nlohmann/json.hpp +25526 -0
  122. package/cpp/nlohmann/json_fwd.hpp +187 -0
  123. package/cpp/regex-partial.cpp +204 -0
  124. package/cpp/regex-partial.h +56 -0
  125. package/cpp/rn-completion.cpp +681 -0
  126. package/cpp/rn-completion.h +116 -0
  127. package/cpp/rn-llama.cpp +345 -0
  128. package/cpp/rn-llama.h +149 -0
  129. package/cpp/rn-mtmd.hpp +602 -0
  130. package/cpp/rn-tts.cpp +591 -0
  131. package/cpp/rn-tts.h +59 -0
  132. package/cpp/sampling.cpp +579 -0
  133. package/cpp/sampling.h +107 -0
  134. package/cpp/tools/mtmd/clip-impl.h +473 -0
  135. package/cpp/tools/mtmd/clip.cpp +4322 -0
  136. package/cpp/tools/mtmd/clip.h +106 -0
  137. package/cpp/tools/mtmd/miniaudio/miniaudio.h +93468 -0
  138. package/cpp/tools/mtmd/mtmd-audio.cpp +769 -0
  139. package/cpp/tools/mtmd/mtmd-audio.h +47 -0
  140. package/cpp/tools/mtmd/mtmd-helper.cpp +460 -0
  141. package/cpp/tools/mtmd/mtmd-helper.h +91 -0
  142. package/cpp/tools/mtmd/mtmd.cpp +1066 -0
  143. package/cpp/tools/mtmd/mtmd.h +298 -0
  144. package/cpp/tools/mtmd/stb/stb_image.h +7988 -0
  145. package/cpp/unicode-data.cpp +7034 -0
  146. package/cpp/unicode-data.h +20 -0
  147. package/cpp/unicode.cpp +1061 -0
  148. package/cpp/unicode.h +68 -0
  149. package/package.json +2 -1
@@ -0,0 +1,875 @@
1
+ #include "llama-batch.h"
2
+
3
+ #include "llama-impl.h"
4
+ #include "llama-vocab.h"
5
+ #include "llama-memory.h"
6
+
7
+ #include <cassert>
8
+ #include <cstring>
9
+ #include <algorithm>
10
+ #include <sstream>
11
+
12
+ llama_batch_allocr::llama_batch_allocr(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {
13
+ const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
14
+ debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
15
+
16
+ seq_pos.resize(LLAMA_MAX_SEQ);
17
+ seq_cpl.resize(LLAMA_MAX_SEQ);
18
+ for (auto & cur : seq_cpl) {
19
+ cur.resize(LLAMA_MAX_SEQ);
20
+ }
21
+
22
+ seq_idx.resize(LLAMA_MAX_SEQ, -1);
23
+ }
24
+
25
+ bool llama_batch_allocr::init(
26
+ const llama_batch & batch_inp,
27
+ const llama_vocab & vocab,
28
+ const llama_memory_i * memory,
29
+ uint32_t n_embd,
30
+ uint32_t n_seq_max,
31
+ bool output_all) {
32
+ clear();
33
+
34
+ batch = batch_inp;
35
+
36
+ this->vocab = &vocab;
37
+
38
+ LM_GGML_ASSERT(batch.n_tokens > 0);
39
+
40
+ //
41
+ // validate input batch
42
+ //
43
+
44
+ if (n_seq_max > LLAMA_MAX_SEQ) {
45
+ LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
46
+ return false;
47
+ }
48
+
49
+ if (batch.token) {
50
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
51
+ if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
52
+ LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
53
+ return false;
54
+ }
55
+ }
56
+ }
57
+
58
+ if (batch.seq_id) {
59
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
60
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
61
+ if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
62
+ LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d >= %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
63
+ return false;
64
+ }
65
+ }
66
+ }
67
+ }
68
+
69
+ //
70
+ // auto-generate missing fields
71
+ //
72
+
73
+ if (!batch.n_seq_id) {
74
+ n_seq_id.resize(batch.n_tokens);
75
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
76
+ n_seq_id[i] = seq_id_0.size();
77
+ }
78
+ batch.n_seq_id = n_seq_id.data();
79
+ }
80
+
81
+ if (!batch.seq_id) {
82
+ seq_id.resize(batch.n_tokens + 1);
83
+ seq_id[batch.n_tokens] = NULL;
84
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
85
+ seq_id[i] = seq_id_0.data();
86
+ }
87
+ batch.seq_id = seq_id.data();
88
+ }
89
+
90
+ if (!batch.pos) {
91
+ pos.resize(batch.n_tokens);
92
+
93
+ // initialize the starting position for each sequence based on the positions in the memory
94
+ llama_pos p0[LLAMA_MAX_SEQ];
95
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
96
+ if (!memory) {
97
+ // if no memory -> start from 0
98
+ p0[s] = 0;
99
+ } else {
100
+ p0[s] = memory->seq_pos_max(s) + 1;
101
+ }
102
+ }
103
+
104
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
105
+ const llama_seq_id seq_id = batch.seq_id[i][0];
106
+
107
+ pos[i] = p0[seq_id];
108
+
109
+ // update the starting position for all sequences that are assigned to the this token
110
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
111
+ const llama_seq_id seq_id = batch.seq_id[i][s];
112
+
113
+ p0[seq_id] = pos[i] + 1;
114
+ }
115
+ }
116
+
117
+ batch.pos = pos.data();
118
+ }
119
+
120
+ if (!batch.logits) {
121
+ if (output_all) {
122
+ // return the output for all tokens
123
+ output.resize(batch.n_tokens, true);
124
+ } else {
125
+ // return the output only for the last token
126
+ output.resize(batch.n_tokens, false);
127
+ output[output.size() - 1] = true;
128
+ }
129
+
130
+ batch.logits = output.data();
131
+ } else if (output_all) {
132
+ bool warn = false;
133
+
134
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
135
+ if (batch.logits[i] == 0) {
136
+ warn = true;
137
+ }
138
+ }
139
+
140
+ if (warn) {
141
+ LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
142
+
143
+ output.resize(batch.n_tokens, true);
144
+ batch.logits = output.data();
145
+ }
146
+ }
147
+
148
+ //
149
+ // compute stats
150
+ //
151
+
152
+ this->n_embd = n_embd;
153
+ this->n_seq_max = n_seq_max;
154
+
155
+ // count the outputs in this batch
156
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
157
+ n_outputs += batch.logits[i] != 0;
158
+ }
159
+
160
+ has_cpl = false;
161
+
162
+ // determine coupled sequences
163
+ // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
164
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
165
+ const llama_seq_id s0 = batch.seq_id[i][0];
166
+
167
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
168
+ const llama_seq_id s1 = batch.seq_id[i][s];
169
+
170
+ seq_pos[s1].insert(batch.pos[i]);
171
+
172
+ if (s > 0) {
173
+ // mark that sequence s1 is coupled to s0
174
+ seq_cpl[s1][s0] = true;
175
+
176
+ // note: tracking the other way around is not necessary for now
177
+ //seq_cpl[s0][s1] = true;
178
+
179
+ has_cpl = true;
180
+ }
181
+ }
182
+ }
183
+
184
+ // precompute the sequence sets for each token and determine the unique sequence ids that participate in the batch
185
+ {
186
+ seq_set_t seq_set_unq;
187
+
188
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
189
+ seq_set_t cur;
190
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
191
+ const llama_seq_id seq_id = batch.seq_id[i][s];
192
+
193
+ cur .set(seq_id);
194
+ seq_set_unq.set(seq_id);
195
+ }
196
+
197
+ seq_set.push_back(cur);
198
+ seq_set_map[cur].push_back(i);
199
+ }
200
+
201
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
202
+ if (seq_set_unq.test(s)) {
203
+ seq_idx[s] = seq_id_unq.size();
204
+ seq_id_unq.push_back(s);
205
+ }
206
+ }
207
+ }
208
+
209
+ if (debug > 0) {
210
+ LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
211
+
212
+ llama_ubatch ubatch {
213
+ /*.b_equal_seqs =*/ false,
214
+ /*.n_tokens =*/ (uint32_t) batch.n_tokens,
215
+ /*.n_seq_tokens =*/ (uint32_t) 1,
216
+ /*.n_seqs =*/ (uint32_t) batch.n_tokens,
217
+ /*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
218
+ /*.token =*/ batch.token,
219
+ /*.embd =*/ batch.embd,
220
+ /*.pos =*/ batch.pos,
221
+ /*.n_seq_id =*/ batch.n_seq_id,
222
+ /*.seq_id =*/ batch.seq_id,
223
+ /*.seq_id_unq =*/ this->seq_id_unq.data(),
224
+ /*.seq_idx =*/ this->seq_idx.data(),
225
+ /*.output =*/ batch.logits,
226
+ /*.data =*/ {},
227
+ };
228
+
229
+ ubatch_print(ubatch, debug);
230
+
231
+ LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
232
+ for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
233
+ if (seq_pos[s0].empty()) {
234
+ continue;
235
+ }
236
+
237
+ std::stringstream ss;
238
+ for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
239
+ if (seq_cpl[s0][s1]) {
240
+ ss << s1 << " ";
241
+ }
242
+ }
243
+
244
+ LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
245
+ __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
246
+ }
247
+ LLAMA_LOG_DEBUG("%s: ]\n", __func__);
248
+ }
249
+
250
+ //
251
+ // consistency checks
252
+ //
253
+
254
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
255
+ if (seq_pos[s].empty()) {
256
+ continue;
257
+ }
258
+
259
+ const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
260
+
261
+ if (p0 >= 0) {
262
+ bool ok = true;
263
+
264
+ if (batch.token) {
265
+ if (seq_pos_min(s) != p0 + 1) {
266
+ ok = false;
267
+ }
268
+ } else {
269
+ assert(batch.embd);
270
+
271
+ // for embeddings (typically used as vision input), we allow them to have repeating positions
272
+ // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
273
+ if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
274
+ ok = false;
275
+ }
276
+ }
277
+
278
+ if (!ok) {
279
+ LLAMA_LOG_ERROR(
280
+ "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
281
+ " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
282
+ " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
283
+ " it is required that the sequence positions remain consecutive: Y = X + 1\n",
284
+ __func__, s, s, p0, s, seq_pos_min(s));
285
+
286
+ return false;
287
+ }
288
+ }
289
+
290
+ if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
291
+ LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
292
+ return false;
293
+ }
294
+ }
295
+
296
+ if (memory) {
297
+ for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
298
+ for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
299
+ if (seq_cpl[s0][s1]) {
300
+ if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
301
+ memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
302
+ LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
303
+ return false;
304
+ }
305
+ }
306
+ }
307
+ }
308
+ }
309
+
310
+ // disallow partial sequence sub-sets:
311
+ //
312
+ // invalid: x
313
+ // i: 0 1 2 ...
314
+ // ---------------------------------------
315
+ // seq_id[i][0]: 0 0 1
316
+ // seq_id[i][1]: 1 1 2
317
+ // seq_id[i][2]: 2
318
+ //
319
+ // disallow decreasing sequence positions:
320
+ //
321
+ // invalid: x
322
+ // i: 0 1 2 3 4 5 6 ...
323
+ // ---------------------------------------
324
+ // pos[i]: 4 5 0 1 6 2 3
325
+ // seq_id[i][0]: 0 0 1 1 0 1 0
326
+ //
327
+ {
328
+ seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
329
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
330
+ cur_seq_set[s].set();
331
+ }
332
+
333
+ llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
334
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
335
+ cur_seq_pos[s] = -1;
336
+ }
337
+
338
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
339
+ const llama_pos pos = batch.pos[i];
340
+
341
+ for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
342
+ const llama_seq_id seq_id = batch.seq_id[i][s];
343
+
344
+ cur_seq_set[seq_id] &= seq_set[i];
345
+
346
+ if (cur_seq_set[seq_id].none()) {
347
+ LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets (not allowed)\n", __func__, seq_id);
348
+ return false;
349
+ }
350
+
351
+ if (pos < cur_seq_pos[seq_id]) {
352
+ LLAMA_LOG_ERROR("%s: sequence %d positions are decreasing (not allowed)\n", __func__, seq_id);
353
+ return false;
354
+ }
355
+ }
356
+ }
357
+ }
358
+
359
+ split_reset();
360
+
361
+ return true;
362
+ }
363
+
364
+ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs) {
365
+ const uint32_t n_tokens = n_seq_tokens*n_seqs;
366
+
367
+ clear();
368
+ split_reset();
369
+
370
+ auto udata = std::make_shared<llama_ubatch::data_t>();
371
+
372
+ udata->token .resize(n_tokens);
373
+ udata->embd .clear();
374
+ udata->pos .resize(n_tokens);
375
+ udata->n_seq_id .resize(n_tokens);
376
+ udata->seq_id .resize(n_tokens);
377
+ udata->seq_id_unq.resize(0);
378
+ udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
379
+ udata->output .resize(n_tokens);
380
+
381
+ for (uint32_t s = 0; s < n_seqs; ++s) {
382
+ udata->seq_idx[s] = s;
383
+ udata->seq_id_unq.push_back(s);
384
+ }
385
+
386
+ llama_ubatch res {
387
+ /*.b_equal_seqs =*/ true,
388
+ /*.n_tokens =*/ n_tokens,
389
+ /*.n_seq_tokens =*/ n_seq_tokens,
390
+ /*.n_seqs =*/ n_seqs,
391
+ /*.n_seqs_unq =*/ n_seqs,
392
+
393
+ /*.token =*/ udata->token.data(),
394
+ /*.embd =*/ nullptr,
395
+ /*.pos =*/ udata->pos.data(),
396
+ /*.n_seq_id =*/ udata->n_seq_id.data(),
397
+ /*.seq_id =*/ udata->seq_id.data(),
398
+ /*.seq_id_unq =*/ udata->seq_id_unq.data(),
399
+ /*.seq_idx =*/ udata->seq_idx.data(),
400
+ /*.output =*/ udata->output.data(),
401
+ /*.data =*/ std::move(udata),
402
+ };
403
+
404
+ return res;
405
+ }
406
+
407
+ const llama_batch & llama_batch_allocr::get_batch() const {
408
+ return batch;
409
+ }
410
+
411
+ uint32_t llama_batch_allocr::get_n_tokens() const {
412
+ return batch.n_tokens;
413
+ }
414
+
415
+ uint32_t llama_batch_allocr::get_n_outputs() const {
416
+ return n_outputs;
417
+ }
418
+
419
+ uint32_t llama_batch_allocr::get_n_used() const {
420
+ return n_used;
421
+ }
422
+
423
+ std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
424
+ return out_ids;
425
+ }
426
+
427
+ llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
428
+ return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
429
+ }
430
+
431
+ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
432
+ return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
433
+ }
434
+
435
+ void llama_batch_allocr::split_reset() {
436
+ out_ids.clear();
437
+
438
+ n_used = 0;
439
+
440
+ used.clear();
441
+ used.resize(get_n_tokens(), false);
442
+ }
443
+
444
+ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
445
+ // find the first unused token
446
+ uint32_t cur_idx = 0;
447
+ while (cur_idx < used.size() && used[cur_idx]) {
448
+ ++cur_idx;
449
+ }
450
+
451
+ // we are done
452
+ if (cur_idx >= used.size()) {
453
+ return {};
454
+ }
455
+
456
+ std::vector<int32_t> idxs;
457
+
458
+ while (true) {
459
+ idxs.push_back(cur_idx);
460
+
461
+ used[cur_idx] = true;
462
+ ++n_used;
463
+
464
+ ++cur_idx;
465
+
466
+ if (cur_idx >= used.size()) {
467
+ break;
468
+ }
469
+
470
+ if (idxs.size() >= n_ubatch) {
471
+ break;
472
+ }
473
+ }
474
+
475
+ return ubatch_add(idxs, idxs.size(), false);
476
+ }
477
+
478
+ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
479
+ if (sequential && has_cpl) {
480
+ LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch (you may need to use the -kvu flag)\n", __func__);
481
+
482
+ return {};
483
+ }
484
+
485
+ std::vector<seq_set_t> cur_seq_set;
486
+
487
+ llama_seq_id last_seq_id = -1;
488
+
489
+ // determine the non-overlapping sequence sets participating in this ubatch
490
+ for (int32_t i = 0; i < batch.n_tokens; ++i) {
491
+ if (used[i]) {
492
+ continue;
493
+ }
494
+
495
+ bool add = true;
496
+
497
+ for (uint32_t s = 0; s < cur_seq_set.size(); ++s) {
498
+ // no overlap with existing sequence sets:
499
+ if (!(cur_seq_set[s] & seq_set[i]).none()) {
500
+ add = false;
501
+ break;
502
+ }
503
+ }
504
+
505
+ // accept only increasing sequence ids
506
+ if (sequential) {
507
+ add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
508
+ }
509
+
510
+ if (add) {
511
+ cur_seq_set.push_back(seq_set[i]);
512
+
513
+ last_seq_id = batch.seq_id[i][0];
514
+
515
+ if (cur_seq_set.size() > n_ubatch) {
516
+ break;
517
+ }
518
+ }
519
+ }
520
+
521
+ const uint32_t n_seqs = cur_seq_set.size();
522
+
523
+ // we are done
524
+ if (n_seqs == 0) {
525
+ return {};
526
+ }
527
+
528
+ // the current batch index of each sequence set
529
+ std::vector<int32_t> cur_idx(n_seqs, 0);
530
+
531
+ for (uint32_t s = 0; s < n_seqs; ++s) {
532
+ while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
533
+ ++cur_idx[s];
534
+ }
535
+ }
536
+
537
+ // the list of batch indices for each sequence set
538
+ // at the end we will concat these to get the final ubatch
539
+ std::vector<idx_vec_t> idxs_per_seq(n_seqs);
540
+
541
+ while (true) {
542
+ // we can only add new n_seq_tokens tokens if all the sequence sets have at least one more unused token and
543
+ // if we haven't reached n_ubatch
544
+ bool can_expand = true;
545
+
546
+ for (uint32_t s = 0; s < n_seqs; ++s) {
547
+ if (cur_idx[s] >= (int32_t) seq_set_map[cur_seq_set[s]].size()) {
548
+ can_expand = false;
549
+ break;
550
+ }
551
+ }
552
+
553
+ if (!can_expand) {
554
+ break;
555
+ }
556
+
557
+ for (uint32_t s = 0; s < n_seqs; ++s) {
558
+ const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
559
+
560
+ idxs_per_seq[s].push_back(idx);
561
+
562
+ used[idx] = true;
563
+ ++n_used;
564
+
565
+ ++cur_idx[s];
566
+ }
567
+
568
+ if ((idxs_per_seq[0].size() + 1)*n_seqs > n_ubatch) {
569
+ break;
570
+ }
571
+ }
572
+
573
+ // concat the per-sequence-set lists
574
+ std::vector<int32_t> idxs;
575
+
576
+ for (uint32_t s = 0; s < n_seqs; ++s) {
577
+ idxs.insert(idxs.end(), idxs_per_seq[s].begin(), idxs_per_seq[s].end());
578
+ }
579
+
580
+ return ubatch_add(idxs, n_seqs, true);
581
+ }
582
+
583
+ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
584
+ // find the first unused token
585
+ uint32_t cur_idx = 0;
586
+ while (cur_idx < used.size() && used[cur_idx]) {
587
+ ++cur_idx;
588
+ }
589
+
590
+ // we are done
591
+ if (cur_idx >= used.size()) {
592
+ return {};
593
+ }
594
+
595
+ // this is the starting sequence set
596
+ // we allow adding tokens only if their sequence set is a subset of the current sequence set
597
+ auto cur_seq_set = seq_set[cur_idx];
598
+
599
+ std::vector<int32_t> idxs;
600
+
601
+ while (true) {
602
+ idxs.push_back(cur_idx);
603
+
604
+ used[cur_idx] = true;
605
+ ++n_used;
606
+
607
+ if (idxs.size() >= n_ubatch) {
608
+ break;
609
+ }
610
+
611
+ do {
612
+ ++cur_idx;
613
+ } while (cur_idx < get_n_tokens() && (used[cur_idx] || ((cur_seq_set & seq_set[cur_idx]) != seq_set[cur_idx])));
614
+
615
+ if (cur_idx == get_n_tokens()) {
616
+ break;
617
+ }
618
+
619
+ cur_seq_set = seq_set[cur_idx];
620
+ }
621
+
622
+ return ubatch_add(idxs, 1, true);
623
+ }
624
+
625
+ void llama_batch_allocr::clear() {
626
+ n_outputs = 0;
627
+
628
+ batch = {};
629
+
630
+ pos .clear();
631
+ n_seq_id .clear();
632
+ seq_id .clear();
633
+ seq_id_unq.clear();
634
+ output .clear();
635
+
636
+ for (auto & cur : seq_pos) {
637
+ cur.clear();
638
+ }
639
+
640
+ for (auto & cur : seq_cpl) {
641
+ std::fill(cur.begin(), cur.end(), false);
642
+ }
643
+
644
+ seq_set.clear();
645
+
646
+ seq_set_map.clear();
647
+
648
+ std::fill(seq_idx.begin(), seq_idx.end(), -1);
649
+ }
650
+
651
+ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs) {
652
+ const uint32_t n_tokens = idxs.size();
653
+
654
+ assert(n_tokens%n_seqs == 0);
655
+
656
+ auto udata = std::make_shared<llama_ubatch::data_t>();
657
+
658
+ const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
659
+
660
+ const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
661
+ const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
662
+
663
+ udata->token .resize(n_tokens);
664
+ udata->embd .resize(n_embd_all);
665
+ udata->pos .resize(n_pos_all);
666
+ udata->n_seq_id .resize(n_tokens);
667
+ udata->seq_id .resize(n_tokens);
668
+ udata->seq_id_unq.resize(0);
669
+ udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
670
+ udata->output .resize(n_tokens);
671
+
672
+ seq_set_t seq_set_unq;
673
+
674
+ for (size_t i = 0; i < idxs.size(); ++i) {
675
+ if (batch.token) {
676
+ udata->token[i] = batch.token[idxs[i]];
677
+ }
678
+
679
+ if (batch.embd) {
680
+ memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
681
+ }
682
+
683
+ for (int j = 0; j < n_pos_cur; ++j) {
684
+ udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
685
+ }
686
+
687
+ udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
688
+ udata->seq_id[i] = batch.seq_id[idxs[i]];
689
+ udata->output[i] = batch.logits[idxs[i]];
690
+
691
+ for (int s = 0; s < udata->n_seq_id[i]; ++s) {
692
+ seq_set_unq.set(udata->seq_id[i][s]);
693
+ }
694
+
695
+ if (udata->output[i]) {
696
+ out_ids.push_back(idxs[i]);
697
+ }
698
+ }
699
+
700
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
701
+ if (seq_set_unq.test(s)) {
702
+ udata->seq_idx[s] = udata->seq_id_unq.size();
703
+ udata->seq_id_unq.push_back(s);
704
+ }
705
+ }
706
+
707
+ llama_ubatch res {
708
+ /*.b_equal_seqs =*/ equal_seqs,
709
+ /*.n_tokens =*/ n_tokens,
710
+ /*.n_seq_tokens =*/ n_tokens/n_seqs,
711
+ /*.n_seqs =*/ n_seqs,
712
+ /*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
713
+
714
+ /*.token =*/ batch.token ? udata->token.data() : nullptr,
715
+ /*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
716
+ /*.pos =*/ udata->pos.data(),
717
+ /*.n_seq_id =*/ udata->n_seq_id.data(),
718
+ /*.seq_id =*/ udata->seq_id.data(),
719
+ /*.seq_id_unq =*/ udata->seq_id_unq.data(),
720
+ /*.seq_idx =*/ udata->seq_idx.data(),
721
+ /*.output =*/ udata->output.data(),
722
+ /*.data =*/ std::move(udata),
723
+ };
724
+
725
+ if (debug > 0) {
726
+ LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__);
727
+
728
+ ubatch_print(res, debug);
729
+ }
730
+
731
+ return res;
732
+ }
733
+
734
+ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
735
+ if (debug > 0) {
736
+ LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs());
737
+ LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
738
+ LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
739
+ LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
740
+ LLAMA_LOG_DEBUG("%s: n_seqs_unq = %d\n", __func__, ubatch.n_seqs_unq);
741
+
742
+ std::stringstream ss_seq_id_unq;
743
+ std::stringstream ss_seq_idx;
744
+
745
+ ss_seq_id_unq << "[ ";
746
+ ss_seq_idx << "[";
747
+
748
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
749
+ ss_seq_id_unq << ubatch.seq_id_unq[s] << " ";
750
+ }
751
+
752
+ for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
753
+ if (ubatch.seq_idx[s] >= 0) {
754
+ ss_seq_idx << ubatch.seq_idx[s]%10;
755
+ } else {
756
+ ss_seq_idx << ".";
757
+ }
758
+ }
759
+
760
+ ss_seq_id_unq << "]";
761
+ ss_seq_idx << "]";
762
+
763
+ LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) ubatch.token);
764
+ LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) ubatch.embd);
765
+ LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) ubatch.pos);
766
+ LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) ubatch.n_seq_id);
767
+ LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) ubatch.seq_id);
768
+ LLAMA_LOG_DEBUG("%s: seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str());
769
+ LLAMA_LOG_DEBUG("%s: seq_idx = %s\n", __func__, ss_seq_idx.str().c_str());
770
+ LLAMA_LOG_DEBUG("%s: output = %p\n", __func__, (void *) ubatch.output);
771
+ LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
772
+
773
+ if (debug > 1) {
774
+ int seq_id_max = 0;
775
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
776
+ for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
777
+ for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
778
+ seq_id_max = std::max(seq_id_max, ubatch.seq_id[i][s]);
779
+ }
780
+ }
781
+ }
782
+ ++seq_id_max;
783
+
784
+ LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
785
+ for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
786
+ std::vector<int8_t> seq_id(seq_id_max);
787
+
788
+ for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
789
+ seq_id[ubatch.seq_id[i][s]] = 1;
790
+ }
791
+
792
+ std::stringstream ss;
793
+ for (int s = 0; s < seq_id_max; ++s) {
794
+ if (seq_id[s]) {
795
+ ss << s%10;
796
+ } else {
797
+ ss << ".";
798
+ }
799
+ }
800
+
801
+ if (ubatch.token) {
802
+ LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
803
+ __func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(),
804
+ ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
805
+ } else {
806
+ LLAMA_LOG_DEBUG("%s: %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
807
+ __func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
808
+ }
809
+ }
810
+ LLAMA_LOG_DEBUG("%s: ]\n", __func__);
811
+ }
812
+ }
813
+ }
814
+
815
+ //
816
+ // interface implementation
817
+ //
818
+
819
+ struct llama_batch llama_batch_get_one(
820
+ llama_token * tokens,
821
+ int32_t n_tokens) {
822
+ return {
823
+ /*n_tokens =*/ n_tokens,
824
+ /*tokens =*/ tokens,
825
+ /*embd =*/ nullptr,
826
+ /*pos =*/ nullptr,
827
+ /*n_seq_id =*/ nullptr,
828
+ /*seq_id =*/ nullptr,
829
+ /*logits =*/ nullptr,
830
+ };
831
+ }
832
+
833
+ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
834
+ llama_batch batch = {
835
+ /*n_tokens =*/ 0,
836
+ /*tokens =*/ nullptr,
837
+ /*embd =*/ nullptr,
838
+ /*pos =*/ nullptr,
839
+ /*n_seq_id =*/ nullptr,
840
+ /*seq_id =*/ nullptr,
841
+ /*logits =*/ nullptr,
842
+ };
843
+
844
+ if (embd) {
845
+ batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
846
+ } else {
847
+ batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
848
+ }
849
+
850
+ batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
851
+ batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
852
+ batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
853
+ for (int i = 0; i < n_tokens_alloc; ++i) {
854
+ batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
855
+ }
856
+ batch.seq_id[n_tokens_alloc] = nullptr;
857
+
858
+ batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
859
+
860
+ return batch;
861
+ }
862
+
863
+ void llama_batch_free(struct llama_batch batch) {
864
+ if (batch.token) free(batch.token);
865
+ if (batch.embd) free(batch.embd);
866
+ if (batch.pos) free(batch.pos);
867
+ if (batch.n_seq_id) free(batch.n_seq_id);
868
+ if (batch.seq_id) {
869
+ for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
870
+ free(batch.seq_id[i]);
871
+ }
872
+ free(batch.seq_id);
873
+ }
874
+ if (batch.logits) free(batch.logits);
875
+ }