@fugood/llama.node 1.0.2 → 1.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. package/package.json +14 -14
  2. package/src/llama.cpp/CMakeLists.txt +0 -1
  3. package/src/llama.cpp/common/CMakeLists.txt +4 -5
  4. package/src/llama.cpp/common/arg.cpp +44 -0
  5. package/src/llama.cpp/common/common.cpp +22 -6
  6. package/src/llama.cpp/common/common.h +15 -1
  7. package/src/llama.cpp/ggml/CMakeLists.txt +10 -2
  8. package/src/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  9. package/src/llama.cpp/ggml/include/ggml.h +104 -10
  10. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  11. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
  12. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +12 -1
  13. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  14. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +749 -163
  15. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +5 -0
  16. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  17. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +12 -9
  18. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +88 -9
  19. package/src/llama.cpp/include/llama.h +13 -47
  20. package/src/llama.cpp/src/llama-arch.cpp +298 -3
  21. package/src/llama.cpp/src/llama-arch.h +22 -1
  22. package/src/llama.cpp/src/llama-batch.cpp +103 -71
  23. package/src/llama.cpp/src/llama-batch.h +31 -18
  24. package/src/llama.cpp/src/llama-chat.cpp +59 -1
  25. package/src/llama.cpp/src/llama-chat.h +3 -0
  26. package/src/llama.cpp/src/llama-context.cpp +134 -95
  27. package/src/llama.cpp/src/llama-context.h +13 -16
  28. package/src/llama.cpp/src/llama-cparams.h +3 -2
  29. package/src/llama.cpp/src/llama-graph.cpp +279 -180
  30. package/src/llama.cpp/src/llama-graph.h +183 -122
  31. package/src/llama.cpp/src/llama-hparams.cpp +47 -1
  32. package/src/llama.cpp/src/llama-hparams.h +12 -1
  33. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  34. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  35. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  36. package/src/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  37. package/src/llama.cpp/src/llama-kv-cells.h +62 -10
  38. package/src/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  39. package/src/llama.cpp/src/llama-memory-hybrid.h +3 -1
  40. package/src/llama.cpp/src/llama-memory-recurrent.cpp +21 -11
  41. package/src/llama.cpp/src/llama-memory.cpp +17 -0
  42. package/src/llama.cpp/src/llama-memory.h +3 -0
  43. package/src/llama.cpp/src/llama-model.cpp +3373 -743
  44. package/src/llama.cpp/src/llama-model.h +20 -4
  45. package/src/llama.cpp/src/llama-quant.cpp +2 -2
  46. package/src/llama.cpp/src/llama-vocab.cpp +376 -10
  47. package/src/llama.cpp/src/llama-vocab.h +43 -0
  48. package/src/llama.cpp/src/unicode.cpp +207 -0
  49. package/src/llama.cpp/src/unicode.h +2 -0
  50. package/src/llama.cpp/ggml/include/ggml-kompute.h +0 -50
@@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
27
27
  const llama_vocab & vocab,
28
28
  const llama_memory_i * memory,
29
29
  uint32_t n_embd,
30
+ uint32_t n_seq_max,
30
31
  bool output_all) {
31
32
  clear();
32
33
 
@@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
40
41
  // validate input batch
41
42
  //
42
43
 
44
+ if (n_seq_max > LLAMA_MAX_SEQ) {
45
+ LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
46
+ return false;
47
+ }
48
+
43
49
  if (batch.token) {
44
50
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
45
51
  if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
@@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
52
58
  if (batch.seq_id) {
53
59
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
54
60
  for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
55
- if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
56
- LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
61
+ if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
62
+ LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
57
63
  return false;
58
64
  }
59
65
  }
@@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
86
92
 
87
93
  // initialize the starting position for each sequence based on the positions in the memory
88
94
  llama_pos p0[LLAMA_MAX_SEQ];
89
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
95
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
90
96
  if (!memory) {
91
97
  // if no memory -> start from 0
92
98
  p0[s] = 0;
@@ -143,13 +149,16 @@ bool llama_batch_allocr::init(
143
149
  // compute stats
144
150
  //
145
151
 
146
- this->n_embd = n_embd;
152
+ this->n_embd = n_embd;
153
+ this->n_seq_max = n_seq_max;
147
154
 
148
155
  // count the outputs in this batch
149
156
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
150
157
  n_outputs += batch.logits[i] != 0;
151
158
  }
152
159
 
160
+ has_cpl = false;
161
+
153
162
  // determine coupled sequences
154
163
  // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
155
164
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
@@ -166,6 +175,8 @@ bool llama_batch_allocr::init(
166
175
 
167
176
  // note: tracking the other way around is not necessary for now
168
177
  //seq_cpl[s0][s1] = true;
178
+
179
+ has_cpl = true;
169
180
  }
170
181
  }
171
182
  }
