@fugood/llama.node 1.0.3 → 1.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/package.json +14 -14
  2. package/src/llama.cpp/common/CMakeLists.txt +4 -5
  3. package/src/llama.cpp/common/arg.cpp +37 -0
  4. package/src/llama.cpp/common/common.cpp +22 -6
  5. package/src/llama.cpp/common/common.h +14 -1
  6. package/src/llama.cpp/ggml/CMakeLists.txt +3 -0
  7. package/src/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  8. package/src/llama.cpp/ggml/include/ggml.h +13 -0
  9. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -0
  10. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  11. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +23 -8
  12. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +3 -0
  13. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +39 -0
  14. package/src/llama.cpp/include/llama.h +13 -48
  15. package/src/llama.cpp/src/llama-arch.cpp +222 -15
  16. package/src/llama.cpp/src/llama-arch.h +16 -1
  17. package/src/llama.cpp/src/llama-batch.cpp +76 -70
  18. package/src/llama.cpp/src/llama-batch.h +24 -18
  19. package/src/llama.cpp/src/llama-chat.cpp +44 -1
  20. package/src/llama.cpp/src/llama-chat.h +2 -0
  21. package/src/llama.cpp/src/llama-context.cpp +134 -95
  22. package/src/llama.cpp/src/llama-context.h +13 -16
  23. package/src/llama.cpp/src/llama-cparams.h +3 -2
  24. package/src/llama.cpp/src/llama-graph.cpp +239 -154
  25. package/src/llama.cpp/src/llama-graph.h +162 -126
  26. package/src/llama.cpp/src/llama-hparams.cpp +45 -0
  27. package/src/llama.cpp/src/llama-hparams.h +11 -1
  28. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +11 -5
  29. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +3 -0
  30. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +698 -302
  31. package/src/llama.cpp/src/llama-kv-cache-unified.h +89 -31
  32. package/src/llama.cpp/src/llama-memory-hybrid.cpp +1 -0
  33. package/src/llama.cpp/src/llama-memory-recurrent.cpp +6 -9
  34. package/src/llama.cpp/src/llama-model.cpp +2309 -665
  35. package/src/llama.cpp/src/llama-model.h +18 -4
  36. package/src/llama.cpp/src/llama-quant.cpp +2 -2
  37. package/src/llama.cpp/src/llama-vocab.cpp +368 -9
  38. package/src/llama.cpp/src/llama-vocab.h +43 -0
  39. package/src/llama.cpp/src/unicode.cpp +207 -0
  40. package/src/llama.cpp/src/unicode.h +2 -0
@@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
27
27
  const llama_vocab & vocab,
28
28
  const llama_memory_i * memory,
29
29
  uint32_t n_embd,
