@fugood/llama.node 1.0.3 → 1.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +14 -14
- package/src/llama.cpp/common/CMakeLists.txt +4 -5
- package/src/llama.cpp/common/arg.cpp +37 -0
- package/src/llama.cpp/common/common.cpp +22 -6
- package/src/llama.cpp/common/common.h +14 -1
- package/src/llama.cpp/ggml/CMakeLists.txt +3 -0
- package/src/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/src/llama.cpp/ggml/include/ggml.h +13 -0
- package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +23 -8
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +39 -0
- package/src/llama.cpp/include/llama.h +13 -48
- package/src/llama.cpp/src/llama-arch.cpp +222 -15
- package/src/llama.cpp/src/llama-arch.h +16 -1
- package/src/llama.cpp/src/llama-batch.cpp +76 -70
- package/src/llama.cpp/src/llama-batch.h +24 -18
- package/src/llama.cpp/src/llama-chat.cpp +44 -1
- package/src/llama.cpp/src/llama-chat.h +2 -0
- package/src/llama.cpp/src/llama-context.cpp +134 -95
- package/src/llama.cpp/src/llama-context.h +13 -16
- package/src/llama.cpp/src/llama-cparams.h +3 -2
- package/src/llama.cpp/src/llama-graph.cpp +239 -154
- package/src/llama.cpp/src/llama-graph.h +162 -126
- package/src/llama.cpp/src/llama-hparams.cpp +45 -0
- package/src/llama.cpp/src/llama-hparams.h +11 -1
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +11 -5
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +3 -0
- package/src/llama.cpp/src/llama-kv-cache-unified.cpp +698 -302
- package/src/llama.cpp/src/llama-kv-cache-unified.h +89 -31
- package/src/llama.cpp/src/llama-memory-hybrid.cpp +1 -0
- package/src/llama.cpp/src/llama-memory-recurrent.cpp +6 -9
- package/src/llama.cpp/src/llama-model.cpp +2309 -665
- package/src/llama.cpp/src/llama-model.h +18 -4
- package/src/llama.cpp/src/llama-quant.cpp +2 -2
- package/src/llama.cpp/src/llama-vocab.cpp +368 -9
- package/src/llama.cpp/src/llama-vocab.h +43 -0
- package/src/llama.cpp/src/unicode.cpp +207 -0
- package/src/llama.cpp/src/unicode.h +2 -0
|
@@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
|
|
|
27
27
|
const llama_vocab & vocab,
|
|
28
28
|
const llama_memory_i * memory,
|
|
29
29
|
uint32_t n_embd,
|
|
30
|
+
uint32_t n_seq_max,
|
|
30
31
|
bool output_all) {
|
|
31
32
|
clear();
|
|
32
33
|
|
|
@@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
|
|
|
40
41
|
// validate input batch
|
|
41
42
|
//
|
|
42
43
|
|
|
44
|
+
if (n_seq_max > LLAMA_MAX_SEQ) {
|
|
45
|
+
LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
|
|
46
|
+
return false;
|
|
47
|
+
}
|
|
48
|
+
|
|
43
49
|
if (batch.token) {
|
|
44
50
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
45
51
|
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
|
|
@@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
|
|
|
52
58
|
if (batch.seq_id) {
|
|
53
59
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
54
60
|
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
|
55
|
-
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >=
|
|
56
|
-
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s],
|
|
61
|
+
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
|
|
62
|
+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
|
|
57
63
|
return false;
|
|
58
64
|
}
|
|
59
65
|
}
|
|
@@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
|
|
|
86
92
|
|
|
87
93
|
// initialize the starting position for each sequence based on the positions in the memory
|
|
88
94
|
llama_pos p0[LLAMA_MAX_SEQ];
|
|
89
|
-
for (
|
|
95
|
+
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
|
90
96
|
if (!memory) {
|
|
91
97
|
// if no memory -> start from 0
|
|
92
98
|
p0[s] = 0;
|
|
@@ -143,13 +149,16 @@ bool llama_batch_allocr::init(
|
|
|
143
149
|
// compute stats
|
|
144
150
|
//
|
|
145
151
|
|
|
146
|
-
this->n_embd
|
|
152
|
+
this->n_embd = n_embd;
|
|
153
|
+
this->n_seq_max = n_seq_max;
|
|
147
154
|
|
|
148
155
|
// count the outputs in this batch
|
|
149
156
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
150
157
|
n_outputs += batch.logits[i] != 0;
|
|
151
158
|
}
|
|
152
159
|
|
|
160
|
+
has_cpl = false;
|
|
161
|
+
|
|
153
162
|
// determine coupled sequences
|
|
154
163
|
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
|
|
155
164
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
@@ -189,7 +198,7 @@ bool llama_batch_allocr::init(
|
|
|
189
198
|
seq_set_map[cur].push_back(i);
|
|
190
199
|
}
|
|
191
200
|
|
|
192
|
-
for (
|
|
201
|
+
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
|
193
202
|
if (seq_set_unq.test(s)) {
|
|
194
203
|
seq_idx[s] = seq_id_unq.size();
|
|
195
204
|
seq_id_unq.push_back(s);
|
|
@@ -201,7 +210,7 @@ bool llama_batch_allocr::init(
|
|
|
201
210
|
LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
|
|
202
211
|
|
|
203
212
|
llama_ubatch ubatch {
|
|
204
|
-
/*.
|
|
213
|
+
/*.b_equal_seqs =*/ false,
|
|
205
214
|
/*.n_tokens =*/ (uint32_t) batch.n_tokens,
|
|
206
215
|
/*.n_seq_tokens =*/ (uint32_t) 1,
|
|
207
216
|
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
|
|
@@ -214,6 +223,7 @@ bool llama_batch_allocr::init(
|
|
|
214
223
|
/*.seq_id_unq =*/ this->seq_id_unq.data(),
|
|
215
224
|
/*.seq_idx =*/ this->seq_idx.data(),
|
|
216
225
|
/*.output =*/ batch.logits,
|
|
226
|
+
/*.data =*/ {},
|
|
217
227
|
};
|
|
218
228
|
|
|
219
229
|
ubatch_print(ubatch, debug);
|
|
@@ -241,7 +251,7 @@ bool llama_batch_allocr::init(
|
|
|
241
251
|
// consistency checks
|
|
242
252
|
//
|
|
243
253
|
|
|
244
|
-
for (
|
|
254
|
+
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
|
245
255
|
if (seq_pos[s].empty()) {
|
|
246
256
|
continue;
|
|
247
257
|
}
|
|
@@ -284,8 +294,8 @@ bool llama_batch_allocr::init(
|
|
|
284
294
|
}
|
|
285
295
|
|
|
286
296
|
if (memory) {
|
|
287
|
-
for (
|
|
288
|
-
for (
|
|
297
|
+
for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
|
|
298
|
+
for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
|
|
289
299
|
if (seq_cpl[s0][s1]) {
|
|
290
300
|
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
|
|
291
301
|
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
|
|
@@ -316,12 +326,12 @@ bool llama_batch_allocr::init(
|
|
|
316
326
|
//
|
|
317
327
|
{
|
|
318
328
|
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
|
|
319
|
-
for (
|
|
329
|
+
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
|
320
330
|
cur_seq_set[s].set();
|
|
321
331
|
}
|
|
322
332
|
|
|
323
333
|
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
|
|
324
|
-
for (
|
|
334
|
+
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
|
325
335
|
cur_seq_pos[s] = -1;
|
|
326
336
|
}
|
|
327
337
|
|
|
@@ -357,39 +367,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
|
|
|
357
367
|
clear();
|
|
358
368
|
split_reset();
|
|
359
369
|
|
|
360
|
-
|
|
370
|
+
auto udata = std::make_shared<llama_ubatch::data_t>();
|
|
361
371
|
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
|
371
|
-
ubatch.output .resize(n_tokens);
|
|
372
|
+
udata->token .resize(n_tokens);
|
|
373
|
+
udata->embd .clear();
|
|
374
|
+
udata->pos .resize(n_tokens);
|
|
375
|
+
udata->n_seq_id .resize(n_tokens);
|
|
376
|
+
udata->seq_id .resize(n_tokens);
|
|
377
|
+
udata->seq_id_unq.resize(0);
|
|
378
|
+
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
|
379
|
+
udata->output .resize(n_tokens);
|
|
372
380
|
|
|
373
381
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
374
|
-
|
|
375
|
-
|
|
382
|
+
udata->seq_idx[s] = s;
|
|
383
|
+
udata->seq_id_unq.push_back(s);
|
|
376
384
|
}
|
|
377
385
|
|
|
378
386
|
llama_ubatch res {
|
|
379
|
-
/*.
|
|
387
|
+
/*.b_equal_seqs =*/ true,
|
|
380
388
|
/*.n_tokens =*/ n_tokens,
|
|
381
389
|
/*.n_seq_tokens =*/ n_seq_tokens,
|
|
382
390
|
/*.n_seqs =*/ n_seqs,
|
|
383
391
|
/*.n_seqs_unq =*/ n_seqs,
|
|
384
392
|
|
|
385
|
-
/*.token =*/
|
|
393
|
+
/*.token =*/ udata->token.data(),
|
|
386
394
|
/*.embd =*/ nullptr,
|
|
387
|
-
/*.pos =*/
|
|
388
|
-
/*.n_seq_id =*/
|
|
389
|
-
/*.seq_id =*/
|
|
390
|
-
/*.seq_id_unq =*/
|
|
391
|
-
/*.seq_idx =*/
|
|
392
|
-
/*.output =*/
|
|
395
|
+
/*.pos =*/ udata->pos.data(),
|
|
396
|
+
/*.n_seq_id =*/ udata->n_seq_id.data(),
|
|
397
|
+
/*.seq_id =*/ udata->seq_id.data(),
|
|
398
|
+
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
|
|
399
|
+
/*.seq_idx =*/ udata->seq_idx.data(),
|
|
400
|
+
/*.output =*/ udata->output.data(),
|
|
401
|
+
/*.data =*/ std::move(udata),
|
|
393
402
|
};
|
|
394
403
|
|
|
395
404
|
return res;
|
|
@@ -430,8 +439,6 @@ void llama_batch_allocr::split_reset() {
|
|
|
430
439
|
|
|
431
440
|
used.clear();
|
|
432
441
|
used.resize(get_n_tokens(), false);
|
|
433
|
-
|
|
434
|
-
ubatches.clear();
|
|
435
442
|
}
|
|
436
443
|
|
|
437
444
|
llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
|
@@ -646,78 +653,77 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|
|
646
653
|
|
|
647
654
|
assert(n_tokens%n_seqs == 0);
|
|
648
655
|
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
auto & ubatch = ubatches.back();
|
|
656
|
+
auto udata = std::make_shared<llama_ubatch::data_t>();
|
|
652
657
|
|
|
653
658
|
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
|
|
654
659
|
|
|
655
660
|
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
|
|
656
661
|
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
|
|
657
662
|
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
663
|
+
udata->token .resize(n_tokens);
|
|
664
|
+
udata->embd .resize(n_embd_all);
|
|
665
|
+
udata->pos .resize(n_pos_all);
|
|
666
|
+
udata->n_seq_id .resize(n_tokens);
|
|
667
|
+
udata->seq_id .resize(n_tokens);
|
|
668
|
+
udata->seq_id_unq.resize(0);
|
|
669
|
+
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
|
670
|
+
udata->output .resize(n_tokens);
|
|
666
671
|
|
|
667
672
|
seq_set_t seq_set_unq;
|
|
668
673
|
|
|
669
674
|
for (size_t i = 0; i < idxs.size(); ++i) {
|
|
670
675
|
if (batch.token) {
|
|
671
|
-
|
|
676
|
+
udata->token[i] = batch.token[idxs[i]];
|
|
672
677
|
}
|
|
673
678
|
|
|
674
679
|
if (batch.embd) {
|
|
675
|
-
memcpy(
|
|
680
|
+
memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
|
|
676
681
|
}
|
|
677
682
|
|
|
678
683
|
for (int j = 0; j < n_pos_cur; ++j) {
|
|
679
|
-
|
|
684
|
+
udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
|
|
680
685
|
}
|
|
681
686
|
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
687
|
+
udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
|
|
688
|
+
udata->seq_id[i] = batch.seq_id[idxs[i]];
|
|
689
|
+
udata->output[i] = batch.logits[idxs[i]];
|
|
685
690
|
|
|
686
|
-
for (int s = 0; s <
|
|
687
|
-
seq_set_unq.set(
|
|
691
|
+
for (int s = 0; s < udata->n_seq_id[i]; ++s) {
|
|
692
|
+
seq_set_unq.set(udata->seq_id[i][s]);
|
|
688
693
|
}
|
|
689
694
|
|
|
690
|
-
if (
|
|
695
|
+
if (udata->output[i]) {
|
|
691
696
|
out_ids.push_back(idxs[i]);
|
|
692
697
|
}
|
|
693
698
|
}
|
|
694
699
|
|
|
695
|
-
for (
|
|
700
|
+
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
|
696
701
|
if (seq_set_unq.test(s)) {
|
|
697
|
-
|
|
698
|
-
|
|
702
|
+
udata->seq_idx[s] = udata->seq_id_unq.size();
|
|
703
|
+
udata->seq_id_unq.push_back(s);
|
|
699
704
|
}
|
|
700
705
|
}
|
|
701
706
|
|
|
702
707
|
llama_ubatch res {
|
|
703
|
-
/*.
|
|
708
|
+
/*.b_equal_seqs =*/ equal_seqs,
|
|
704
709
|
/*.n_tokens =*/ n_tokens,
|
|
705
710
|
/*.n_seq_tokens =*/ n_tokens/n_seqs,
|
|
706
711
|
/*.n_seqs =*/ n_seqs,
|
|
707
|
-
/*.n_seqs_unq =*/ (uint32_t)
|
|
708
|
-
|
|
709
|
-
/*.token =*/ batch.token ?
|
|
710
|
-
/*.embd =*/ batch.embd ?
|
|
711
|
-
/*.pos =*/
|
|
712
|
-
/*.n_seq_id =*/
|
|
713
|
-
/*.seq_id =*/
|
|
714
|
-
/*.seq_id_unq =*/
|
|
715
|
-
/*.seq_idx =*/
|
|
716
|
-
/*.output =*/
|
|
712
|
+
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
|
|
713
|
+
|
|
714
|
+
/*.token =*/ batch.token ? udata->token.data() : nullptr,
|
|
715
|
+
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
|
|
716
|
+
/*.pos =*/ udata->pos.data(),
|
|
717
|
+
/*.n_seq_id =*/ udata->n_seq_id.data(),
|
|
718
|
+
/*.seq_id =*/ udata->seq_id.data(),
|
|
719
|
+
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
|
|
720
|
+
/*.seq_idx =*/ udata->seq_idx.data(),
|
|
721
|
+
/*.output =*/ udata->output.data(),
|
|
722
|
+
/*.data =*/ std::move(udata),
|
|
717
723
|
};
|
|
718
724
|
|
|
719
725
|
if (debug > 0) {
|
|
720
|
-
LLAMA_LOG_DEBUG("%s: added ubatch
|
|
726
|
+
LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__);
|
|
721
727
|
|
|
722
728
|
ubatch_print(res, debug);
|
|
723
729
|
}
|
|
@@ -727,7 +733,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|
|
727
733
|
|
|
728
734
|
void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
|
|
729
735
|
if (debug > 0) {
|
|
730
|
-
LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs);
|
|
736
|
+
LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs());
|
|
731
737
|
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
|
|
732
738
|
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
|
|
733
739
|
LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
|
|
@@ -8,12 +8,17 @@
|
|
|
8
8
|
#include <vector>
|
|
9
9
|
#include <set>
|
|
10
10
|
#include <bitset>
|
|
11
|
+
#include <memory>
|
|
11
12
|
#include <unordered_map>
|
|
12
13
|
|
|
13
14
|
// keep this struct lightweight
|
|
14
|
-
// it points to data in `llama_batch_allocr`
|
|
15
15
|
struct llama_ubatch {
|
|
16
|
-
bool equal_seqs
|
|
16
|
+
bool equal_seqs() const {
|
|
17
|
+
return b_equal_seqs != 0;
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
|
|
21
|
+
// otherwise address sanitizer complains
|
|
17
22
|
// TODO: whole_seqs for embeddings?
|
|
18
23
|
|
|
19
24
|
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
|
@@ -34,6 +39,20 @@ struct llama_ubatch {
|
|
|
34
39
|
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
|
|
35
40
|
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
|
|
36
41
|
int8_t * output; // [n_tokens] | i | -
|
|
42
|
+
|
|
43
|
+
struct data_t {
|
|
44
|
+
std::vector<llama_token> token;
|
|
45
|
+
std::vector<float> embd;
|
|
46
|
+
std::vector<llama_pos> pos;
|
|
47
|
+
std::vector<int32_t> n_seq_id;
|
|
48
|
+
std::vector<llama_seq_id *> seq_id;
|
|
49
|
+
std::vector<llama_seq_id> seq_id_unq;
|
|
50
|
+
std::vector<int32_t> seq_idx;
|
|
51
|
+
std::vector<int8_t> output;
|
|
52
|
+
};
|
|
53
|
+
|
|
54
|
+
// the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
|
|
55
|
+
std::shared_ptr<data_t> data;
|
|
37
56
|
};
|
|
38
57
|
|
|
39
58
|
// a helper for sanitizing, fulfilling and splitting a batch
|
|
@@ -48,6 +67,7 @@ public:
|
|
|
48
67
|
const llama_vocab & vocab,
|
|
49
68
|
const llama_memory_i * memory,
|
|
50
69
|
uint32_t n_embd,
|
|
70
|
+
uint32_t n_seq_max,
|
|
51
71
|
bool output_all);
|
|
52
72
|
|
|
53
73
|
const llama_batch & get_batch() const;
|
|
@@ -100,6 +120,7 @@ private:
|
|
|
100
120
|
const uint32_t n_pos_per_embd;
|
|
101
121
|
|
|
102
122
|
uint32_t n_embd;
|
|
123
|
+
uint32_t n_seq_max;
|
|
103
124
|
uint32_t n_outputs;
|
|
104
125
|
|
|
105
126
|
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
|
@@ -115,7 +136,7 @@ private:
|
|
|
115
136
|
using seq_cpl_t = std::vector<bool>;
|
|
116
137
|
|
|
117
138
|
// helper flag to quickly determine if there are any coupled sequences in the batch
|
|
118
|
-
bool has_cpl;
|
|
139
|
+
bool has_cpl = false;
|
|
119
140
|
|
|
120
141
|
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
|
121
142
|
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
|
@@ -135,20 +156,5 @@ private:
|
|
|
135
156
|
// used[i] indicates if token i has already been used in a previous ubatch
|
|
136
157
|
std::vector<bool> used;
|
|
137
158
|
|
|
138
|
-
// llama_ubatch points to this data:
|
|
139
|
-
struct ubatch {
|
|
140
|
-
std::vector<llama_token> token;
|
|
141
|
-
std::vector<float> embd;
|
|
142
|
-
std::vector<llama_pos> pos;
|
|
143
|
-
std::vector<int32_t> n_seq_id;
|
|
144
|
-
std::vector<llama_seq_id *> seq_id;
|
|
145
|
-
std::vector<llama_seq_id> seq_id_unq;
|
|
146
|
-
std::vector<int32_t> seq_idx;
|
|
147
|
-
std::vector<int8_t> output;
|
|
148
|
-
};
|
|
149
|
-
|
|
150
|
-
// current splitting state:
|
|
151
|
-
std::vector<ubatch> ubatches;
|
|
152
|
-
|
|
153
159
|
int debug;
|
|
154
160
|
};
|
|
@@ -56,6 +56,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|
|
56
56
|
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
|
|
57
57
|
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
|
|
58
58
|
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
|
|
59
|
+
{ "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 },
|
|
59
60
|
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
|
|
60
61
|
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
|
|
61
62
|
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
|
@@ -65,6 +66,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|
|
65
66
|
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
|
66
67
|
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
|
67
68
|
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
|
|
69
|
+
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
|
|
68
70
|
};
|
|
69
71
|
|
|
70
72
|
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
|
@@ -167,10 +169,13 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|
|
167
169
|
} else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) {
|
|
168
170
|
return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
|
|
169
171
|
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
|
|
172
|
+
if (tmpl_contains("[|tool|]")) {
|
|
173
|
+
return LLM_CHAT_TEMPLATE_EXAONE_4;
|
|
174
|
+
}
|
|
170
175
|
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
|
|
171
176
|
// EXAONE-3.0-7.8B-Instruct
|
|
172
177
|
return LLM_CHAT_TEMPLATE_EXAONE_3;
|
|
173
|
-
} else if (tmpl_contains("rwkv-world")) {
|
|
178
|
+
} else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) {
|
|
174
179
|
return LLM_CHAT_TEMPLATE_RWKV_WORLD;
|
|
175
180
|
} else if (tmpl_contains("<|start_of_role|>")) {
|
|
176
181
|
return LLM_CHAT_TEMPLATE_GRANITE;
|
|
@@ -188,6 +193,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|
|
188
193
|
return LLM_CHAT_TEMPLATE_DOTS1;
|
|
189
194
|
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
|
|
190
195
|
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
|
|
196
|
+
} else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
|
|
197
|
+
return LLM_CHAT_TEMPLATE_KIMI_K2;
|
|
191
198
|
}
|
|
192
199
|
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
|
193
200
|
}
|
|
@@ -529,6 +536,22 @@ int32_t llm_chat_apply_template(
|
|
|
529
536
|
if (add_ass) {
|
|
530
537
|
ss << "[|assistant|]";
|
|
531
538
|
}
|
|
539
|
+
} else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_4) {
|
|
540
|
+
for (auto message : chat) {
|
|
541
|
+
std::string role(message->role);
|
|
542
|
+
if (role == "system") {
|
|
543
|
+
ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
|
|
544
|
+
} else if (role == "user") {
|
|
545
|
+
ss << "[|user|]" << trim(message->content) << "\n";
|
|
546
|
+
} else if (role == "assistant") {
|
|
547
|
+
ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
|
|
548
|
+
} else if (role == "tool") {
|
|
549
|
+
ss << "[|tool|]" << trim(message->content) << "[|endofturn|]\n";
|
|
550
|
+
}
|
|
551
|
+
}
|
|
552
|
+
if (add_ass) {
|
|
553
|
+
ss << "[|assistant|]";
|
|
554
|
+
}
|
|
532
555
|
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
|
|
533
556
|
// this template requires the model to have "\n\n" as EOT token
|
|
534
557
|
for (size_t i = 0; i < chat.size(); i++) {
|
|
@@ -680,6 +703,26 @@ int32_t llm_chat_apply_template(
|
|
|
680
703
|
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
|
|
681
704
|
}
|
|
682
705
|
}
|
|
706
|
+
} else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) {
|
|
707
|
+
// moonshotai/Kimi-K2-Instruct
|
|
708
|
+
for (auto message : chat) {
|
|
709
|
+
std::string role(message->role);
|
|
710
|
+
if (role == "system") {
|
|
711
|
+
ss << "<|im_system|>system<|im_middle|>";
|
|
712
|
+
} else if (role == "user") {
|
|
713
|
+
ss << "<|im_user|>user<|im_middle|>";
|
|
714
|
+
} else if (role == "assistant") {
|
|
715
|
+
ss << "<|im_assistant|>assistant<|im_middle|>";
|
|
716
|
+
} else if (role == "tool") {
|
|
717
|
+
ss << "<|im_system|>tool<|im_middle|>";
|
|
718
|
+
}
|
|
719
|
+
|
|
720
|
+
ss << message->content << "<|im_end|>";
|
|
721
|
+
|
|
722
|
+
if (add_ass) {
|
|
723
|
+
ss << "<|im_assistant|>assistant<|im_middle|>";
|
|
724
|
+
}
|
|
725
|
+
}
|
|
683
726
|
} else {
|
|
684
727
|
// template not supported
|
|
685
728
|
return -1;
|
|
@@ -35,6 +35,7 @@ enum llm_chat_template {
|
|
|
35
35
|
LLM_CHAT_TEMPLATE_GLMEDGE,
|
|
36
36
|
LLM_CHAT_TEMPLATE_MINICPM,
|
|
37
37
|
LLM_CHAT_TEMPLATE_EXAONE_3,
|
|
38
|
+
LLM_CHAT_TEMPLATE_EXAONE_4,
|
|
38
39
|
LLM_CHAT_TEMPLATE_RWKV_WORLD,
|
|
39
40
|
LLM_CHAT_TEMPLATE_GRANITE,
|
|
40
41
|
LLM_CHAT_TEMPLATE_GIGACHAT,
|
|
@@ -45,6 +46,7 @@ enum llm_chat_template {
|
|
|
45
46
|
LLM_CHAT_TEMPLATE_SMOLVLM,
|
|
46
47
|
LLM_CHAT_TEMPLATE_DOTS1,
|
|
47
48
|
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
|
|
49
|
+
LLM_CHAT_TEMPLATE_KIMI_K2,
|
|
48
50
|
LLM_CHAT_TEMPLATE_UNKNOWN,
|
|
49
51
|
};
|
|
50
52
|
|