@fugood/llama.node 1.0.2 → 1.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +14 -14
- package/src/llama.cpp/CMakeLists.txt +0 -1
- package/src/llama.cpp/common/CMakeLists.txt +4 -5
- package/src/llama.cpp/common/arg.cpp +44 -0
- package/src/llama.cpp/common/common.cpp +22 -6
- package/src/llama.cpp/common/common.h +15 -1
- package/src/llama.cpp/ggml/CMakeLists.txt +10 -2
- package/src/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/src/llama.cpp/ggml/include/ggml.h +104 -10
- package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +12 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +749 -163
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +5 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +12 -9
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +88 -9
- package/src/llama.cpp/include/llama.h +13 -47
- package/src/llama.cpp/src/llama-arch.cpp +298 -3
- package/src/llama.cpp/src/llama-arch.h +22 -1
- package/src/llama.cpp/src/llama-batch.cpp +103 -71
- package/src/llama.cpp/src/llama-batch.h +31 -18
- package/src/llama.cpp/src/llama-chat.cpp +59 -1
- package/src/llama.cpp/src/llama-chat.h +3 -0
- package/src/llama.cpp/src/llama-context.cpp +134 -95
- package/src/llama.cpp/src/llama-context.h +13 -16
- package/src/llama.cpp/src/llama-cparams.h +3 -2
- package/src/llama.cpp/src/llama-graph.cpp +279 -180
- package/src/llama.cpp/src/llama-graph.h +183 -122
- package/src/llama.cpp/src/llama-hparams.cpp +47 -1
- package/src/llama.cpp/src/llama-hparams.h +12 -1
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
- package/src/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
- package/src/llama.cpp/src/llama-kv-cache-unified.h +143 -47
- package/src/llama.cpp/src/llama-kv-cells.h +62 -10
- package/src/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
- package/src/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/src/llama.cpp/src/llama-memory-recurrent.cpp +21 -11
- package/src/llama.cpp/src/llama-memory.cpp +17 -0
- package/src/llama.cpp/src/llama-memory.h +3 -0
- package/src/llama.cpp/src/llama-model.cpp +3373 -743
- package/src/llama.cpp/src/llama-model.h +20 -4
- package/src/llama.cpp/src/llama-quant.cpp +2 -2
- package/src/llama.cpp/src/llama-vocab.cpp +376 -10
- package/src/llama.cpp/src/llama-vocab.h +43 -0
- package/src/llama.cpp/src/unicode.cpp +207 -0
- package/src/llama.cpp/src/unicode.h +2 -0
- package/src/llama.cpp/ggml/include/ggml-kompute.h +0 -50
|
@@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
|
|
|
27
27
|
const llama_vocab & vocab,
|
|
28
28
|
const llama_memory_i * memory,
|
|
29
29
|
uint32_t n_embd,
|
|
30
|
+
uint32_t n_seq_max,
|
|
30
31
|
bool output_all) {
|
|
31
32
|
clear();
|
|
32
33
|
|
|
@@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
|
|
|
40
41
|
// validate input batch
|
|
41
42
|
//
|
|
42
43
|
|
|
44
|
+
if (n_seq_max > LLAMA_MAX_SEQ) {
|
|
45
|
+
LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
|
|
46
|
+
return false;
|
|
47
|
+
}
|
|
48
|
+
|
|
43
49
|
if (batch.token) {
|
|
44
50
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
45
51
|
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
|
|
@@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
|
|
|
52
58
|
if (batch.seq_id) {
|
|
53
59
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
54
60
|
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
|
55
|
-
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >=
|
|
56
|
-
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s],
|
|
61
|
+
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
|
|
62
|
+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
|
|
57
63
|
return false;
|
|
58
64
|
}
|
|
59
65
|
}
|
|
@@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
|
|
|
86
92
|
|
|
87
93
|
// initialize the starting position for each sequence based on the positions in the memory
|
|
88
94
|
llama_pos p0[LLAMA_MAX_SEQ];
|
|
89
|
-
for (
|
|
95
|
+
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
|
90
96
|
if (!memory) {
|
|
91
97
|
// if no memory -> start from 0
|
|
92
98
|
p0[s] = 0;
|
|
@@ -143,13 +149,16 @@ bool llama_batch_allocr::init(
|
|
|
143
149
|
// compute stats
|
|
144
150
|
//
|
|
145
151
|
|
|
146
|
-
this->n_embd
|
|
152
|
+
this->n_embd = n_embd;
|
|
153
|
+
this->n_seq_max = n_seq_max;
|
|
147
154
|
|
|
148
155
|
// count the outputs in this batch
|
|
149
156
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
150
157
|
n_outputs += batch.logits[i] != 0;
|
|
151
158
|
}
|
|
152
159
|
|
|
160
|
+
has_cpl = false;
|
|
161
|
+
|
|
153
162
|
// determine coupled sequences
|
|
154
163
|
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
|
|
155
164
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
@@ -166,6 +175,8 @@ bool llama_batch_allocr::init(
|
|
|
166
175
|
|
|
167
176
|
// note: tracking the other way around is not necessary for now
|
|
168
177
|
//seq_cpl[s0][s1] = true;
|
|
178
|
+
|
|
179
|
+
has_cpl = true;
|
|
169
180
|
}
|
|
170
181
|
}
|
|
171
182
|
}
|
|
@@ -187,7 +198,7 @@ bool llama_batch_allocr::init(
|
|
|
187
198
|
seq_set_map[cur].push_back(i);
|
|
188
199
|
}
|
|
189
200
|
|
|
190
|
-
for (
|
|
201
|
+
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
|
191
202
|
if (seq_set_unq.test(s)) {
|
|
192
203
|
seq_idx[s] = seq_id_unq.size();
|
|
193
204
|
seq_id_unq.push_back(s);
|
|
@@ -199,7 +210,7 @@ bool llama_batch_allocr::init(
|
|
|
199
210
|
LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
|
|
200
211
|
|
|
201
212
|
llama_ubatch ubatch {
|
|
202
|
-
/*.
|
|
213
|
+
/*.b_equal_seqs =*/ false,
|
|
203
214
|
/*.n_tokens =*/ (uint32_t) batch.n_tokens,
|
|
204
215
|
/*.n_seq_tokens =*/ (uint32_t) 1,
|
|
205
216
|
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
|
|
@@ -212,6 +223,7 @@ bool llama_batch_allocr::init(
|
|
|
212
223
|
/*.seq_id_unq =*/ this->seq_id_unq.data(),
|
|
213
224
|
/*.seq_idx =*/ this->seq_idx.data(),
|
|
214
225
|
/*.output =*/ batch.logits,
|
|
226
|
+
/*.data =*/ {},
|
|
215
227
|
};
|
|
216
228
|
|
|
217
229
|
ubatch_print(ubatch, debug);
|
|
@@ -239,7 +251,7 @@ bool llama_batch_allocr::init(
|
|
|
239
251
|
// consistency checks
|
|
240
252
|
//
|
|
241
253
|
|
|
242
|
-
for (
|
|
254
|
+
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
|
243
255
|
if (seq_pos[s].empty()) {
|
|
244
256
|
continue;
|
|
245
257
|
}
|
|
@@ -282,8 +294,8 @@ bool llama_batch_allocr::init(
|
|
|
282
294
|
}
|
|
283
295
|
|
|
284
296
|
if (memory) {
|
|
285
|
-
for (
|
|
286
|
-
for (
|
|
297
|
+
for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
|
|
298
|
+
for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
|
|
287
299
|
if (seq_cpl[s0][s1]) {
|
|
288
300
|
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
|
|
289
301
|
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
|
|
@@ -314,12 +326,12 @@ bool llama_batch_allocr::init(
|
|
|
314
326
|
//
|
|
315
327
|
{
|
|
316
328
|
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
|
|
317
|
-
for (
|
|
329
|
+
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
|
318
330
|
cur_seq_set[s].set();
|
|
319
331
|
}
|
|
320
332
|
|
|
321
333
|
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
|
|
322
|
-
for (
|
|
334
|
+
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
|
323
335
|
cur_seq_pos[s] = -1;
|
|
324
336
|
}
|
|
325
337
|
|
|
@@ -355,39 +367,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
|
|
|
355
367
|
clear();
|
|
356
368
|
split_reset();
|
|
357
369
|
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
auto & ubatch = ubatches.back();
|
|
370
|
+
auto udata = std::make_shared<llama_ubatch::data_t>();
|
|
361
371
|
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
372
|
+
udata->token .resize(n_tokens);
|
|
373
|
+
udata->embd .clear();
|
|
374
|
+
udata->pos .resize(n_tokens);
|
|
375
|
+
udata->n_seq_id .resize(n_tokens);
|
|
376
|
+
udata->seq_id .resize(n_tokens);
|
|
377
|
+
udata->seq_id_unq.resize(0);
|
|
378
|
+
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
|
379
|
+
udata->output .resize(n_tokens);
|
|
370
380
|
|
|
371
381
|
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
372
|
-
|
|
373
|
-
|
|
382
|
+
udata->seq_idx[s] = s;
|
|
383
|
+
udata->seq_id_unq.push_back(s);
|
|
374
384
|
}
|
|
375
385
|
|
|
376
386
|
llama_ubatch res {
|
|
377
|
-
/*.
|
|
387
|
+
/*.b_equal_seqs =*/ true,
|
|
378
388
|
/*.n_tokens =*/ n_tokens,
|
|
379
389
|
/*.n_seq_tokens =*/ n_seq_tokens,
|
|
380
390
|
/*.n_seqs =*/ n_seqs,
|
|
381
391
|
/*.n_seqs_unq =*/ n_seqs,
|
|
382
392
|
|
|
383
|
-
/*.token =*/
|
|
393
|
+
/*.token =*/ udata->token.data(),
|
|
384
394
|
/*.embd =*/ nullptr,
|
|
385
|
-
/*.pos =*/
|
|
386
|
-
/*.n_seq_id =*/
|
|
387
|
-
/*.seq_id =*/
|
|
388
|
-
/*.seq_id_unq =*/
|
|
389
|
-
/*.seq_idx =*/
|
|
390
|
-
/*.output =*/
|
|
395
|
+
/*.pos =*/ udata->pos.data(),
|
|
396
|
+
/*.n_seq_id =*/ udata->n_seq_id.data(),
|
|
397
|
+
/*.seq_id =*/ udata->seq_id.data(),
|
|
398
|
+
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
|
|
399
|
+
/*.seq_idx =*/ udata->seq_idx.data(),
|
|
400
|
+
/*.output =*/ udata->output.data(),
|
|
401
|
+
/*.data =*/ std::move(udata),
|
|
391
402
|
};
|
|
392
403
|
|
|
393
404
|
return res;
|
|
@@ -405,6 +416,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
|
|
|
405
416
|
return n_outputs;
|
|
406
417
|
}
|
|
407
418
|
|
|
419
|
+
uint32_t llama_batch_allocr::get_n_used() const {
|
|
420
|
+
return n_used;
|
|
421
|
+
}
|
|
422
|
+
|
|
408
423
|
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
|
|
409
424
|
return out_ids;
|
|
410
425
|
}
|
|
@@ -420,10 +435,10 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
|
|
|
420
435
|
void llama_batch_allocr::split_reset() {
|
|
421
436
|
out_ids.clear();
|
|
422
437
|
|
|
438
|
+
n_used = 0;
|
|
439
|
+
|
|
423
440
|
used.clear();
|
|
424
441
|
used.resize(get_n_tokens(), false);
|
|
425
|
-
|
|
426
|
-
ubatches.clear();
|
|
427
442
|
}
|
|
428
443
|
|
|
429
444
|
llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
|
@@ -444,6 +459,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
|
|
444
459
|
idxs.push_back(cur_idx);
|
|
445
460
|
|
|
446
461
|
used[cur_idx] = true;
|
|
462
|
+
++n_used;
|
|
447
463
|
|
|
448
464
|
++cur_idx;
|
|
449
465
|
|
|
@@ -459,9 +475,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
|
|
459
475
|
return ubatch_add(idxs, idxs.size(), false);
|
|
460
476
|
}
|
|
461
477
|
|
|
462
|
-
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
|
478
|
+
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
|
|
479
|
+
if (sequential && has_cpl) {
|
|
480
|
+
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
|
|
481
|
+
|
|
482
|
+
return {};
|
|
483
|
+
}
|
|
484
|
+
|
|
463
485
|
std::vector<seq_set_t> cur_seq_set;
|
|
464
486
|
|
|
487
|
+
llama_seq_id last_seq_id = -1;
|
|
488
|
+
|
|
465
489
|
// determine the non-overlapping sequence sets participating in this ubatch
|
|
466
490
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
467
491
|
if (used[i]) {
|
|
@@ -478,9 +502,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
|
|
478
502
|
}
|
|
479
503
|
}
|
|
480
504
|
|
|
505
|
+
// accept only increasing sequence ids
|
|
506
|
+
if (sequential) {
|
|
507
|
+
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
|
|
508
|
+
}
|
|
509
|
+
|
|
481
510
|
if (add) {
|
|
482
511
|
cur_seq_set.push_back(seq_set[i]);
|
|
483
512
|
|
|
513
|
+
last_seq_id = batch.seq_id[i][0];
|
|
514
|
+
|
|
484
515
|
if (cur_seq_set.size() > n_ubatch) {
|
|
485
516
|
break;
|
|
486
517
|
}
|
|
@@ -529,6 +560,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
|
|
529
560
|
idxs_per_seq[s].push_back(idx);
|
|
530
561
|
|
|
531
562
|
used[idx] = true;
|
|
563
|
+
++n_used;
|
|
532
564
|
|
|
533
565
|
++cur_idx[s];
|
|
534
566
|
}
|
|
@@ -570,6 +602,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
|
|
|
570
602
|
idxs.push_back(cur_idx);
|
|
571
603
|
|
|
572
604
|
used[cur_idx] = true;
|
|
605
|
+
++n_used;
|
|
573
606
|
|
|
574
607
|
if (idxs.size() >= n_ubatch) {
|
|
575
608
|
break;
|
|
@@ -620,78 +653,77 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|
|
620
653
|
|
|
621
654
|
assert(n_tokens%n_seqs == 0);
|
|
622
655
|
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
auto & ubatch = ubatches.back();
|
|
656
|
+
auto udata = std::make_shared<llama_ubatch::data_t>();
|
|
626
657
|
|
|
627
658
|
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
|
|
628
659
|
|
|
629
660
|
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
|
|
630
661
|
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
|
|
631
662
|
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
663
|
+
udata->token .resize(n_tokens);
|
|
664
|
+
udata->embd .resize(n_embd_all);
|
|
665
|
+
udata->pos .resize(n_pos_all);
|
|
666
|
+
udata->n_seq_id .resize(n_tokens);
|
|
667
|
+
udata->seq_id .resize(n_tokens);
|
|
668
|
+
udata->seq_id_unq.resize(0);
|
|
669
|
+
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
|
670
|
+
udata->output .resize(n_tokens);
|
|
640
671
|
|
|
641
672
|
seq_set_t seq_set_unq;
|
|
642
673
|
|
|
643
674
|
for (size_t i = 0; i < idxs.size(); ++i) {
|
|
644
675
|
if (batch.token) {
|
|
645
|
-
|
|
676
|
+
udata->token[i] = batch.token[idxs[i]];
|
|
646
677
|
}
|
|
647
678
|
|
|
648
679
|
if (batch.embd) {
|
|
649
|
-
memcpy(
|
|
680
|
+
memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
|
|
650
681
|
}
|
|
651
682
|
|
|
652
683
|
for (int j = 0; j < n_pos_cur; ++j) {
|
|
653
|
-
|
|
684
|
+
udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
|
|
654
685
|
}
|
|
655
686
|
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
687
|
+
udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
|
|
688
|
+
udata->seq_id[i] = batch.seq_id[idxs[i]];
|
|
689
|
+
udata->output[i] = batch.logits[idxs[i]];
|
|
659
690
|
|
|
660
|
-
for (int s = 0; s <
|
|
661
|
-
seq_set_unq.set(
|
|
691
|
+
for (int s = 0; s < udata->n_seq_id[i]; ++s) {
|
|
692
|
+
seq_set_unq.set(udata->seq_id[i][s]);
|
|
662
693
|
}
|
|
663
694
|
|
|
664
|
-
if (
|
|
695
|
+
if (udata->output[i]) {
|
|
665
696
|
out_ids.push_back(idxs[i]);
|
|
666
697
|
}
|
|
667
698
|
}
|
|
668
699
|
|
|
669
|
-
for (
|
|
700
|
+
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
|
670
701
|
if (seq_set_unq.test(s)) {
|
|
671
|
-
|
|
672
|
-
|
|
702
|
+
udata->seq_idx[s] = udata->seq_id_unq.size();
|
|
703
|
+
udata->seq_id_unq.push_back(s);
|
|
673
704
|
}
|
|
674
705
|
}
|
|
675
706
|
|
|
676
707
|
llama_ubatch res {
|
|
677
|
-
/*.
|
|
708
|
+
/*.b_equal_seqs =*/ equal_seqs,
|
|
678
709
|
/*.n_tokens =*/ n_tokens,
|
|
679
710
|
/*.n_seq_tokens =*/ n_tokens/n_seqs,
|
|
680
711
|
/*.n_seqs =*/ n_seqs,
|
|
681
|
-
/*.n_seqs_unq =*/ (uint32_t)
|
|
682
|
-
|
|
683
|
-
/*.token =*/ batch.token ?
|
|
684
|
-
/*.embd =*/ batch.embd ?
|
|
685
|
-
/*.pos =*/
|
|
686
|
-
/*.n_seq_id =*/
|
|
687
|
-
/*.seq_id =*/
|
|
688
|
-
/*.seq_id_unq =*/
|
|
689
|
-
/*.seq_idx =*/
|
|
690
|
-
/*.output =*/
|
|
712
|
+
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
|
|
713
|
+
|
|
714
|
+
/*.token =*/ batch.token ? udata->token.data() : nullptr,
|
|
715
|
+
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
|
|
716
|
+
/*.pos =*/ udata->pos.data(),
|
|
717
|
+
/*.n_seq_id =*/ udata->n_seq_id.data(),
|
|
718
|
+
/*.seq_id =*/ udata->seq_id.data(),
|
|
719
|
+
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
|
|
720
|
+
/*.seq_idx =*/ udata->seq_idx.data(),
|
|
721
|
+
/*.output =*/ udata->output.data(),
|
|
722
|
+
/*.data =*/ std::move(udata),
|
|
691
723
|
};
|
|
692
724
|
|
|
693
725
|
if (debug > 0) {
|
|
694
|
-
LLAMA_LOG_DEBUG("%s: added ubatch
|
|
726
|
+
LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__);
|
|
695
727
|
|
|
696
728
|
ubatch_print(res, debug);
|
|
697
729
|
}
|
|
@@ -701,7 +733,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|
|
701
733
|
|
|
702
734
|
void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
|
|
703
735
|
if (debug > 0) {
|
|
704
|
-
LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs);
|
|
736
|
+
LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs());
|
|
705
737
|
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
|
|
706
738
|
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
|
|
707
739
|
LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
|
|
@@ -8,12 +8,17 @@
|
|
|
8
8
|
#include <vector>
|
|
9
9
|
#include <set>
|
|
10
10
|
#include <bitset>
|
|
11
|
+
#include <memory>
|
|
11
12
|
#include <unordered_map>
|
|
12
13
|
|
|
13
14
|
// keep this struct lightweight
|
|
14
|
-
// it points to data in `llama_batch_allocr`
|
|
15
15
|
struct llama_ubatch {
|
|
16
|
-
bool equal_seqs
|
|
16
|
+
bool equal_seqs() const {
|
|
17
|
+
return b_equal_seqs != 0;
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
|
|
21
|
+
// otherwise address sanitizer complains
|
|
17
22
|
// TODO: whole_seqs for embeddings?
|
|
18
23
|
|
|
19
24
|
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
|
@@ -34,6 +39,20 @@ struct llama_ubatch {
|
|
|
34
39
|
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
|
|
35
40
|
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
|
|
36
41
|
int8_t * output; // [n_tokens] | i | -
|
|
42
|
+
|
|
43
|
+
struct data_t {
|
|
44
|
+
std::vector<llama_token> token;
|
|
45
|
+
std::vector<float> embd;
|
|
46
|
+
std::vector<llama_pos> pos;
|
|
47
|
+
std::vector<int32_t> n_seq_id;
|
|
48
|
+
std::vector<llama_seq_id *> seq_id;
|
|
49
|
+
std::vector<llama_seq_id> seq_id_unq;
|
|
50
|
+
std::vector<int32_t> seq_idx;
|
|
51
|
+
std::vector<int8_t> output;
|
|
52
|
+
};
|
|
53
|
+
|
|
54
|
+
// the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
|
|
55
|
+
std::shared_ptr<data_t> data;
|
|
37
56
|
};
|
|
38
57
|
|
|
39
58
|
// a helper for sanitizing, fulfilling and splitting a batch
|
|
@@ -48,12 +67,14 @@ public:
|
|
|
48
67
|
const llama_vocab & vocab,
|
|
49
68
|
const llama_memory_i * memory,
|
|
50
69
|
uint32_t n_embd,
|
|
70
|
+
uint32_t n_seq_max,
|
|
51
71
|
bool output_all);
|
|
52
72
|
|
|
53
73
|
const llama_batch & get_batch() const;
|
|
54
74
|
|
|
55
75
|
uint32_t get_n_tokens() const;
|
|
56
76
|
uint32_t get_n_outputs() const;
|
|
77
|
+
uint32_t get_n_used() const;
|
|
57
78
|
|
|
58
79
|
// the array of output indices in the order they were encountered during the ubatch splitting
|
|
59
80
|
std::vector<int32_t> & get_out_ids();
|
|
@@ -69,7 +90,8 @@ public:
|
|
|
69
90
|
llama_ubatch split_simple(uint32_t n_ubatch);
|
|
70
91
|
|
|
71
92
|
// make ubatches of equal-length sequences sets
|
|
72
|
-
|
|
93
|
+
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
|
|
94
|
+
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
|
|
73
95
|
|
|
74
96
|
// sequence-set-wise split - each ubatch contains a single sequence-set
|
|
75
97
|
llama_ubatch split_seq(uint32_t n_ubatch);
|
|
@@ -98,6 +120,7 @@ private:
|
|
|
98
120
|
const uint32_t n_pos_per_embd;
|
|
99
121
|
|
|
100
122
|
uint32_t n_embd;
|
|
123
|
+
uint32_t n_seq_max;
|
|
101
124
|
uint32_t n_outputs;
|
|
102
125
|
|
|
103
126
|
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
|
@@ -112,6 +135,9 @@ private:
|
|
|
112
135
|
using pos_set_t = std::set<llama_pos>;
|
|
113
136
|
using seq_cpl_t = std::vector<bool>;
|
|
114
137
|
|
|
138
|
+
// helper flag to quickly determine if there are any coupled sequences in the batch
|
|
139
|
+
bool has_cpl = false;
|
|
140
|
+
|
|
115
141
|
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
|
116
142
|
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
|
117
143
|
|
|
@@ -125,23 +151,10 @@ private:
|
|
|
125
151
|
// batch indices of the output
|
|
126
152
|
std::vector<int32_t> out_ids;
|
|
127
153
|
|
|
154
|
+
uint32_t n_used;
|
|
155
|
+
|
|
128
156
|
// used[i] indicates if token i has already been used in a previous ubatch
|
|
129
157
|
std::vector<bool> used;
|
|
130
158
|
|
|
131
|
-
// llama_ubatch points to this data:
|
|
132
|
-
struct ubatch {
|
|
133
|
-
std::vector<llama_token> token;
|
|
134
|
-
std::vector<float> embd;
|
|
135
|
-
std::vector<llama_pos> pos;
|
|
136
|
-
std::vector<int32_t> n_seq_id;
|
|
137
|
-
std::vector<llama_seq_id *> seq_id;
|
|
138
|
-
std::vector<llama_seq_id> seq_id_unq;
|
|
139
|
-
std::vector<int32_t> seq_idx;
|
|
140
|
-
std::vector<int8_t> output;
|
|
141
|
-
};
|
|
142
|
-
|
|
143
|
-
// current splitting state:
|
|
144
|
-
std::vector<ubatch> ubatches;
|
|
145
|
-
|
|
146
159
|
int debug;
|
|
147
160
|
};
|
|
@@ -56,6 +56,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|
|
56
56
|
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
|
|
57
57
|
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
|
|
58
58
|
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
|
|
59
|
+
{ "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 },
|
|
59
60
|
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
|
|
60
61
|
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
|
|
61
62
|
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
|
@@ -64,6 +65,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|
|
64
65
|
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
|
|
65
66
|
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
|
66
67
|
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
|
68
|
+
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
|
|
69
|
+
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
|
|
67
70
|
};
|
|
68
71
|
|
|
69
72
|
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
|
@@ -166,10 +169,13 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|
|
166
169
|
} else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) {
|
|
167
170
|
return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
|
|
168
171
|
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
|
|
172
|
+
if (tmpl_contains("[|tool|]")) {
|
|
173
|
+
return LLM_CHAT_TEMPLATE_EXAONE_4;
|
|
174
|
+
}
|
|
169
175
|
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
|
|
170
176
|
// EXAONE-3.0-7.8B-Instruct
|
|
171
177
|
return LLM_CHAT_TEMPLATE_EXAONE_3;
|
|
172
|
-
} else if (tmpl_contains("rwkv-world")) {
|
|
178
|
+
} else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) {
|
|
173
179
|
return LLM_CHAT_TEMPLATE_RWKV_WORLD;
|
|
174
180
|
} else if (tmpl_contains("<|start_of_role|>")) {
|
|
175
181
|
return LLM_CHAT_TEMPLATE_GRANITE;
|
|
@@ -185,6 +191,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|
|
185
191
|
return LLM_CHAT_TEMPLATE_LLAMA4;
|
|
186
192
|
} else if (tmpl_contains("<|endofuserprompt|>")) {
|
|
187
193
|
return LLM_CHAT_TEMPLATE_DOTS1;
|
|
194
|
+
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
|
|
195
|
+
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
|
|
196
|
+
} else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
|
|
197
|
+
return LLM_CHAT_TEMPLATE_KIMI_K2;
|
|
188
198
|
}
|
|
189
199
|
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
|
190
200
|
}
|
|
@@ -526,6 +536,22 @@ int32_t llm_chat_apply_template(
|
|
|
526
536
|
if (add_ass) {
|
|
527
537
|
ss << "[|assistant|]";
|
|
528
538
|
}
|
|
539
|
+
} else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_4) {
|
|
540
|
+
for (auto message : chat) {
|
|
541
|
+
std::string role(message->role);
|
|
542
|
+
if (role == "system") {
|
|
543
|
+
ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
|
|
544
|
+
} else if (role == "user") {
|
|
545
|
+
ss << "[|user|]" << trim(message->content) << "\n";
|
|
546
|
+
} else if (role == "assistant") {
|
|
547
|
+
ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
|
|
548
|
+
} else if (role == "tool") {
|
|
549
|
+
ss << "[|tool|]" << trim(message->content) << "[|endofturn|]\n";
|
|
550
|
+
}
|
|
551
|
+
}
|
|
552
|
+
if (add_ass) {
|
|
553
|
+
ss << "[|assistant|]";
|
|
554
|
+
}
|
|
529
555
|
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
|
|
530
556
|
// this template requires the model to have "\n\n" as EOT token
|
|
531
557
|
for (size_t i = 0; i < chat.size(); i++) {
|
|
@@ -665,6 +691,38 @@ int32_t llm_chat_apply_template(
|
|
|
665
691
|
if (add_ass) {
|
|
666
692
|
ss << "<|response|>";
|
|
667
693
|
}
|
|
694
|
+
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) {
|
|
695
|
+
// tencent/Hunyuan-A13B-Instruct
|
|
696
|
+
for (auto message : chat) {
|
|
697
|
+
std::string role(message->role);
|
|
698
|
+
if (role == "system") {
|
|
699
|
+
ss << "<|startoftext|>" << message->content << "<|extra_4|>";
|
|
700
|
+
} else if (role == "assistant") {
|
|
701
|
+
ss << "<|startoftext|>" << message->content << "<|eos|>";
|
|
702
|
+
} else {
|
|
703
|
+
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
|
|
704
|
+
}
|
|
705
|
+
}
|
|
706
|
+
} else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) {
|
|
707
|
+
// moonshotai/Kimi-K2-Instruct
|
|
708
|
+
for (auto message : chat) {
|
|
709
|
+
std::string role(message->role);
|
|
710
|
+
if (role == "system") {
|
|
711
|
+
ss << "<|im_system|>system<|im_middle|>";
|
|
712
|
+
} else if (role == "user") {
|
|
713
|
+
ss << "<|im_user|>user<|im_middle|>";
|
|
714
|
+
} else if (role == "assistant") {
|
|
715
|
+
ss << "<|im_assistant|>assistant<|im_middle|>";
|
|
716
|
+
} else if (role == "tool") {
|
|
717
|
+
ss << "<|im_system|>tool<|im_middle|>";
|
|
718
|
+
}
|
|
719
|
+
|
|
720
|
+
ss << message->content << "<|im_end|>";
|
|
721
|
+
|
|
722
|
+
if (add_ass) {
|
|
723
|
+
ss << "<|im_assistant|>assistant<|im_middle|>";
|
|
724
|
+
}
|
|
725
|
+
}
|
|
668
726
|
} else {
|
|
669
727
|
// template not supported
|
|
670
728
|
return -1;
|
|
@@ -35,6 +35,7 @@ enum llm_chat_template {
|
|
|
35
35
|
LLM_CHAT_TEMPLATE_GLMEDGE,
|
|
36
36
|
LLM_CHAT_TEMPLATE_MINICPM,
|
|
37
37
|
LLM_CHAT_TEMPLATE_EXAONE_3,
|
|
38
|
+
LLM_CHAT_TEMPLATE_EXAONE_4,
|
|
38
39
|
LLM_CHAT_TEMPLATE_RWKV_WORLD,
|
|
39
40
|
LLM_CHAT_TEMPLATE_GRANITE,
|
|
40
41
|
LLM_CHAT_TEMPLATE_GIGACHAT,
|
|
@@ -44,6 +45,8 @@ enum llm_chat_template {
|
|
|
44
45
|
LLM_CHAT_TEMPLATE_LLAMA4,
|
|
45
46
|
LLM_CHAT_TEMPLATE_SMOLVLM,
|
|
46
47
|
LLM_CHAT_TEMPLATE_DOTS1,
|
|
48
|
+
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
|
|
49
|
+
LLM_CHAT_TEMPLATE_KIMI_K2,
|
|
47
50
|
LLM_CHAT_TEMPLATE_UNKNOWN,
|
|
48
51
|
};
|
|
49
52
|
|