30
+ uint32_t n_seq_max,
30
31
  bool output_all) {
31
32
  clear();
32
33
 
@@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
40
41
  // validate input batch
41
42
  //
42
43
 
44
+ if (n_seq_max > LLAMA_MAX_SEQ) {
45
+ LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
46
+ return false;
47
+ }
48
+
43
49
  if (batch.token) {
44
50
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
45
51
  if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
@@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
52
58
  if (batch.seq_id) {
53
59
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
54
60
  for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
55
- if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
56
- LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
61
+ if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
62
+ LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
57
63
  return false;
58
64
  }
59
65
  }
@@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
86
92
 
87
93
  // initialize the starting position for each sequence based on the positions in the memory
88
94
  llama_pos p0[LLAMA_MAX_SEQ];
89
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
95
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
90
96
  if (!memory) {
91
97
  // if no memory -> start from 0
92
98
  p0[s] = 0;
@@ -143,13 +149,16 @@ bool llama_batch_allocr::init(
143
149
  // compute stats
144
150
  //
145
151
 
146
- this->n_embd = n_embd;
152
+ this->n_embd = n_embd;
153
+ this->n_seq_max = n_seq_max;
147
154
 
148
155
  // count the outputs in this batch
149
156
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
150
157
  n_outputs += batch.logits[i] != 0;
151
158
  }
152
159
 
160
+ has_cpl = false;
161
+
153
162
  // determine coupled sequences
154
163
  // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
155
164
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
@@ -189,7 +198,7 @@ bool llama_batch_allocr::init(
189
198
  seq_set_map[cur].push_back(i);
190
199
  }
191
200
 
192
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
201
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
193
202
  if (seq_set_unq.test(s)) {
194
203
  seq_idx[s] = seq_id_unq.size();
195
204
  seq_id_unq.push_back(s);
@@ -201,7 +210,7 @@ bool llama_batch_allocr::init(
201
210
  LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
202
211
 
203
212
  llama_ubatch ubatch {
204
- /*.equal_seqs =*/ false,
213
+ /*.b_equal_seqs =*/ false,
205
214
  /*.n_tokens =*/ (uint32_t) batch.n_tokens,
206
215
  /*.n_seq_tokens =*/ (uint32_t) 1,
207
216
  /*.n_seqs =*/ (uint32_t) batch.n_tokens,
@@ -214,6 +223,7 @@ bool llama_batch_allocr::init(
214
223
  /*.seq_id_unq =*/ this->seq_id_unq.data(),
215
224
  /*.seq_idx =*/ this->seq_idx.data(),
216
225
  /*.output =*/ batch.logits,
226
+ /*.data =*/ {},
217
227
  };
218
228
 
219
229
  ubatch_print(ubatch, debug);
@@ -241,7 +251,7 @@ bool llama_batch_allocr::init(
241
251
  // consistency checks
242
252
  //
243
253
 
244
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
254
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
245
255
  if (seq_pos[s].empty()) {
246
256
  continue;
247
257
  }
@@ -284,8 +294,8 @@ bool llama_batch_allocr::init(
284
294
  }
285
295
 
286
296
  if (memory) {
287
- for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
288
- for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
297
+ for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
298
+ for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
289
299
  if (seq_cpl[s0][s1]) {
290
300
  if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
291
301
  memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
@@ -316,12 +326,12 @@ bool llama_batch_allocr::init(
316
326
  //
317
327
  {
318
328
  seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
319
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
329
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
320
330
  cur_seq_set[s].set();
321
331
  }
322
332
 
323
333
  llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
324
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
334
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
325
335
  cur_seq_pos[s] = -1;
326
336
  }
327
337
 
@@ -357,39 +367,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
357
367
  clear();
358
368
  split_reset();
359
369
 
360
- ubatches.emplace_back();
370
+ auto udata = std::make_shared<llama_ubatch::data_t>();
361
371
 
362
- auto & ubatch = ubatches.back();
363
-
364
- ubatch.token .resize(n_tokens);
365
- ubatch.embd .clear();
366
- ubatch.pos .resize(n_tokens);
367
- ubatch.n_seq_id .resize(n_tokens);
368
- ubatch.seq_id .resize(n_tokens);
369
- ubatch.seq_id_unq.resize(0);
370
- ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
371
- ubatch.output .resize(n_tokens);
372
+ udata->token .resize(n_tokens);
373
+ udata->embd .clear();
374
+ udata->pos .resize(n_tokens);
375
+ udata->n_seq_id .resize(n_tokens);
376
+ udata->seq_id .resize(n_tokens);
377
+ udata->seq_id_unq.resize(0);
378
+ udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
379
+ udata->output .resize(n_tokens);
372
380
 
373
381
  for (uint32_t s = 0; s < n_seqs; ++s) {
374
- ubatch.seq_idx[s] = s;
375
- ubatch.seq_id_unq.push_back(s);
382
+ udata->seq_idx[s] = s;
383
+ udata->seq_id_unq.push_back(s);
376
384
  }
377
385
 
378
386
  llama_ubatch res {
379
- /*.equal_seqs =*/ true,
387
+ /*.b_equal_seqs =*/ true,
380
388
  /*.n_tokens =*/ n_tokens,
381
389
  /*.n_seq_tokens =*/ n_seq_tokens,
382
390
  /*.n_seqs =*/ n_seqs,
383
391
  /*.n_seqs_unq =*/ n_seqs,
384
392
 
385
- /*.token =*/ ubatch.token.data(),
393
+ /*.token =*/ udata->token.data(),
386
394
  /*.embd =*/ nullptr,
387
- /*.pos =*/ ubatch.pos.data(),
388
- /*.n_seq_id =*/ ubatch.n_seq_id.data(),
389
- /*.seq_id =*/ ubatch.seq_id.data(),
390
- /*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
391
- /*.seq_idx =*/ ubatch.seq_idx.data(),
392
- /*.output =*/ ubatch.output.data(),
395
+ /*.pos =*/ udata->pos.data(),
396
+ /*.n_seq_id =*/ udata->n_seq_id.data(),
397
+ /*.seq_id =*/ udata->seq_id.data(),
398
+ /*.seq_id_unq =*/ udata->seq_id_unq.data(),
399
+ /*.seq_idx =*/ udata->seq_idx.data(),
400
+ /*.output =*/ udata->output.data(),
401
+ /*.data =*/ std::move(udata),
393
402
  };
394
403
 
395
404
  return res;
@@ -430,8 +439,6 @@ void llama_batch_allocr::split_reset() {
430
439
 
431
440
  used.clear();
432
441
  used.resize(get_n_tokens(), false);
433
-
434
- ubatches.clear();
435
442
  }
436
443
 
437
444
  llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
@@ -646,78 +653,77 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
646
653
 
647
654
  assert(n_tokens%n_seqs == 0);
648
655
 
649
- ubatches.emplace_back();
650
-
651
- auto & ubatch = ubatches.back();
656
+ auto udata = std::make_shared<llama_ubatch::data_t>();
652
657
 
653
658
  const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
654
659
 
655
660
  const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
656
661
  const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
657
662
 
658
- ubatch.token .resize(n_tokens);
659
- ubatch.embd .resize(n_embd_all);
660
- ubatch.pos .resize(n_pos_all);
661
- ubatch.n_seq_id .resize(n_tokens);
662
- ubatch.seq_id .resize(n_tokens);
663
- ubatch.seq_id_unq.resize(0);
664
- ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
665
- ubatch.output .resize(n_tokens);
663
+ udata->token .resize(n_tokens);
664
+ udata->embd .resize(n_embd_all);
665
+ udata->pos .resize(n_pos_all);
666
+ udata->n_seq_id .resize(n_tokens);
667
+ udata->seq_id .resize(n_tokens);
668
+ udata->seq_id_unq.resize(0);
669
+ udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
670
+ udata->output .resize(n_tokens);
666
671
 
667
672
  seq_set_t seq_set_unq;
668
673
 
669
674
  for (size_t i = 0; i < idxs.size(); ++i) {
670
675
  if (batch.token) {
671
- ubatch.token[i] = batch.token[idxs[i]];
676
+ udata->token[i] = batch.token[idxs[i]];
672
677
  }
673
678
 
674
679
  if (batch.embd) {
675
- memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
680
+ memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
676
681
  }
677
682
 
678
683
  for (int j = 0; j < n_pos_cur; ++j) {
679
- ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
684
+ udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
680
685
  }
681
686
 
682
- ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
683
- ubatch.seq_id[i] = batch.seq_id[idxs[i]];
684
- ubatch.output[i] = batch.logits[idxs[i]];
687
+ udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
688
+ udata->seq_id[i] = batch.seq_id[idxs[i]];
689
+ udata->output[i] = batch.logits[idxs[i]];
685
690
 
686
- for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
687
- seq_set_unq.set(ubatch.seq_id[i][s]);
691
+ for (int s = 0; s < udata->n_seq_id[i]; ++s) {
692
+ seq_set_unq.set(udata->seq_id[i][s]);
688
693
  }
689
694
 
690
- if (ubatch.output[i]) {
695
+ if (udata->output[i]) {
691
696
  out_ids.push_back(idxs[i]);
692
697
  }
693
698
  }
694
699
 
695
- for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
700
+ for (uint32_t s = 0; s < n_seq_max; ++s) {
696
701
  if (seq_set_unq.test(s)) {
697
- ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
698
- ubatch.seq_id_unq.push_back(s);
702
+ udata->seq_idx[s] = udata->seq_id_unq.size();
703
+ udata->seq_id_unq.push_back(s);
699
704
  }
700
705
  }
701
706
 
702
707
  llama_ubatch res {
703
- /*.equal_seqs =*/ equal_seqs,
708
+ /*.b_equal_seqs =*/ equal_seqs,
704
709
  /*.n_tokens =*/ n_tokens,
705
710
  /*.n_seq_tokens =*/ n_tokens/n_seqs,
706
711
  /*.n_seqs =*/ n_seqs,
707
- /*.n_seqs_unq =*/ (uint32_t) ubatch.seq_id_unq.size(),
708
-
709
- /*.token =*/ batch.token ? ubatch.token.data() : nullptr,
710
- /*.embd =*/ batch.embd ? ubatch.embd.data() : nullptr,
711
- /*.pos =*/ ubatch.pos.data(),
712
- /*.n_seq_id =*/ ubatch.n_seq_id.data(),
713
- /*.seq_id =*/ ubatch.seq_id.data(),
714
- /*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
715
- /*.seq_idx =*/ ubatch.seq_idx.data(),
716
- /*.output =*/ ubatch.output.data(),
712
+ /*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
713
+
714
+ /*.token =*/ batch.token ? udata->token.data() : nullptr,
715
+ /*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
716
+ /*.pos =*/ udata->pos.data(),
717
+ /*.n_seq_id =*/ udata->n_seq_id.data(),
718
+ /*.seq_id =*/ udata->seq_id.data(),
719
+ /*.seq_id_unq =*/ udata->seq_id_unq.data(),
720
+ /*.seq_idx =*/ udata->seq_idx.data(),
721
+ /*.output =*/ udata->output.data(),
722
+ /*.data =*/ std::move(udata),
717
723
  };
718
724
 
719
725
  if (debug > 0) {
720
- LLAMA_LOG_DEBUG("%s: added ubatch %d to split:\n", __func__, (int) ubatches.size() - 1);
726
+ LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__);
721
727
 
722
728
  ubatch_print(res, debug);
723
729
  }
@@ -727,7 +733,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
727
733
 
728
734
  void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
729
735
  if (debug > 0) {
730
- LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs);
736
+ LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs());
731
737
  LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
732
738
  LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
733
739
  LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
@@ -8,12 +8,17 @@
8
8
  #include <vector>
9
9
  #include <set>
10
10
  #include <bitset>
11
+ #include <memory>
11
12
  #include <unordered_map>
12
13
 
13
14
  // keep this struct lightweight
14
- // it points to data in `llama_batch_allocr`
15
15
  struct llama_ubatch {
16
- bool equal_seqs;
16
+ bool equal_seqs() const {
17
+ return b_equal_seqs != 0;
18
+ }
19
+
20
+ uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
21
+ // otherwise address sanitizer complains
17
22
  // TODO: whole_seqs for embeddings?
18
23
 
19
24
  uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
@@ -34,6 +39,20 @@ struct llama_ubatch {
34
39
  llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
35
40
  int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
36
41
  int8_t * output; // [n_tokens] | i | -
42
+
43
+ struct data_t {
44
+ std::vector<llama_token> token;
45
+ std::vector<float> embd;
46
+ std::vector<llama_pos> pos;
47
+ std::vector<int32_t> n_seq_id;
48
+ std::vector<llama_seq_id *> seq_id;
49
+ std::vector<llama_seq_id> seq_id_unq;
50
+ std::vector<int32_t> seq_idx;
51
+ std::vector<int8_t> output;
52
+ };
53
+
54
+ // the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
55
+ std::shared_ptr<data_t> data;
37
56
  };
38
57
 
39
58
  // a helper for sanitizing, fulfilling and splitting a batch
@@ -48,6 +67,7 @@ public:
48
67
  const llama_vocab & vocab,
49
68
  const llama_memory_i * memory,
50
69
  uint32_t n_embd,
70
+ uint32_t n_seq_max,
51
71
  bool output_all);
52
72
 
53
73
  const llama_batch & get_batch() const;
@@ -100,6 +120,7 @@ private:
100
120
  const uint32_t n_pos_per_embd;
101
121
 
102
122
  uint32_t n_embd;
123
+ uint32_t n_seq_max;
103
124
  uint32_t n_outputs;
104
125
 
105
126
  std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
@@ -115,7 +136,7 @@ private:
115
136
  using seq_cpl_t = std::vector<bool>;
116
137
 
117
138
  // helper flag to quickly determine if there are any coupled sequences in the batch
118
- bool has_cpl;
139
+ bool has_cpl = false;
119
140
 
120
141
  std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
121
142
  std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
@@ -135,20 +156,5 @@ private:
135
156
  // used[i] indicates if token i has already been used in a previous ubatch
136
157
  std::vector<bool> used;
137
158
 
138
- // llama_ubatch points to this data:
139
- struct ubatch {
140
- std::vector<llama_token> token;
141
- std::vector<float> embd;
142
- std::vector<llama_pos> pos;
143
- std::vector<int32_t> n_seq_id;
144
- std::vector<llama_seq_id *> seq_id;
145
- std::vector<llama_seq_id> seq_id_unq;
146
- std::vector<int32_t> seq_idx;
147
- std::vector<int8_t> output;
148
- };
149
-
150
- // current splitting state:
151
- std::vector<ubatch> ubatches;
152
-
153
159
  int debug;
154
160
  };
@@ -56,6 +56,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
56
56
  { "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
57
57
  { "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
58
58
  { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
59
+ { "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 },
59
60
  { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
60
61
  { "granite", LLM_CHAT_TEMPLATE_GRANITE },
61
62
  { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
@@ -65,6 +66,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
65
66
  { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
66
67
  { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
67
68
  { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
69
+ { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
68
70
  };
69
71
 
70
72
  llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -167,10 +169,13 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
167
169
  } else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) {
168
170
  return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
169
171
  } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
172
+ if (tmpl_contains("[|tool|]")) {
173
+ return LLM_CHAT_TEMPLATE_EXAONE_4;
174
+ }
170
175
  // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
171
176
  // EXAONE-3.0-7.8B-Instruct
172
177
  return LLM_CHAT_TEMPLATE_EXAONE_3;
173
- } else if (tmpl_contains("rwkv-world")) {
178
+ } else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) {
174
179
  return LLM_CHAT_TEMPLATE_RWKV_WORLD;
175
180
  } else if (tmpl_contains("<|start_of_role|>")) {
176
181
  return LLM_CHAT_TEMPLATE_GRANITE;
@@ -188,6 +193,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
188
193
  return LLM_CHAT_TEMPLATE_DOTS1;
189
194
  } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
190
195
  return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
196
+ } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
197
+ return LLM_CHAT_TEMPLATE_KIMI_K2;
191
198
  }
192
199
  return LLM_CHAT_TEMPLATE_UNKNOWN;
193
200
  }
@@ -529,6 +536,22 @@ int32_t llm_chat_apply_template(
529
536
  if (add_ass) {
530
537
  ss << "[|assistant|]";
531
538
  }
539
+ } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_4) {
540
+ for (auto message : chat) {
541
+ std::string role(message->role);
542
+ if (role == "system") {
543
+ ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
544
+ } else if (role == "user") {
545
+ ss << "[|user|]" << trim(message->content) << "\n";
546
+ } else if (role == "assistant") {
547
+ ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
548
+ } else if (role == "tool") {
549
+ ss << "[|tool|]" << trim(message->content) << "[|endofturn|]\n";
550
+ }
551
+ }
552
+ if (add_ass) {
553
+ ss << "[|assistant|]";
554
+ }
532
555
  } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
533
556
  // this template requires the model to have "\n\n" as EOT token
534
557
  for (size_t i = 0; i < chat.size(); i++) {
@@ -680,6 +703,26 @@ int32_t llm_chat_apply_template(
680
703
  ss << "<|startoftext|>" << message->content << "<|extra_0|>";
681
704
  }
682
705
  }
706
+ } else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) {
707
+ // moonshotai/Kimi-K2-Instruct
708
+ for (auto message : chat) {
709
+ std::string role(message->role);
710
+ if (role == "system") {
711
+ ss << "<|im_system|>system<|im_middle|>";
712
+ } else if (role == "user") {
713
+ ss << "<|im_user|>user<|im_middle|>";
714
+ } else if (role == "assistant") {
715
+ ss << "<|im_assistant|>assistant<|im_middle|>";
716
+ } else if (role == "tool") {
717
+ ss << "<|im_system|>tool<|im_middle|>";
718
+ }
719
+
720
+ ss << message->content << "<|im_end|>";
721
+
722
+ if (add_ass) {
723
+ ss << "<|im_assistant|>assistant<|im_middle|>";
724
+ }
725
+ }
683
726
  } else {
684
727
  // template not supported
685
728
  return -1;
@@ -35,6 +35,7 @@ enum llm_chat_template {
35
35
  LLM_CHAT_TEMPLATE_GLMEDGE,
36
36
  LLM_CHAT_TEMPLATE_MINICPM,
37
37
  LLM_CHAT_TEMPLATE_EXAONE_3,
38
+ LLM_CHAT_TEMPLATE_EXAONE_4,
38
39
  LLM_CHAT_TEMPLATE_RWKV_WORLD,
39
40
  LLM_CHAT_TEMPLATE_GRANITE,
40
41
  LLM_CHAT_TEMPLATE_GIGACHAT,
@@ -45,6 +46,7 @@ enum llm_chat_template {
45
46
  LLM_CHAT_TEMPLATE_SMOLVLM,
46
47
  LLM_CHAT_TEMPLATE_DOTS1,
47
48
  LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
49
+ LLM_CHAT_TEMPLATE_KIMI_K2,
48
50
  LLM_CHAT_TEMPLATE_UNKNOWN,
49
51
  };
50
52