llama_cpp 0.7.1 → 0.8.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -75,6 +75,7 @@
75
75
  #include <thread>
76
76
  #include <unordered_map>
77
77
  #include <set>
78
+ #include <forward_list>
78
79
 
79
80
  #if defined(_MSC_VER)
80
81
  #pragma warning(disable: 4244 4267) // possible loss of data
@@ -1178,6 +1179,8 @@ struct llama_vocab {
1178
1179
  std::unordered_map<token, id> token_to_id;
1179
1180
  std::vector<token_data> id_to_token;
1180
1181
 
1182
+ std::unordered_map<token, id> special_tokens_cache;
1183
+
1181
1184
  std::map<std::pair<std::string, std::string>, int> bpe_ranks;
1182
1185
 
1183
1186
  // default LLaMA special tokens
@@ -1442,7 +1445,10 @@ static bool llama_kv_cache_find_slot(
1442
1445
 
1443
1446
  for (uint32_t i = 0; i < n_tokens; i++) {
1444
1447
  cache.cells[cache.head + i].pos = batch.pos[i];
1445
- cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i]);
1448
+
1449
+ for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
1450
+ cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
1451
+ }
1446
1452
  }
1447
1453
 
1448
1454
  return true;
@@ -1522,6 +1528,9 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
1522
1528
  cache.cells[i].pos = -1;
1523
1529
  cache.cells[i].seq_id.clear();
1524
1530
  if (new_head == cache.size) new_head = i;
1531
+ } else {
1532
+ cache.cells[i].seq_id.clear();
1533
+ cache.cells[i].seq_id.insert(seq_id);
1525
1534
  }
1526
1535
  }
1527
1536
 
@@ -2120,7 +2129,7 @@ static void llm_load_hparams(
2120
2129
  }
2121
2130
 
2122
2131
  // TODO: This should probably be in llama.h
2123
- static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos);
2132
+ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special = false);
2124
2133
  static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch);
2125
2134
 
2126
2135
  static void llm_load_vocab(
@@ -2236,6 +2245,101 @@ static void llm_load_vocab(
2236
2245
  GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID));
2237
2246
  GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID));
2238
2247
  GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID));
2248
+
2249
+ // build special tokens cache
2250
+ {
2251
+ // TODO: It is unclear (to me) at this point, whether special tokes are guaranteed to be of a deterministic type,
2252
+ // and will always be correctly labeled in 'added_tokens.json' etc.
2253
+ // The assumption is, since special tokens aren't meant to be exposed to end user, they are designed
2254
+ // to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer
2255
+ // are special tokens.
2256
+ // From testing, this appears to corelate 1:1 with special tokens.
2257
+ //
2258
+
2259
+ // Counting special tokens and verifying in only one direction
2260
+ // is sufficient to detect difference in those two sets.
2261
+ //
2262
+ uint32_t special_tokens_count_by_type = 0;
2263
+ uint32_t special_tokens_count_from_verification = 0;
2264
+
2265
+ bool special_tokens_definition_mismatch = false;
2266
+
2267
+ for (const auto & t : vocab.token_to_id) {
2268
+ const auto & token = t.first;
2269
+ const auto & id = t.second;
2270
+
2271
+ // Count all non-normal tokens in the vocab while iterating
2272
+ if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
2273
+ special_tokens_count_by_type++;
2274
+ }
2275
+
2276
+ // Skip single character tokens
2277
+ if (token.length() > 1) {
2278
+ bool is_tokenizable = false;
2279
+
2280
+ // Split token string representation in two, in all possible ways
2281
+ // and check if both halves can be matched to a valid token
2282
+ for (unsigned i = 1; i < token.length();) {
2283
+ const auto left = token.substr(0, i);
2284
+ const auto right = token.substr(i);
2285
+
2286
+ // check if we didnt partition in the middle of a utf sequence
2287
+ auto utf = utf8_len(left.at(left.length() - 1));
2288
+
2289
+ if (utf == 1) {
2290
+ if (vocab.token_to_id.find(left) != vocab.token_to_id.end() &&
2291
+ vocab.token_to_id.find(right) != vocab.token_to_id.end() ) {
2292
+ is_tokenizable = true;
2293
+ break;
2294
+ }
2295
+ i++;
2296
+ } else {
2297
+ // skip over the rest of multibyte utf sequence
2298
+ i += utf - 1;
2299
+ }
2300
+ }
2301
+
2302
+ if (!is_tokenizable) {
2303
+ // Some tokens are multibyte, but they are utf sequences with equivalent text length of 1
2304
+ // it's faster to re-filter them here, since there are way less candidates now
2305
+
2306
+ // Calculate a total "utf" length of a token string representation
2307
+ size_t utf8_str_len = 0;
2308
+ for (unsigned i = 0; i < token.length();) {
2309
+ utf8_str_len++;
2310
+ i += utf8_len(token.at(i));
2311
+ }
2312
+
2313
+ // And skip the ones which are one character
2314
+ if (utf8_str_len > 1) {
2315
+ // At this point what we have left are special tokens only
2316
+ vocab.special_tokens_cache[token] = id;
2317
+
2318
+ // Count manually found special tokens
2319
+ special_tokens_count_from_verification++;
2320
+
2321
+ // If this manually found special token is not marked as such, flag a mismatch
2322
+ if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) {
2323
+ special_tokens_definition_mismatch = true;
2324
+ }
2325
+ }
2326
+ }
2327
+ }
2328
+ }
2329
+
2330
+ if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) {
2331
+ LLAMA_LOG_WARN("%s: mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n",
2332
+ __func__,
2333
+ special_tokens_count_from_verification, vocab.id_to_token.size(),
2334
+ special_tokens_count_by_type, vocab.id_to_token.size()
2335
+ );
2336
+ } else {
2337
+ LLAMA_LOG_INFO("%s: special tokens definition check successful ( %u/%zu ).\n",
2338
+ __func__,
2339
+ special_tokens_count_from_verification, vocab.id_to_token.size()
2340
+ );
2341
+ }
2342
+ }
2239
2343
  }