@@ -187,7 +198,7 @@ bool llama_batch_allocr::init(
187
198
  seq_set_map[cur].push_back(i);
188
199
  }
189
200
 
190
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
201
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
191
202
  if (seq_set_unq.test(s)) {
192
203
  seq_idx[s] = seq_id_unq.size();
193
204
  seq_id_unq.push_back(s);
@@ -199,7 +210,7 @@ bool llama_batch_allocr::init(
199
210
  LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
200
211
 
201
212
  llama_ubatch ubatch {
202
- /*.equal_seqs =*/ false,
213
+ /*.b_equal_seqs =*/ false,
203
214
  /*.n_tokens =*/ (uint32_t) batch.n_tokens,
204
215
  /*.n_seq_tokens =*/ (uint32_t) 1,
205
216
  /*.n_seqs =*/ (uint32_t) batch.n_tokens,
@@ -212,6 +223,7 @@ bool llama_batch_allocr::init(
212
223
  /*.seq_id_unq =*/ this->seq_id_unq.data(),
213
224
  /*.seq_idx =*/ this->seq_idx.data(),
214
225
  /*.output =*/ batch.logits,
226
+ /*.data =*/ {},
215
227
  };
216
228
 
217
229
  ubatch_print(ubatch, debug);
@@ -239,7 +251,7 @@ bool llama_batch_allocr::init(
239
251
  // consistency checks
240
252
  //
241
253
 
242
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
254
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
243
255
  if (seq_pos[s].empty()) {
244
256
  continue;
245
257
  }
@@ -282,8 +294,8 @@ bool llama_batch_allocr::init(
282
294
  }
283
295
 
284
296
  if (memory) {
285
- for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
286
- for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
297
+ for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
298
+ for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
287
299
  if (seq_cpl[s0][s1]) {
288
300
  if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
289
301
  memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
@@ -314,12 +326,12 @@ bool llama_batch_allocr::init(
314
326
  //
315
327
  {
316
328
  seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
317
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
329
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
318
330
  cur_seq_set[s].set();
319
331
  }
320
332
 
321
333
  llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
322
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
334
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
323
335
  cur_seq_pos[s] = -1;
324
336
  }
325
337
 
@@ -355,39 +367,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
355
367
  clear();
356
368
  split_reset();
357
369
 
358
- ubatches.emplace_back();
359
-
360
- auto & ubatch = ubatches.back();
370
+ auto udata = std::make_shared<llama_ubatch::data_t>();
361
371
 
362
- ubatch.token .resize(n_tokens);
363
- ubatch.embd .clear();
364
- ubatch.pos .resize(n_tokens);
365
- ubatch.n_seq_id .resize(n_tokens);
366
- ubatch.seq_id .resize(n_tokens);
367
- ubatch.seq_id_unq.resize(0);
368
- ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
369
- ubatch.output .resize(n_tokens);
372
+ udata->token .resize(n_tokens);
373
+ udata->embd .clear();
374
+ udata->pos .resize(n_tokens);
375
+ udata->n_seq_id .resize(n_tokens);
376
+ udata->seq_id .resize(n_tokens);
377
+ udata->seq_id_unq.resize(0);
378
+ udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
379
+ udata->output .resize(n_tokens);
370
380
 
371
381
  for (uint32_t s = 0; s < n_seqs; ++s) {
372
- ubatch.seq_idx[s] = s;
373
- ubatch.seq_id_unq.push_back(s);
382
+ udata->seq_idx[s] = s;
383
+ udata->seq_id_unq.push_back(s);
374
384
  }
375
385
 
376
386
  llama_ubatch res {
377
- /*.equal_seqs =*/ true,
387
+ /*.b_equal_seqs =*/ true,
378
388
  /*.n_tokens =*/ n_tokens,
379
389
  /*.n_seq_tokens =*/ n_seq_tokens,
380
390
  /*.n_seqs =*/ n_seqs,
381
391
  /*.n_seqs_unq =*/ n_seqs,
382
392
 
383
- /*.token =*/ ubatch.token.data(),
393
+ /*.token =*/ udata->token.data(),
384
394
  /*.embd =*/ nullptr,
385
- /*.pos =*/ ubatch.pos.data(),
386
- /*.n_seq_id =*/ ubatch.n_seq_id.data(),
387
- /*.seq_id =*/ ubatch.seq_id.data(),
388
- /*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
389
- /*.seq_idx =*/ ubatch.seq_idx.data(),
390
- /*.output =*/ ubatch.output.data(),
395
+ /*.pos =*/ udata->pos.data(),
396
+ /*.n_seq_id =*/ udata->n_seq_id.data(),
397
+ /*.seq_id =*/ udata->seq_id.data(),
398
+ /*.seq_id_unq =*/ udata->seq_id_unq.data(),
399
+ /*.seq_idx =*/ udata->seq_idx.data(),
400
+ /*.output =*/ udata->output.data(),
401
+ /*.data =*/ std::move(udata),
391
402
  };
392
403
 
393
404
  return res;
@@ -405,6 +416,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
405
416
  return n_outputs;
406
417
  }
407
418
 
419
+ uint32_t llama_batch_allocr::get_n_used() const {
420
+ return n_used;
421
+ }
422
+
408
423
  std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
409
424
  return out_ids;
410
425
  }
@@ -420,10 +435,10 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
420
435
  void llama_batch_allocr::split_reset() {
421
436
  out_ids.clear();
422
437
 
438
+ n_used = 0;
439
+
423
440
  used.clear();
424
441
  used.resize(get_n_tokens(), false);
425
-
426
- ubatches.clear();
427
442
  }
428
443
 
429
444
  llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
@@ -444,6 +459,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
444
459
  idxs.push_back(cur_idx);
445
460
 
446
461
  used[cur_idx] = true;
462
+ ++n_used;
447
463
 
448
464
  ++cur_idx;
449
465
 
@@ -459,9 +475,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
459
475
  return ubatch_add(idxs, idxs.size(), false);
460
476
  }
461
477
 
462
- llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
478
+ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
479
+ if (sequential && has_cpl) {
480
+ LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
481
+
482
+ return {};
483
+ }
484
+
463
485
  std::vector<seq_set_t> cur_seq_set;
464
486
 
487
+ llama_seq_id last_seq_id = -1;
488
+
465
489
  // determine the non-overlapping sequence sets participating in this ubatch
466
490
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
467
491
  if (used[i]) {
@@ -478,9 +502,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
478
502
  }
479
503
  }
480
504
 
505
+ // accept only increasing sequence ids
506
+ if (sequential) {
507
+ add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
508
+ }
509
+
481
510
  if (add) {
482
511
  cur_seq_set.push_back(seq_set[i]);
483
512
 
513
+ last_seq_id = batch.seq_id[i][0];
514
+
484
515
  if (cur_seq_set.size() > n_ubatch) {
485
516
  break;
486
517
  }
@@ -529,6 +560,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
529
560
  idxs_per_seq[s].push_back(idx);
530
561
 
531
562
  used[idx] = true;
563
+ ++n_used;
532
564
 
533
565
  ++cur_idx[s];
534
566
  }
@@ -570,6 +602,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
570
602
  idxs.push_back(cur_idx);
571
603
 
572
604
  used[cur_idx] = true;
605
+ ++n_used;
573
606
 
574
607
  if (idxs.size() >= n_ubatch) {
575
608
  break;
@@ -620,78 +653,77 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
620
653
 
621
654
  assert(n_tokens%n_seqs == 0);
622
655
 
623
- ubatches.emplace_back();
624
-
625
- auto & ubatch = ubatches.back();
656
+ auto udata = std::make_shared<llama_ubatch::data_t>();
626
657
 
627
658
  const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
628
659
 
629
660
  const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
630
661
  const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
631
662
 
632
- ubatch.token .resize(n_tokens);
633
- ubatch.embd .resize(n_embd_all);
634
- ubatch.pos .resize(n_pos_all);
635
- ubatch.n_seq_id .resize(n_tokens);
636
- ubatch.seq_id .resize(n_tokens);
637
- ubatch.seq_id_unq.resize(0);
638
- ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
639
- ubatch.output .resize(n_tokens);
663
+ udata->token .resize(n_tokens);
664
+ udata->embd .resize(n_embd_all);
665
+ udata->pos .resize(n_pos_all);
666
+ udata->n_seq_id .resize(n_tokens);
667
+ udata->seq_id .resize(n_tokens);
668
+ udata->seq_id_unq.resize(0);
669
+ udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
670
+ udata->output .resize(n_tokens);
640
671
 
641
672
  seq_set_t seq_set_unq;
642
673
 
643
674
  for (size_t i = 0; i < idxs.size(); ++i) {
644
675
  if (batch.token) {
645
- ubatch.token[i] = batch.token[idxs[i]];
676
+ udata->token[i] = batch.token[idxs[i]];
646
677
  }
647
678
 
648
679
  if (batch.embd) {
649
- memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
680
+ memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
650
681
  }
651
682
 
652
683
  for (int j = 0; j < n_pos_cur; ++j) {
653
- ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
684
+ udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
654
685
  }
655
686
 
656
- ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
657
- ubatch.seq_id[i] = batch.seq_id[idxs[i]];
658
- ubatch.output[i] = batch.logits[idxs[i]];
687
+ udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
688
+ udata->seq_id[i] = batch.seq_id[idxs[i]];
689
+ udata->output[i] = batch.logits[idxs[i]];
659
690
 
660
- for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
661
- seq_set_unq.set(ubatch.seq_id[i][s]);
691
+ for (int s = 0; s < udata->n_seq_id[i]; ++s) {
692
+ seq_set_unq.set(udata->seq_id[i][s]);
662
693
  }
663
694
 
664
- if (ubatch.output[i]) {
695
+ if (udata->output[i]) {
665
696
  out_ids.push_back(idxs[i]);
666
697
  }
667
698
  }
668
699
 
669
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
700
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
670
701
  if (seq_set_unq.test(s)) {
671
- ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
672
- ubatch.seq_id_unq.push_back(s);
702
+ udata->seq_idx[s] = udata->seq_id_unq.size();
703
+ udata->seq_id_unq.push_back(s);
673
704
  }
674
705
  }
675
706
 
676
707
  llama_ubatch res {
677
- /*.equal_seqs =*/ equal_seqs,
708
+ /*.b_equal_seqs =*/ equal_seqs,
678
709
  /*.n_tokens =*/ n_tokens,
679
710
  /*.n_seq_tokens =*/ n_tokens/n_seqs,
680
711
  /*.n_seqs =*/ n_seqs,
681
- /*.n_seqs_unq =*/ (uint32_t) ubatch.seq_id_unq.size(),
682
-
683
- /*.token =*/ batch.token ? ubatch.token.data() : nullptr,
684
- /*.embd =*/ batch.embd ? ubatch.embd.data() : nullptr,
685
- /*.pos =*/ ubatch.pos.data(),
686
- /*.n_seq_id =*/ ubatch.n_seq_id.data(),
687
- /*.seq_id =*/ ubatch.seq_id.data(),
688
- /*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
689
- /*.seq_idx =*/ ubatch.seq_idx.data(),
690
- /*.output =*/ ubatch.output.data(),
712
+ /*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
713
+
714
+ /*.token =*/ batch.token ? udata->token.data() : nullptr,
715
+ /*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
716
+ /*.pos =*/ udata->pos.data(),
717
+ /*.n_seq_id =*/ udata->n_seq_id.data(),
718
+ /*.seq_id =*/ udata->seq_id.data(),
719
+ /*.seq_id_unq =*/ udata->seq_id_unq.data(),
720
+ /*.seq_idx =*/ udata->seq_idx.data(),
721
+ /*.output =*/ udata->output.data(),
722
+ /*.data =*/ std::move(udata),
691
723
  };
692
724
 
693
725
  if (debug > 0) {
694
- LLAMA_LOG_DEBUG("%s: added ubatch %d to split:\n", __func__, (int) ubatches.size() - 1);
726
+ LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__);
695
727
 
696
728
  ubatch_print(res, debug);
697
729
  }
@@ -701,7 +733,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
701
733
 
702
734
  void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
703
735
  if (debug > 0) {
704
- LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs);
736
+ LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs());
705
737
  LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
706
738
  LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
707
739
  LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
@@ -8,12 +8,17 @@
8
8
  #include <vector>
9
9
  #include <set>
10
10
  #include <bitset>
11
+ #include <memory>
11
12
  #include <unordered_map>
12
13
 
13
14
  // keep this struct lightweight
14
- // it points to data in `llama_batch_allocr`
15
15
  struct llama_ubatch {
16
- bool equal_seqs;
16
+ bool equal_seqs() const {
17
+ return b_equal_seqs != 0;
18
+ }
19
+
20
+ uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
21
+ // otherwise address sanitizer complains
17
22
  // TODO: whole_seqs for embeddings?
18
23
 
19
24
  uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
@@ -34,6 +39,20 @@ struct llama_ubatch {
34
39
  llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
35
40
  int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
36
41
  int8_t * output; // [n_tokens] | i | -
42
+
43
+ struct data_t {
44
+ std::vector<llama_token> token;
45
+ std::vector<float> embd;
46
+ std::vector<llama_pos> pos;
47
+ std::vector<int32_t> n_seq_id;
48
+ std::vector<llama_seq_id *> seq_id;
49
+ std::vector<llama_seq_id> seq_id_unq;
50
+ std::vector<int32_t> seq_idx;
51
+ std::vector<int8_t> output;
52
+ };
53
+
54
+ // the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
55
+ std::shared_ptr<data_t> data;
37
56
  };
38
57
 
39
58
  // a helper for sanitizing, fulfilling and splitting a batch
@@ -48,12 +67,14 @@ public:
48
67
  const llama_vocab & vocab,
49
68
  const llama_memory_i * memory,
50
69
  uint32_t n_embd,
70
+ uint32_t n_seq_max,
51
71
  bool output_all);
52
72
 
53
73
  const llama_batch & get_batch() const;
54
74
 
55
75
  uint32_t get_n_tokens() const;
56
76
  uint32_t get_n_outputs() const;
77
+ uint32_t get_n_used() const;
57
78
 
58
79
  // the array of output indices in the order they were encountered during the ubatch splitting
59
80
  std::vector<int32_t> & get_out_ids();
@@ -69,7 +90,8 @@ public:
69
90
  llama_ubatch split_simple(uint32_t n_ubatch);
70
91
 
71
92
  // make ubatches of equal-length sequences sets
72
- llama_ubatch split_equal(uint32_t n_ubatch);
93
+ // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
94
+ llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
73
95
 
74
96
  // sequence-set-wise split - each ubatch contains a single sequence-set
75
97
  llama_ubatch split_seq(uint32_t n_ubatch);
@@ -98,6 +120,7 @@ private:
98
120
  const uint32_t n_pos_per_embd;
99
121
 
100
122
  uint32_t n_embd;
123
+ uint32_t n_seq_max;
101
124
  uint32_t n_outputs;
102
125
 
103
126
  std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
@@ -112,6 +135,9 @@ private:
112
135
  using pos_set_t = std::set<llama_pos>;
113
136
  using seq_cpl_t = std::vector<bool>;
114
137
 
138
+ // helper flag to quickly determine if there are any coupled sequences in the batch
139
+ bool has_cpl = false;
140
+
115
141
  std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
116
142
  std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
117
143
 
@@ -125,23 +151,10 @@ private:
125
151
  // batch indices of the output
126
152
  std::vector<int32_t> out_ids;
127
153
 
154
+ uint32_t n_used;
155
+
128
156
  // used[i] indicates if token i has already been used in a previous ubatch
129
157
  std::vector<bool> used;
130
158
 
131
- // llama_ubatch points to this data:
132
- struct ubatch {
133
- std::vector<llama_token> token;
134
- std::vector<float> embd;
135
- std::vector<llama_pos> pos;
136
- std::vector<int32_t> n_seq_id;
137
- std::vector<llama_seq_id *> seq_id;
138
- std::vector<llama_seq_id> seq_id_unq;
139
- std::vector<int32_t> seq_idx;
140
- std::vector<int8_t> output;
141
- };
142
-
143
- // current splitting state:
144
- std::vector<ubatch> ubatches;
145
-
146
159
  int debug;
147
160
  };