2240
2344
 
2241
2345
  static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
@@ -2834,8 +2938,8 @@ static void llm_load_tensors(
2834
2938
  auto & layer = model.layers[i];
2835
2939
 
2836
2940
  layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
2837
- layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd}, backend_split);
2838
- layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
2941
+ layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
2942
+ layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
2839
2943
 
2840
2944
  layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
2841
2945
 
@@ -3075,7 +3179,7 @@ static struct ggml_cgraph * llm_build_llama(
3075
3179
  for (int h = 0; h < 1; ++h) {
3076
3180
  for (int j = 0; j < n_tokens; ++j) {
3077
3181
  const llama_pos pos = batch.pos[j];
3078
- const llama_seq_id seq_id = batch.seq_id[j];
3182
+ const llama_seq_id seq_id = batch.seq_id[j][0];
3079
3183
 
3080
3184
  for (int i = 0; i < n_kv; ++i) {
3081
3185
  if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -3461,7 +3565,7 @@ static struct ggml_cgraph * llm_build_baichaun(
3461
3565
  for (int h = 0; h < 1; ++h) {
3462
3566
  for (int j = 0; j < n_tokens; ++j) {
3463
3567
  const llama_pos pos = batch.pos[j];
3464
- const llama_seq_id seq_id = batch.seq_id[j];
3568
+ const llama_seq_id seq_id = batch.seq_id[j][0];
3465
3569
 
3466
3570
  for (int i = 0; i < n_kv; ++i) {
3467
3571
  if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -3860,7 +3964,7 @@ static struct ggml_cgraph * llm_build_refact(
3860
3964
  for (int h = 0; h < 1; ++h) {
3861
3965
  for (int j = 0; j < n_tokens; ++j) {
3862
3966
  const llama_pos pos = batch.pos[j];
3863
- const llama_seq_id seq_id = batch.seq_id[j];
3967
+ const llama_seq_id seq_id = batch.seq_id[j][0];
3864
3968
 
3865
3969
  for (int i = 0; i < n_kv; ++i) {
3866
3970
  if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -4212,7 +4316,7 @@ static struct ggml_cgraph * llm_build_falcon(
4212
4316
  for (int h = 0; h < 1; ++h) {
4213
4317
  for (int j = 0; j < n_tokens; ++j) {
4214
4318
  const llama_pos pos = batch.pos[j];
4215
- const llama_seq_id seq_id = batch.seq_id[j];
4319
+ const llama_seq_id seq_id = batch.seq_id[j][0];
4216
4320
 
4217
4321
  for (int i = 0; i < n_kv; ++i) {
4218
4322
  if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -4564,7 +4668,7 @@ static struct ggml_cgraph * llm_build_starcoder(
4564
4668
  for (int h = 0; h < 1; ++h) {
4565
4669
  for (int j = 0; j < n_tokens; ++j) {
4566
4670
  const llama_pos pos = batch.pos[j];
4567
- const llama_seq_id seq_id = batch.seq_id[j];
4671
+ const llama_seq_id seq_id = batch.seq_id[j][0];
4568
4672
 
4569
4673
  for (int i = 0; i < n_kv; ++i) {
4570
4674
  if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -4795,7 +4899,7 @@ static struct ggml_cgraph * llm_build_persimmon(
4795
4899
  for (int h = 0; h < 1; ++h) {
4796
4900
  for (int j = 0; j < n_tokens; ++j) {
4797
4901
  const llama_pos pos = batch.pos[j];
4798
- const llama_seq_id seq_id = batch.seq_id[j];
4902
+ const llama_seq_id seq_id = batch.seq_id[j][0];
4799
4903
  for (int i = 0; i < n_kv; ++i) {
4800
4904
  if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
4801
4905
  data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
@@ -5193,7 +5297,7 @@ static struct ggml_cgraph * llm_build_bloom(
5193
5297
  for (int h = 0; h < 1; ++h) {
5194
5298
  for (int j = 0; j < n_tokens; ++j) {
5195
5299
  const llama_pos pos = batch.pos[j];
5196
- const llama_seq_id seq_id = batch.seq_id[j];
5300
+ const llama_seq_id seq_id = batch.seq_id[j][0];
5197
5301
 
5198
5302
  for (int i = 0; i < n_kv; ++i) {
5199
5303
  if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -5363,7 +5467,7 @@ static struct ggml_cgraph * llm_build_mpt(
5363
5467
  const int64_t n_layer = hparams.n_layer;
5364
5468
  const int64_t n_ctx = cparams.n_ctx;
5365
5469
  const int64_t n_head = hparams.n_head;
5366
- const int64_t n_head_kv = hparams.n_head_kv; // == n_head for MPT, as there's no MQA/GQA
5470
+ const int64_t n_head_kv = hparams.n_head_kv;
5367
5471
  const int64_t n_embd_head = hparams.n_embd_head();
5368
5472
  const int64_t n_embd_gqa = hparams.n_embd_gqa();
5369
5473
 
@@ -5461,7 +5565,7 @@ static struct ggml_cgraph * llm_build_mpt(
5461
5565
  for (int h = 0; h < 1; ++h) {
5462
5566
  for (int j = 0; j < n_tokens; ++j) {
5463
5567
  const llama_pos pos = batch.pos[j];
5464
- const llama_seq_id seq_id = batch.seq_id[j];
5568
+ const llama_seq_id seq_id = batch.seq_id[j][0];
5465
5569
 
5466
5570
  for (int i = 0; i < n_kv; ++i) {
5467
5571
  if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -5761,8 +5865,11 @@ static int llama_decode_internal(
5761
5865
 
5762
5866
  // helpers for smoother batch API transistion
5763
5867
  // after deprecating the llama_eval calls, these will be removed
5764
- std::vector<llama_pos> pos;
5765
- std::vector<llama_seq_id> seq_id;
5868
+ std::vector<llama_pos> pos;
5869
+
5870
+ std::vector<int32_t> n_seq_id;
5871
+ std::vector<llama_seq_id *> seq_id_arr;
5872
+ std::vector<std::vector<llama_seq_id>> seq_id;
5766
5873
 
5767
5874
  if (batch.pos == nullptr) {
5768
5875
  pos.resize(n_tokens);
@@ -5774,12 +5881,18 @@ static int llama_decode_internal(
5774
5881
  }
5775
5882
 
5776
5883
  if (batch.seq_id == nullptr) {
5884
+ n_seq_id.resize(n_tokens);
5777
5885
  seq_id.resize(n_tokens);
5886
+ seq_id_arr.resize(n_tokens);
5778
5887
  for (uint32_t i = 0; i < n_tokens; i++) {
5779
- seq_id[i] = batch.all_seq_id;
5888
+ n_seq_id[i] = 1;
5889
+ seq_id[i].resize(1);
5890
+ seq_id[i][0] = batch.all_seq_id;
5891
+ seq_id_arr[i] = seq_id[i].data();
5780
5892
  }
5781
5893
 
5782
- batch.seq_id = seq_id.data();
5894
+ batch.n_seq_id = n_seq_id.data();
5895
+ batch.seq_id = seq_id_arr.data();
5783
5896
  }
5784
5897
 
5785
5898
  if (!llama_kv_cache_find_slot(kv_self, batch)) {
@@ -5800,6 +5913,13 @@ static int llama_decode_internal(
5800
5913
 
5801
5914
  ggml_allocr_alloc_graph(lctx.alloc, gf);
5802
5915
 
5916
+ struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
5917
+ struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
5918
+
5919
+ GGML_ASSERT(strcmp(res->name, "result_output") == 0);
5920
+ GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
5921
+
5922
+
5803
5923
  #ifdef GGML_USE_CUBLAS
5804
5924
  for (int i = 0; i < gf->n_leafs; i++) {
5805
5925
  ggml_tensor * node = gf->leafs[i];
@@ -5817,6 +5937,12 @@ static int llama_decode_internal(
5817
5937
  }
5818
5938
 
5819
5939
  ggml_cuda_set_mul_mat_q(cparams.mul_mat_q);
5940
+
5941
+ // HACK: ggml-alloc may change the tensor backend when reusing a parent, so force output to be on the CPU here if needed
5942
+ if (!lctx.embedding.empty()) {
5943
+ embeddings->backend = GGML_BACKEND_CPU;
5944
+ }
5945
+ res->backend = GGML_BACKEND_CPU;
5820
5946
  #endif
5821
5947
 
5822
5948
  // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
@@ -5841,12 +5967,6 @@ static int llama_decode_internal(
5841
5967
  n_threads = 1;
5842
5968
  }
5843
5969
 
5844
- struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
5845
- struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
5846
-
5847
- GGML_ASSERT(strcmp(res->name, "result_output") == 0);
5848
- GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
5849
-
5850
5970
  #if GGML_USE_MPI
5851
5971
  const int64_t n_layer = hparams.n_layer;
5852
5972
  ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
@@ -6199,7 +6319,6 @@ struct llm_tokenizer_bpe {
6199
6319
  llm_symbol sym;
6200
6320
  size_t char_len = std::min(word.size() - offset, (size_t) ::utf8_len(word[offset]));
6201
6321
  sym.text = word.c_str() + offset;
6202
- sym.n = 1;
6203
6322
  sym.n = char_len;
6204
6323
  offset += sym.n;
6205
6324
  sym.prev = index - 1;
@@ -6459,7 +6578,137 @@ private:
6459
6578
  llm_bigram_bpe::queue work_queue;
6460
6579
  };
6461
6580
 
6462
- static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos) {
6581
+ typedef enum FRAGMENT_BUFFER_VARIANT_TYPE{
6582
+ FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN,
6583
+ FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT
6584
+ } FRAGMENT_BUFFER_VARIANT_TYPE;
6585
+
6586
+ struct fragment_buffer_variant{
6587
+ fragment_buffer_variant(llama_vocab::id _token)
6588
+ :
6589
+ type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN),
6590
+ token(_token),
6591
+ raw_text(_dummy),
6592
+ offset(0),
6593
+ length(0){}
6594
+ fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length)
6595
+ :
6596
+ type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT),
6597
+ token((llama_vocab::id)-1),
6598
+ raw_text(_raw_text),
6599
+ offset(_offset),
6600
+ length(_length){
6601
+ GGML_ASSERT( _offset >= 0 );
6602
+ GGML_ASSERT( _length >= 1 );
6603
+ GGML_ASSERT( offset + length <= raw_text.length() );
6604
+ }
6605
+
6606
+ const FRAGMENT_BUFFER_VARIANT_TYPE type;
6607
+ const llama_vocab::id token;
6608
+ const std::string _dummy;
6609
+ const std::string & raw_text;
6610
+ const uint64_t offset;
6611
+ const uint64_t length;
6612
+ };
6613
+
6614
+ // #define PRETOKENIZERDEBUG
6615
+
6616
+ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer)
6617
+ {
6618
+ // for each special token
6619
+ for (const auto & st: vocab.special_tokens_cache) {
6620
+ const auto & special_token = st.first;
6621
+ const auto & special_id = st.second;
6622
+
6623
+ // for each text fragment
6624
+ std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
6625
+ while (it != buffer.end()) {
6626
+ auto & fragment = (*it);
6627
+
6628
+ // if a fragment is text ( not yet processed )
6629
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
6630
+ auto * raw_text = &(fragment.raw_text);
6631
+
6632
+ auto raw_text_base_offset = fragment.offset;
6633
+ auto raw_text_base_length = fragment.length;
6634
+
6635
+ // loop over the text
6636
+ while (true) {
6637
+ // find the first occurence of a given special token in this fragment
6638
+ // passing offset argument only limit the "search area" but match coordinates
6639
+ // are still relative to the source full raw_text
6640
+ auto match = raw_text->find(special_token, raw_text_base_offset);
6641
+
6642
+ // no occurences found, stop processing this fragment for a given special token
6643
+ if (match == std::string::npos) break;
6644
+
6645
+ // check if match is within bounds of offset <-> length
6646
+ if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break;
6647
+
6648
+ #ifdef PRETOKENIZERDEBUG
6649
+ fprintf(stderr, "FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
6650
+ #endif
6651
+ auto source = std::distance(buffer.begin(), it);
6652
+
6653
+ // if match is further than base offset
6654
+ // then we have some text to the left of it
6655
+ if (match > raw_text_base_offset) {
6656
+ // left
6657
+ const int64_t left_reminder_offset = raw_text_base_offset + 0;
6658
+ const int64_t left_reminder_length = match - raw_text_base_offset;
6659
+ buffer.emplace_after(it, (*raw_text), left_reminder_offset, left_reminder_length);
6660
+
6661
+ #ifdef PRETOKENIZERDEBUG
6662
+ fprintf(stderr, "FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
6663
+ #endif
6664
+ it++;
6665
+ }
6666
+
6667
+ // special token
6668
+ buffer.emplace_after(it, special_id);
6669
+ it++;
6670
+
6671
+ // right
6672
+ if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
6673
+ const int64_t right_reminder_offset = match + special_token.length();
6674
+ const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
6675
+ buffer.emplace_after(it, (*raw_text), right_reminder_offset, right_reminder_length);
6676
+
6677
+ #ifdef PRETOKENIZERDEBUG
6678
+ fprintf(stderr, "FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
6679
+ #endif
6680
+
6681
+ it++;
6682
+
6683
+ if (source == 0) {
6684
+ buffer.erase_after(buffer.before_begin());
6685
+ } else {
6686
+ buffer.erase_after(std::next(buffer.begin(), (source-1)));
6687
+ }
6688
+
6689
+ // repeat for the right side
6690
+ raw_text_base_offset = right_reminder_offset;
6691
+ raw_text_base_length = right_reminder_length;
6692
+
6693
+ #ifdef PRETOKENIZERDEBUG
6694
+ fprintf(stderr, "RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
6695
+ #endif
6696
+ } else {
6697
+ if (source == 0) {
6698
+ buffer.erase_after(buffer.before_begin());
6699
+ } else {
6700
+ buffer.erase_after(std::next(buffer.begin(), (source-1)));
6701
+ }
6702
+ break;
6703
+ }
6704
+ }
6705
+ }
6706
+ it++;
6707
+ }
6708
+ }
6709
+ }
6710
+
6711
+ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special) {
6463
6712
  std::vector<llama_vocab::id> output;
6464
6713
 
6465
6714
  // OG tokenizer behavior:
@@ -6475,20 +6724,58 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
6475
6724
  return output;
6476
6725
  }
6477
6726
 
6727
+ std::forward_list<fragment_buffer_variant> fragment_buffer;
6728
+ fragment_buffer.emplace_front( raw_text, 0, raw_text.length() );
6729
+
6730
+ if (special) tokenizer_st_partition( vocab, fragment_buffer );
6731
+
6478
6732
  switch (vocab.type) {
6479
6733
  case LLAMA_VOCAB_TYPE_SPM:
6480
6734
  {
6481
- // without adding this leading whitespace, we do not get the same results as the original tokenizer
6482
- raw_text = " " + raw_text;
6735
+ for (const auto & fragment: fragment_buffer)
6736
+ {
6737
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
6738
+ {
6739
+ // without adding this leading whitespace, we do not get the same results as the original tokenizer
6483
6740
 
6484
- llm_tokenizer_spm tokenizer(vocab);
6485
- llama_escape_whitespace(raw_text);
6486
- tokenizer.tokenize(raw_text, output);
6741
+ // TODO: It's likely possible to get rid of this string copy entirely
6742
+ // by modifying llm_tokenizer_x to operate with string offsets like pre-tokenizer
6743
+ // and passing 'add space prefix' as bool argument
6744
+ //
6745
+ auto raw_text = (special ? "" : " ") + fragment.raw_text.substr(fragment.offset, fragment.length);
6746
+
6747
+ #ifdef PRETOKENIZERDEBUG
6748
+ fprintf(stderr,"TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
6749
+ #endif
6750
+ llm_tokenizer_spm tokenizer(vocab);
6751
+ llama_escape_whitespace(raw_text);
6752
+ tokenizer.tokenize(raw_text, output);
6753
+ }
6754
+ else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
6755
+ {
6756
+ output.push_back(fragment.token);
6757
+ }
6758
+ }
6487
6759
  } break;
6488
6760
  case LLAMA_VOCAB_TYPE_BPE:
6489
6761
  {
6490
- llm_tokenizer_bpe tokenizer(vocab);
6491
- tokenizer.tokenize(raw_text, output);
6762
+ for (const auto & fragment: fragment_buffer)
6763
+ {
6764
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
6765
+ {
6766
+ auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
6767
+
6768
+ #ifdef PRETOKENIZERDEBUG
6769
+ fprintf(stderr,"TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
6770
+ #endif
6771
+ llm_tokenizer_bpe tokenizer(vocab);
6772
+ tokenizer.tokenize(raw_text, output);
6773
+ }
6774
+ else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
6775
+ {
6776
+ output.push_back(fragment.token);
6777
+ }
6778
+ }
6492
6779
  } break;
6493
6780
  }
6494
6781
 
@@ -6761,7 +7048,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
6761
7048
  std::vector<llama_grammar_candidate> rejects;
6762
7049
 
6763
7050
  if (stack.empty()) {
6764
- for (auto tok : candidates) {
7051
+ for (const auto & tok : candidates) {
6765
7052
  if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
6766
7053
  rejects.push_back(tok);
6767
7054
  }
@@ -6772,7 +7059,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
6772
7059
  const llama_grammar_element * stack_pos = stack.back();
6773
7060
 
6774
7061
  std::vector<llama_grammar_candidate> next_candidates;
6775
- for (auto tok : candidates) {
7062
+ for (const auto & tok : candidates) {
6776
7063
  if (*tok.code_points == 0) {
6777
7064
  // reached end of full codepoints in token, reject iff it ended in a partial sequence
6778
7065
  // that cannot satisfy this position in grammar
@@ -6798,7 +7085,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
6798
7085
  llama_grammar_advance_stack(rules, stack_after, next_stacks);
6799
7086
 
6800
7087
  auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
6801
- for (auto tok : next_rejects) {
7088
+ for (const auto & tok : next_rejects) {
6802
7089
  rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
6803
7090
  }
6804
7091
 
@@ -8831,6 +9118,9 @@ void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llam
8831
9118
  }
8832
9119
 
8833
9120
  void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
9121
+ if (seq_id_src == seq_id_dst) {
9122
+ return;
9123
+ }
8834
9124
  llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
8835
9125
  }
8836
9126
 
@@ -9283,7 +9573,7 @@ int llama_eval_embd(
9283
9573
  int n_past) {
9284
9574
  llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
9285
9575
 
9286
- llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, };
9576
+ llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
9287
9577
 
9288
9578
  const int ret = llama_decode_internal(*ctx, batch);
9289
9579
  if (ret < 0) {
@@ -9304,20 +9594,21 @@ struct llama_batch llama_batch_get_one(
9304
9594
  llama_pos pos_0,
9305
9595
  llama_seq_id seq_id) {
9306
9596
  return {
9307
- /*n_tokens =*/ n_tokens,
9308
- /*tokens =*/ tokens,
9309
- /*embd =*/ nullptr,
9310
- /*pos =*/ nullptr,
9311
- /*seq_id =*/ nullptr,
9312
- /*logits =*/ nullptr,
9313
- /*all_pos_0 =*/ pos_0,
9314
- /*all_pos_1 =*/ 1,
9315
- /*all_seq_id =*/ seq_id,
9597
+ /*n_tokens =*/ n_tokens,
9598
+ /*tokens =*/ tokens,
9599
+ /*embd =*/ nullptr,
9600
+ /*pos =*/ nullptr,
9601
+ /*n_seq_id =*/ nullptr,
9602
+ /*seq_id =*/ nullptr,
9603
+ /*logits =*/ nullptr,
9604
+ /*all_pos_0 =*/ pos_0,
9605
+ /*all_pos_1 =*/ 1,
9606
+ /*all_seq_id =*/ seq_id,
9316
9607
  };
9317
9608
  }
9318
9609
 
9319
- struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) {
9320
- llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
9610
+ struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) {
9611
+ llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
9321
9612
 
9322
9613
  if (embd) {
9323
9614
  batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);
@@ -9325,19 +9616,29 @@ struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) {
9325
9616
  batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
9326
9617
  }
9327
9618
 
9328
- batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
9329
- batch.seq_id = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_tokens);
9330
- batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
9619
+ batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
9620
+ batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
9621
+ batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
9622
+ for (int i = 0; i < n_tokens; ++i) {
9623
+ batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
9624
+ }
9625
+ batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
9331
9626
 
9332
9627
  return batch;
9333
9628
  }
9334
9629
 
9335
9630
  void llama_batch_free(struct llama_batch batch) {
9336
- if (batch.token) free(batch.token);
9337
- if (batch.embd) free(batch.embd);
9338
- if (batch.pos) free(batch.pos);
9339
- if (batch.seq_id) free(batch.seq_id);
9340
- if (batch.logits) free(batch.logits);
9631
+ if (batch.token) free(batch.token);
9632
+ if (batch.embd) free(batch.embd);
9633
+ if (batch.pos) free(batch.pos);
9634
+ if (batch.n_seq_id) free(batch.n_seq_id);
9635
+ if (batch.seq_id) {
9636
+ for (int i = 0; i < batch.n_tokens; ++i) {
9637
+ free(batch.seq_id[i]);
9638
+ }
9639
+ free(batch.seq_id);
9640
+ }
9641
+ if (batch.logits) free(batch.logits);
9341
9642
  }
9342
9643
 
9343
9644
  int llama_decode(
@@ -9402,15 +9703,15 @@ llama_token llama_token_eot(const struct llama_context * ctx) {
9402
9703
  return ctx->model.vocab.special_eot_id;
9403
9704
  }
9404
9705
 
9405
-
9406
9706
  int llama_tokenize(
9407
9707
  const struct llama_model * model,
9408
9708
  const char * text,
9409
9709
  int text_len,
9410
9710
  llama_token * tokens,
9411
9711
  int n_max_tokens,
9412
- bool add_bos) {
9413
- auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos);
9712
+ bool add_bos,
9713
+ bool special) {
9714
+ auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special);
9414
9715
 
9415
9716
  if (n_max_tokens < (int) res.size()) {
9416
9717
  // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);