@@ -56,6 +56,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
56
56
  { "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
57
57
  { "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
58
58
  { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
59
+ { "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 },
59
60
  { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
60
61
  { "granite", LLM_CHAT_TEMPLATE_GRANITE },
61
62
  { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
@@ -64,6 +65,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
64
65
  { "bailing", LLM_CHAT_TEMPLATE_BAILING },
65
66
  { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
66
67
  { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
68
+ { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
69
+ { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
67
70
  };
68
71
 
69
72
  llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -166,10 +169,13 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
166
169
  } else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) {
167
170
  return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
168
171
  } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
172
+ if (tmpl_contains("[|tool|]")) {
173
+ return LLM_CHAT_TEMPLATE_EXAONE_4;
174
+ }
169
175
  // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
170
176
  // EXAONE-3.0-7.8B-Instruct
171
177
  return LLM_CHAT_TEMPLATE_EXAONE_3;
172
- } else if (tmpl_contains("rwkv-world")) {
178
+ } else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) {
173
179
  return LLM_CHAT_TEMPLATE_RWKV_WORLD;
174
180
  } else if (tmpl_contains("<|start_of_role|>")) {
175
181
  return LLM_CHAT_TEMPLATE_GRANITE;
@@ -185,6 +191,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
185
191
  return LLM_CHAT_TEMPLATE_LLAMA4;
186
192
  } else if (tmpl_contains("<|endofuserprompt|>")) {
187
193
  return LLM_CHAT_TEMPLATE_DOTS1;
194
+ } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
195
+ return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
196
+ } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
197
+ return LLM_CHAT_TEMPLATE_KIMI_K2;
188
198
  }
189
199
  return LLM_CHAT_TEMPLATE_UNKNOWN;
190
200
  }
@@ -526,6 +536,22 @@ int32_t llm_chat_apply_template(
526
536
  if (add_ass) {
527
537
  ss << "[|assistant|]";
528
538
  }
539
+ } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_4) {
540
+ for (auto message : chat) {
541
+ std::string role(message->role);
542
+ if (role == "system") {
543
+ ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
544
+ } else if (role == "user") {
545
+ ss << "[|user|]" << trim(message->content) << "\n";
546
+ } else if (role == "assistant") {
547
+ ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
548
+ } else if (role == "tool") {
549
+ ss << "[|tool|]" << trim(message->content) << "[|endofturn|]\n";
550
+ }
551
+ }
552
+ if (add_ass) {
553
+ ss << "[|assistant|]";
554
+ }
529
555
  } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
530
556
  // this template requires the model to have "\n\n" as EOT token
531
557
  for (size_t i = 0; i < chat.size(); i++) {
@@ -665,6 +691,38 @@ int32_t llm_chat_apply_template(
665
691
  if (add_ass) {
666
692
  ss << "<|response|>";
667
693
  }
694
+ } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) {
695
+ // tencent/Hunyuan-A13B-Instruct
696
+ for (auto message : chat) {
697
+ std::string role(message->role);
698
+ if (role == "system") {
699
+ ss << "<|startoftext|>" << message->content << "<|extra_4|>";
700
+ } else if (role == "assistant") {
701
+ ss << "<|startoftext|>" << message->content << "<|eos|>";
702
+ } else {
703
+ ss << "<|startoftext|>" << message->content << "<|extra_0|>";
704
+ }
705
+ }
706
+ } else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) {
707
+ // moonshotai/Kimi-K2-Instruct
708
+ for (auto message : chat) {
709
+ std::string role(message->role);
710
+ if (role == "system") {
711
+ ss << "<|im_system|>system<|im_middle|>";
712
+ } else if (role == "user") {
713
+ ss << "<|im_user|>user<|im_middle|>";
714
+ } else if (role == "assistant") {
715
+ ss << "<|im_assistant|>assistant<|im_middle|>";
716
+ } else if (role == "tool") {
717
+ ss << "<|im_system|>tool<|im_middle|>";
718
+ }
719
+
720
+ ss << message->content << "<|im_end|>";
721
+
722
+ if (add_ass) {
723
+ ss << "<|im_assistant|>assistant<|im_middle|>";
724
+ }
725
+ }
668
726
  } else {
669
727
  // template not supported
670
728
  return -1;
@@ -35,6 +35,7 @@ enum llm_chat_template {
35
35
  LLM_CHAT_TEMPLATE_GLMEDGE,
36
36
  LLM_CHAT_TEMPLATE_MINICPM,
37
37
  LLM_CHAT_TEMPLATE_EXAONE_3,
38
+ LLM_CHAT_TEMPLATE_EXAONE_4,
38
39
  LLM_CHAT_TEMPLATE_RWKV_WORLD,
39
40
  LLM_CHAT_TEMPLATE_GRANITE,
40
41
  LLM_CHAT_TEMPLATE_GIGACHAT,
@@ -44,6 +45,8 @@ enum llm_chat_template {
44
45
  LLM_CHAT_TEMPLATE_LLAMA4,
45
46
  LLM_CHAT_TEMPLATE_SMOLVLM,
46
47
  LLM_CHAT_TEMPLATE_DOTS1,
48
+ LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
49
+ LLM_CHAT_TEMPLATE_KIMI_K2,
47
50
  LLM_CHAT_TEMPLATE_UNKNOWN,
48
51
  };
49
52