llama_cpp 0.7.1 → 0.9.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -75,6 +75,7 @@
75
75
  #include <thread>
76
76
  #include <unordered_map>
77
77
  #include <set>
78
+ #include <forward_list>
78
79
 
79
80
  #if defined(_MSC_VER)
80
81
  #pragma warning(disable: 4244 4267) // possible loss of data
@@ -969,14 +970,15 @@ static void llama_nop(struct ggml_tensor * tensor) { // don't offload by default
969
970
  (void) tensor;
970
971
  }
971
972
 
972
- static std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) {
973
+ static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
973
974
  std::vector<char> result(8, 0);
974
975
  const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
975
976
  if (n_tokens < 0) {
976
977
  result.resize(-n_tokens);
977
978
  int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
978
979
  GGML_ASSERT(check == -n_tokens);
979
- } else {
980
+ }
981
+ else {
980
982
  result.resize(n_tokens);
981
983
  }
982
984
 
@@ -1012,8 +1014,8 @@ enum e_model {
1012
1014
  };
1013
1015
 
1014
1016
  static const size_t kB = 1024;
1015
- static const size_t MB = kB*kB;
1016
- static const size_t GB = kB*kB*kB;
1017
+ static const size_t MB = 1024*kB;
1018
+ static const size_t GB = 1024*MB;
1017
1019
 
1018
1020
  struct llama_hparams {
1019
1021
  bool vocab_only;
@@ -1036,21 +1038,21 @@ struct llama_hparams {
1036
1038
  float f_max_alibi_bias;
1037
1039
 
1038
1040
  bool operator!=(const llama_hparams & other) const {
1039
- if (this->vocab_only != other.vocab_only) return true;
1040
- if (this->n_vocab != other.n_vocab) return true;
1041
+ if (this->vocab_only != other.vocab_only) return true;
1042
+ if (this->n_vocab != other.n_vocab) return true;
1041
1043
  if (this->n_ctx_train != other.n_ctx_train) return true;
1042
- if (this->n_embd != other.n_embd) return true;
1043
- if (this->n_head != other.n_head) return true;
1044
- if (this->n_head_kv != other.n_head_kv) return true;
1045
- if (this->n_layer != other.n_layer) return true;
1046
- if (this->n_rot != other.n_rot) return true;
1047
- if (this->n_ff != other.n_ff) return true;
1044
+ if (this->n_embd != other.n_embd) return true;
1045
+ if (this->n_head != other.n_head) return true;
1046
+ if (this->n_head_kv != other.n_head_kv) return true;
1047
+ if (this->n_layer != other.n_layer) return true;
1048
+ if (this->n_rot != other.n_rot) return true;
1049
+ if (this->n_ff != other.n_ff) return true;
1048
1050
 
1049
1051
  const float EPSILON = 1e-9;
1050
1052
 
1051
- if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
1052
- if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true;
1053
- if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true;
1053
+ if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
1054
+ if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true;
1055
+ if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true;
1054
1056
  if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true;
1055
1057
 
1056
1058
  return false;
@@ -1178,6 +1180,8 @@ struct llama_vocab {
1178
1180
  std::unordered_map<token, id> token_to_id;
1179
1181
  std::vector<token_data> id_to_token;
1180
1182
 
1183
+ std::unordered_map<token, id> special_tokens_cache;
1184
+
1181
1185
  std::map<std::pair<std::string, std::string>, int> bpe_ranks;
1182
1186
 
1183
1187
  // default LLaMA special tokens
@@ -1187,17 +1191,17 @@ struct llama_vocab {
1187
1191
  id special_sep_id = -1;
1188
1192
  id special_pad_id = -1;
1189
1193
 
1190
- id linefeed_id = 13;
1194
+ id linefeed_id = 13;
1191
1195
  id special_prefix_id = 32007;
1192
1196
  id special_middle_id = 32009;
1193
1197
  id special_suffix_id = 32008;
1194
- id special_eot_id = 32010;
1198
+ id special_eot_id = 32010;
1195
1199
 
1196
1200
  int find_bpe_rank(std::string token_left, std::string token_right) const {
1197
- replace_all(token_left, " ", "\u0120");
1198
- replace_all(token_left, "\n", "\u010A");
1199
- replace_all(token_right, " ", "\u0120");
1200
- replace_all(token_right, "\n", "\u010A");
1201
+ GGML_ASSERT(token_left.find(" ") == std::string::npos);
1202
+ GGML_ASSERT(token_left.find("\n") == std::string::npos);
1203
+ GGML_ASSERT(token_right.find(" ") == std::string::npos);
1204
+ GGML_ASSERT(token_right.find("\n") == std::string::npos);
1201
1205
 
1202
1206
  auto it = bpe_ranks.find(std::make_pair(token_left, token_right));
1203
1207
  if (it == bpe_ranks.end()) {
@@ -1351,10 +1355,7 @@ static bool llama_kv_cache_init(
1351
1355
  cache.cells.clear();
1352
1356
  cache.cells.resize(n_ctx);
1353
1357
 
1354
- // TODO: this should be:
1355
- // cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*ggml_tensor_overhead());
1356
- // change it and test that it works
1357
- cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
1358
+ cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*ggml_tensor_overhead());
1358
1359
  memset(cache.buf.data, 0, cache.buf.size);
1359
1360
 
1360
1361
  struct ggml_init_params params;
@@ -1442,7 +1443,10 @@ static bool llama_kv_cache_find_slot(
1442
1443
 
1443
1444
  for (uint32_t i = 0; i < n_tokens; i++) {
1444
1445
  cache.cells[cache.head + i].pos = batch.pos[i];
1445
- cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i]);
1446
+
1447
+ for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
1448
+ cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
1449
+ }
1446
1450
  }
1447
1451
 
1448
1452
  return true;
@@ -1522,6 +1526,9 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
1522
1526
  cache.cells[i].pos = -1;
1523
1527
  cache.cells[i].seq_id.clear();
1524
1528
  if (new_head == cache.size) new_head = i;
1529
+ } else {
1530
+ cache.cells[i].seq_id.clear();
1531
+ cache.cells[i].seq_id.insert(seq_id);
1525
1532
  }
1526
1533
  }
1527
1534
 
@@ -2120,7 +2127,7 @@ static void llm_load_hparams(
2120
2127
  }
2121
2128
 
2122
2129
  // TODO: This should probably be in llama.h
2123
- static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos);
2130
+ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special = false);
2124
2131
  static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch);
2125
2132
 
2126
2133
  static void llm_load_vocab(
@@ -2227,15 +2234,130 @@ static void llm_load_vocab(
2227
2234
  if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
2228
2235
  vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
2229
2236
  } else {
2230
- vocab.linefeed_id = llama_tokenize_internal(vocab, "\u010A", false)[0];
2237
+ const std::vector<int> ids = llama_tokenize_internal(vocab, "\u010A", false);
2238
+ GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
2239
+ vocab.linefeed_id = ids[0];
2231
2240
  }
2232
2241
 
2233
2242
  // special tokens
2234
- GGUF_GET_KEY(ctx, vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID));
2235
- GGUF_GET_KEY(ctx, vocab.special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_EOS_ID));
2236
- GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID));
2237
- GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID));
2238
- GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID));
2243
+ {
2244
+ const std::vector<std::pair<enum llm_kv, int32_t &>> special_token_types = {
2245
+ { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id },
2246
+ { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id },
2247
+ { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id },
2248
+ { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id },
2249
+ { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id },
2250
+ };
2251
+ for (const auto & it : special_token_types) {
2252
+ const std::string & key = kv(std::get<0>(it));
2253
+ int32_t & id = std::get<1>(it), old_id = id;
2254
+
2255
+ GGUF_GET_KEY(ctx, id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, key);
2256
+ // Must be >= -1 and < vocab size. Since the key is unsigned, -1
2257
+ // can only come from the default value, so there's no point in
2258
+ // validating that.
2259
+ if (size_t(id + 1) > vocab.id_to_token.size()) {
2260
+ LLAMA_LOG_WARN("%s: bad special token: '%s' = %d, using default id %d\n",
2261
+ __func__, key.c_str(), id, old_id);
2262
+ id = old_id;
2263
+ }
2264
+ }
2265
+ }
2266
+
2267
+ // build special tokens cache
2268
+ {
2269
+ // TODO: It is unclear (to me) at this point, whether special tokes are guaranteed to be of a deterministic type,
2270
+ // and will always be correctly labeled in 'added_tokens.json' etc.
2271
+ // The assumption is, since special tokens aren't meant to be exposed to end user, they are designed
2272
+ // to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer
2273
+ // are special tokens.
2274
+ // From testing, this appears to corelate 1:1 with special tokens.
2275
+ //
2276
+
2277
+ // Counting special tokens and verifying in only one direction
2278
+ // is sufficient to detect difference in those two sets.
2279
+ //
2280
+ uint32_t special_tokens_count_by_type = 0;
2281
+ uint32_t special_tokens_count_from_verification = 0;
2282
+
2283
+ bool special_tokens_definition_mismatch = false;
2284
+
2285
+ for (const auto & t : vocab.token_to_id) {
2286
+ const auto & token = t.first;
2287
+ const auto & id = t.second;
2288
+
2289
+ // Count all non-normal tokens in the vocab while iterating
2290
+ if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
2291
+ special_tokens_count_by_type++;
2292
+ }
2293
+
2294
+ // Skip single character tokens
2295
+ if (token.length() > 1) {
2296
+ bool is_tokenizable = false;
2297
+
2298
+ // Split token string representation in two, in all possible ways
2299
+ // and check if both halves can be matched to a valid token
2300
+ for (unsigned i = 1; i < token.length();) {
2301
+ const auto left = token.substr(0, i);
2302
+ const auto right = token.substr(i);
2303
+
2304
+ // check if we didnt partition in the middle of a utf sequence
2305
+ auto utf = utf8_len(left.at(left.length() - 1));
2306
+
2307
+ if (utf == 1) {
2308
+ if (vocab.token_to_id.find(left) != vocab.token_to_id.end() &&
2309
+ vocab.token_to_id.find(right) != vocab.token_to_id.end() ) {
2310
+ is_tokenizable = true;
2311
+ break;
2312
+ }
2313
+ i++;
2314
+ } else {
2315
+ // skip over the rest of multibyte utf sequence
2316
+ i += utf - 1;
2317
+ }
2318
+ }
2319
+
2320
+ if (!is_tokenizable) {
2321
+ // Some tokens are multibyte, but they are utf sequences with equivalent text length of 1
2322
+ // it's faster to re-filter them here, since there are way less candidates now
2323
+
2324
+ // Calculate a total "utf" length of a token string representation
2325
+ size_t utf8_str_len = 0;
2326
+ for (unsigned i = 0; i < token.length();) {
2327
+ utf8_str_len++;
2328
+ i += utf8_len(token.at(i));
2329
+ }
2330
+
2331
+ // And skip the ones which are one character
2332
+ if (utf8_str_len > 1) {
2333
+ // At this point what we have left are special tokens only
2334
+ vocab.special_tokens_cache[token] = id;
2335
+
2336
+ // Count manually found special tokens
2337
+ special_tokens_count_from_verification++;
2338
+
2339
+ // If this manually found special token is not marked as such, flag a mismatch
2340
+ if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) {
2341
+ special_tokens_definition_mismatch = true;
2342
+ }
2343
+ }
2344
+ }
2345
+ }
2346
+ }
2347
+
2348
+ if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) {
2349
+ LLAMA_LOG_WARN("%s: mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n",
2350
+ __func__,
2351
+ special_tokens_count_from_verification, vocab.id_to_token.size(),
2352
+ special_tokens_count_by_type, vocab.id_to_token.size()
2353
+ );
2354
+ } else {
2355
+ LLAMA_LOG_INFO("%s: special tokens definition check successful ( %u/%zu ).\n",
2356
+ __func__,
2357
+ special_tokens_count_from_verification, vocab.id_to_token.size()
2358
+ );
2359
+ }
2360
+ }
2239
2361
  }
2240
2362
 
2241
2363
  static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
@@ -2834,8 +2956,8 @@ static void llm_load_tensors(
2834
2956
  auto & layer = model.layers[i];
2835
2957
 
2836
2958
  layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
2837
- layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd}, backend_split);
2838
- layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
2959
+ layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
2960
+ layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
2839
2961
 
2840
2962
  layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
2841
2963
 
@@ -3075,7 +3197,7 @@ static struct ggml_cgraph * llm_build_llama(
3075
3197
  for (int h = 0; h < 1; ++h) {
3076
3198
  for (int j = 0; j < n_tokens; ++j) {
3077
3199
  const llama_pos pos = batch.pos[j];
3078
- const llama_seq_id seq_id = batch.seq_id[j];
3200
+ const llama_seq_id seq_id = batch.seq_id[j][0];
3079
3201
 
3080
3202
  for (int i = 0; i < n_kv; ++i) {
3081
3203
  if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -3461,7 +3583,7 @@ static struct ggml_cgraph * llm_build_baichaun(
3461
3583
  for (int h = 0; h < 1; ++h) {
3462
3584
  for (int j = 0; j < n_tokens; ++j) {
3463
3585
  const llama_pos pos = batch.pos[j];
3464
- const llama_seq_id seq_id = batch.seq_id[j];
3586
+ const llama_seq_id seq_id = batch.seq_id[j][0];
3465
3587
 
3466
3588
  for (int i = 0; i < n_kv; ++i) {
3467
3589
  if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -3860,7 +3982,7 @@ static struct ggml_cgraph * llm_build_refact(
3860
3982
  for (int h = 0; h < 1; ++h) {
3861
3983
  for (int j = 0; j < n_tokens; ++j) {
3862
3984
  const llama_pos pos = batch.pos[j];
3863
- const llama_seq_id seq_id = batch.seq_id[j];
3985
+ const llama_seq_id seq_id = batch.seq_id[j][0];
3864
3986
 
3865
3987
  for (int i = 0; i < n_kv; ++i) {
3866
3988
  if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -4212,7 +4334,7 @@ static struct ggml_cgraph * llm_build_falcon(
4212
4334
  for (int h = 0; h < 1; ++h) {
4213
4335
  for (int j = 0; j < n_tokens; ++j) {
4214
4336
  const llama_pos pos = batch.pos[j];
4215
- const llama_seq_id seq_id = batch.seq_id[j];
4337
+ const llama_seq_id seq_id = batch.seq_id[j][0];
4216
4338
 
4217
4339
  for (int i = 0; i < n_kv; ++i) {
4218
4340
  if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -4564,7 +4686,7 @@ static struct ggml_cgraph * llm_build_starcoder(
4564
4686
  for (int h = 0; h < 1; ++h) {
4565
4687
  for (int j = 0; j < n_tokens; ++j) {
4566
4688
  const llama_pos pos = batch.pos[j];
4567
- const llama_seq_id seq_id = batch.seq_id[j];
4689
+ const llama_seq_id seq_id = batch.seq_id[j][0];
4568
4690
 
4569
4691
  for (int i = 0; i < n_kv; ++i) {
4570
4692
  if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -4795,7 +4917,7 @@ static struct ggml_cgraph * llm_build_persimmon(
4795
4917
  for (int h = 0; h < 1; ++h) {
4796
4918
  for (int j = 0; j < n_tokens; ++j) {
4797
4919
  const llama_pos pos = batch.pos[j];
4798
- const llama_seq_id seq_id = batch.seq_id[j];
4920
+ const llama_seq_id seq_id = batch.seq_id[j][0];
4799
4921
  for (int i = 0; i < n_kv; ++i) {
4800
4922
  if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
4801
4923
  data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
@@ -5193,7 +5315,7 @@ static struct ggml_cgraph * llm_build_bloom(
5193
5315
  for (int h = 0; h < 1; ++h) {
5194
5316
  for (int j = 0; j < n_tokens; ++j) {
5195
5317
  const llama_pos pos = batch.pos[j];
5196
- const llama_seq_id seq_id = batch.seq_id[j];
5318
+ const llama_seq_id seq_id = batch.seq_id[j][0];
5197
5319
 
5198
5320
  for (int i = 0; i < n_kv; ++i) {
5199
5321
  if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -5363,7 +5485,7 @@ static struct ggml_cgraph * llm_build_mpt(
5363
5485
  const int64_t n_layer = hparams.n_layer;
5364
5486
  const int64_t n_ctx = cparams.n_ctx;
5365
5487
  const int64_t n_head = hparams.n_head;
5366
- const int64_t n_head_kv = hparams.n_head_kv; // == n_head for MPT, as there's no MQA/GQA
5488
+ const int64_t n_head_kv = hparams.n_head_kv;
5367
5489
  const int64_t n_embd_head = hparams.n_embd_head();
5368
5490
  const int64_t n_embd_gqa = hparams.n_embd_gqa();
5369
5491
 
@@ -5461,7 +5583,7 @@ static struct ggml_cgraph * llm_build_mpt(
5461
5583
  for (int h = 0; h < 1; ++h) {
5462
5584
  for (int j = 0; j < n_tokens; ++j) {
5463
5585
  const llama_pos pos = batch.pos[j];
5464
- const llama_seq_id seq_id = batch.seq_id[j];
5586
+ const llama_seq_id seq_id = batch.seq_id[j][0];
5465
5587
 
5466
5588
  for (int i = 0; i < n_kv; ++i) {
5467
5589
  if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@@ -5761,8 +5883,11 @@ static int llama_decode_internal(
5761
5883
 
5762
5884
  // helpers for smoother batch API transistion
5763
5885
  // after deprecating the llama_eval calls, these will be removed
5764
- std::vector<llama_pos> pos;
5765
- std::vector<llama_seq_id> seq_id;
5886
+ std::vector<llama_pos> pos;
5887
+
5888
+ std::vector<int32_t> n_seq_id;
5889
+ std::vector<llama_seq_id *> seq_id_arr;
5890
+ std::vector<std::vector<llama_seq_id>> seq_id;
5766
5891
 
5767
5892
  if (batch.pos == nullptr) {
5768
5893
  pos.resize(n_tokens);
@@ -5774,12 +5899,18 @@ static int llama_decode_internal(
5774
5899
  }
5775
5900
 
5776
5901
  if (batch.seq_id == nullptr) {
5902
+ n_seq_id.resize(n_tokens);
5777
5903
  seq_id.resize(n_tokens);
5904
+ seq_id_arr.resize(n_tokens);
5778
5905
  for (uint32_t i = 0; i < n_tokens; i++) {
5779
- seq_id[i] = batch.all_seq_id;
5906
+ n_seq_id[i] = 1;
5907
+ seq_id[i].resize(1);
5908
+ seq_id[i][0] = batch.all_seq_id;
5909
+ seq_id_arr[i] = seq_id[i].data();
5780
5910
  }
5781
5911
 
5782
- batch.seq_id = seq_id.data();
5912
+ batch.n_seq_id = n_seq_id.data();
5913
+ batch.seq_id = seq_id_arr.data();
5783
5914
  }
5784
5915
 
5785
5916
  if (!llama_kv_cache_find_slot(kv_self, batch)) {
@@ -5800,6 +5931,13 @@ static int llama_decode_internal(
5800
5931
 
5801
5932
  ggml_allocr_alloc_graph(lctx.alloc, gf);
5802
5933
 
5934
+ struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
5935
+ struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
5936
+
5937
+ GGML_ASSERT(strcmp(res->name, "result_output") == 0);
5938
+ GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
5939
+
5940
+
5803
5941
  #ifdef GGML_USE_CUBLAS
5804
5942
  for (int i = 0; i < gf->n_leafs; i++) {
5805
5943
  ggml_tensor * node = gf->leafs[i];
@@ -5817,6 +5955,12 @@ static int llama_decode_internal(
5817
5955
  }
5818
5956
 
5819
5957
  ggml_cuda_set_mul_mat_q(cparams.mul_mat_q);
5958
+
5959
+ // HACK: ggml-alloc may change the tensor backend when reusing a parent, so force output to be on the CPU here if needed
5960
+ if (!lctx.embedding.empty()) {
5961
+ embeddings->backend = GGML_BACKEND_CPU;
5962
+ }
5963
+ res->backend = GGML_BACKEND_CPU;
5820
5964
  #endif
5821
5965
 
5822
5966
  // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
@@ -5841,12 +5985,6 @@ static int llama_decode_internal(
5841
5985
  n_threads = 1;
5842
5986
  }
5843
5987
 
5844
- struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
5845
- struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
5846
-
5847
- GGML_ASSERT(strcmp(res->name, "result_output") == 0);
5848
- GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
5849
-
5850
5988
  #if GGML_USE_MPI
5851
5989
  const int64_t n_layer = hparams.n_layer;
5852
5990
  ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
@@ -5981,11 +6119,10 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
5981
6119
  }
5982
6120
 
5983
6121
  static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
6122
+ static const char * hex = "0123456789ABCDEF";
5984
6123
  switch (llama_vocab_get_type(vocab)) {
5985
6124
  case LLAMA_VOCAB_TYPE_SPM: {
5986
- char buf[7];
5987
- int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch);
5988
- GGML_ASSERT(0 <= result && result < 7);
6125
+ const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
5989
6126
  return vocab.token_to_id.at(buf);
5990
6127
  }
5991
6128
  case LLAMA_VOCAB_TYPE_BPE: {
@@ -6199,7 +6336,6 @@ struct llm_tokenizer_bpe {
6199
6336
  llm_symbol sym;
6200
6337
  size_t char_len = std::min(word.size() - offset, (size_t) ::utf8_len(word[offset]));
6201
6338
  sym.text = word.c_str() + offset;
6202
- sym.n = 1;
6203
6339
  sym.n = char_len;
6204
6340
  offset += sym.n;
6205
6341
  sym.prev = index - 1;
@@ -6459,7 +6595,137 @@ private:
6459
6595
  llm_bigram_bpe::queue work_queue;
6460
6596
  };
6461
6597
 
6462
- static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos) {
6598
+ typedef enum FRAGMENT_BUFFER_VARIANT_TYPE{
6599
+ FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN,
6600
+ FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT
6601
+ } FRAGMENT_BUFFER_VARIANT_TYPE;
6602
+
6603
+ struct fragment_buffer_variant{
6604
+ fragment_buffer_variant(llama_vocab::id _token)
6605
+ :
6606
+ type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN),
6607
+ token(_token),
6608
+ raw_text(_dummy),
6609
+ offset(0),
6610
+ length(0){}
6611
+ fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length)
6612
+ :
6613
+ type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT),
6614
+ token((llama_vocab::id)-1),
6615
+ raw_text(_raw_text),
6616
+ offset(_offset),
6617
+ length(_length){
6618
+ GGML_ASSERT( _offset >= 0 );
6619
+ GGML_ASSERT( _length >= 1 );
6620
+ GGML_ASSERT( offset + length <= raw_text.length() );
6621
+ }
6622
+
6623
+ const FRAGMENT_BUFFER_VARIANT_TYPE type;
6624
+ const llama_vocab::id token;
6625
+ const std::string _dummy;
6626
+ const std::string & raw_text;
6627
+ const uint64_t offset;
6628
+ const uint64_t length;
6629
+ };
6630
+
6631
+ // #define PRETOKENIZERDEBUG
6632
+
6633
+ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer)
6634
+ {
6635
+ // for each special token
6636
+ for (const auto & st: vocab.special_tokens_cache) {
6637
+ const auto & special_token = st.first;
6638
+ const auto & special_id = st.second;
6639
+
6640
+ // for each text fragment
6641
+ std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
6642
+ while (it != buffer.end()) {
6643
+ auto & fragment = (*it);
6644
+
6645
+ // if a fragment is text ( not yet processed )
6646
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
6647
+ auto * raw_text = &(fragment.raw_text);
6648
+
6649
+ auto raw_text_base_offset = fragment.offset;
6650
+ auto raw_text_base_length = fragment.length;
6651
+
6652
+ // loop over the text
6653
+ while (true) {
6654
+ // find the first occurence of a given special token in this fragment
6655
+ // passing offset argument only limit the "search area" but match coordinates
6656
+ // are still relative to the source full raw_text
6657
+ auto match = raw_text->find(special_token, raw_text_base_offset);
6658
+
6659
+ // no occurences found, stop processing this fragment for a given special token
6660
+ if (match == std::string::npos) break;
6661
+
6662
+ // check if match is within bounds of offset <-> length
6663
+ if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break;
6664
+
6665
+ #ifdef PRETOKENIZERDEBUG
6666
+ fprintf(stderr, "FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
6667
+ #endif
6668
+ auto source = std::distance(buffer.begin(), it);
6669
+
6670
+ // if match is further than base offset
6671
+ // then we have some text to the left of it
6672
+ if (match > raw_text_base_offset) {
6673
+ // left
6674
+ const int64_t left_reminder_offset = raw_text_base_offset + 0;
6675
+ const int64_t left_reminder_length = match - raw_text_base_offset;
6676
+ buffer.emplace_after(it, (*raw_text), left_reminder_offset, left_reminder_length);
6677
+
6678
+ #ifdef PRETOKENIZERDEBUG
6679
+ fprintf(stderr, "FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
6680
+ #endif
6681
+ it++;
6682
+ }
6683
+
6684
+ // special token
6685
+ buffer.emplace_after(it, special_id);
6686
+ it++;
6687
+
6688
+ // right
6689
+ if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
6690
+ const int64_t right_reminder_offset = match + special_token.length();
6691
+ const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
6692
+ buffer.emplace_after(it, (*raw_text), right_reminder_offset, right_reminder_length);
6693
+
6694
+ #ifdef PRETOKENIZERDEBUG
6695
+ fprintf(stderr, "FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
6696
+ #endif
6697
+
6698
+ it++;
6699
+
6700
+ if (source == 0) {
6701
+ buffer.erase_after(buffer.before_begin());
6702
+ } else {
6703
+ buffer.erase_after(std::next(buffer.begin(), (source-1)));
6704
+ }
6705
+
6706
+ // repeat for the right side
6707
+ raw_text_base_offset = right_reminder_offset;
6708
+ raw_text_base_length = right_reminder_length;
6709
+
6710
+ #ifdef PRETOKENIZERDEBUG
6711
+ fprintf(stderr, "RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
6712
+ #endif
6713
+ } else {
6714
+ if (source == 0) {
6715
+ buffer.erase_after(buffer.before_begin());
6716
+ } else {
6717
+ buffer.erase_after(std::next(buffer.begin(), (source-1)));
6718
+ }
6719
+ break;
6720
+ }
6721
+ }
6722
+ }
6723
+ it++;
6724
+ }
6725
+ }
6726
+ }
6727
+
6728
+ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special) {
6463
6729
  std::vector<llama_vocab::id> output;
6464
6730
 
6465
6731
  // OG tokenizer behavior:
@@ -6475,20 +6741,58 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
6475
6741
  return output;
6476
6742
  }
6477
6743
 
6744
+ std::forward_list<fragment_buffer_variant> fragment_buffer;
6745
+ fragment_buffer.emplace_front( raw_text, 0, raw_text.length() );
6746
+
6747
+ if (special) tokenizer_st_partition( vocab, fragment_buffer );
6748
+
6478
6749
  switch (vocab.type) {
6479
6750
  case LLAMA_VOCAB_TYPE_SPM:
6480
6751
  {
6481
- // without adding this leading whitespace, we do not get the same results as the original tokenizer
6482
- raw_text = " " + raw_text;
6752
+ for (const auto & fragment: fragment_buffer)
6753
+ {
6754
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
6755
+ {
6756
+ // without adding this leading whitespace, we do not get the same results as the original tokenizer
6757
+
6758
+ // TODO: It's likely possible to get rid of this string copy entirely
6759
+ // by modifying llm_tokenizer_x to operate with string offsets like pre-tokenizer
6760
+ // and passing 'add space prefix' as bool argument
6761
+ //
6762
+ auto raw_text = (special ? "" : " ") + fragment.raw_text.substr(fragment.offset, fragment.length);
6483
6763
 
6484
- llm_tokenizer_spm tokenizer(vocab);
6485
- llama_escape_whitespace(raw_text);
6486
- tokenizer.tokenize(raw_text, output);
6764
+ #ifdef PRETOKENIZERDEBUG
6765
+ fprintf(stderr,"TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
6766
+ #endif
6767
+ llm_tokenizer_spm tokenizer(vocab);
6768
+ llama_escape_whitespace(raw_text);
6769
+ tokenizer.tokenize(raw_text, output);
6770
+ }
6771
+ else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
6772
+ {
6773
+ output.push_back(fragment.token);
6774
+ }
6775
+ }
6487
6776
  } break;
6488
6777
  case LLAMA_VOCAB_TYPE_BPE:
6489
6778
  {
6490
- llm_tokenizer_bpe tokenizer(vocab);
6491
- tokenizer.tokenize(raw_text, output);
6779
+ for (const auto & fragment: fragment_buffer)
6780
+ {
6781
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
6782
+ {
6783
+ auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
6784
+
6785
+ #ifdef PRETOKENIZERDEBUG
6786
+ fprintf(stderr,"TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
6787
+ #endif
6788
+ llm_tokenizer_bpe tokenizer(vocab);
6789
+ tokenizer.tokenize(raw_text, output);
6790
+ }
6791
+ else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
6792
+ {
6793
+ output.push_back(fragment.token);
6794
+ }
6795
+ }
6492
6796
  } break;
6493
6797
  }
6494
6798
 
@@ -6761,7 +7065,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
6761
7065
  std::vector<llama_grammar_candidate> rejects;
6762
7066
 
6763
7067
  if (stack.empty()) {
6764
- for (auto tok : candidates) {
7068
+ for (const auto & tok : candidates) {
6765
7069
  if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
6766
7070
  rejects.push_back(tok);
6767
7071
  }
@@ -6772,7 +7076,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
6772
7076
  const llama_grammar_element * stack_pos = stack.back();
6773
7077
 
6774
7078
  std::vector<llama_grammar_candidate> next_candidates;
6775
- for (auto tok : candidates) {
7079
+ for (const auto & tok : candidates) {
6776
7080
  if (*tok.code_points == 0) {
6777
7081
  // reached end of full codepoints in token, reject iff it ended in a partial sequence
6778
7082
  // that cannot satisfy this position in grammar
@@ -6798,7 +7102,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
6798
7102
  llama_grammar_advance_stack(rules, stack_after, next_stacks);
6799
7103
 
6800
7104
  auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
6801
- for (auto tok : next_rejects) {
7105
+ for (const auto & tok : next_rejects) {
6802
7106
  rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
6803
7107
  }
6804
7108
 
@@ -7125,37 +7429,15 @@ void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array
7125
7429
  llama_sample_temp(ctx, candidates_p, temp);
7126
7430
  }
7127
7431
 
7128
- void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty) {
7129
- if (last_tokens_size == 0 || penalty == 1.0f) {
7130
- return;
7131
- }
7132
-
7133
- const int64_t t_start_sample_us = ggml_time_us();
7134
-
7135
- for (size_t i = 0; i < candidates->size; ++i) {
7136
- const auto * token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id);
7137
- if (token_iter == last_tokens + last_tokens_size) {
7138
- continue;
7139
- }
7140
-
7141
- // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
7142
- // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
7143
- if (candidates->data[i].logit <= 0) {
7144
- candidates->data[i].logit *= penalty;
7145
- } else {
7146
- candidates->data[i].logit /= penalty;
7147
- }
7148
- }
7149
-
7150
- candidates->sorted = false;
7151
-
7152
- if (ctx) {
7153
- ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
7154
- }
7155
- }
7156
-
7157
- void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) {
7158
- if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) {
7432
+ void llama_sample_repetition_penalties(
7433
+ struct llama_context * ctx,
7434
+ llama_token_data_array * candidates,
7435
+ const llama_token * last_tokens,
7436
+ size_t penalty_last_n,
7437
+ float penalty_repeat,
7438
+ float penalty_freq,
7439
+ float penalty_present) {
7440
+ if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
7159
7441
  return;
7160
7442
  }
7161
7443
 
@@ -7163,19 +7445,28 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
7163
7445
 
7164
7446
  // Create a frequency map to count occurrences of each token in last_tokens
7165
7447
  std::unordered_map<llama_token, int> token_count;
7166
- for (size_t i = 0; i < last_tokens_size; ++i) {
7167
- token_count[last_tokens_p[i]]++;
7448
+ for (size_t i = 0; i < penalty_last_n; ++i) {
7449
+ token_count[last_tokens[i]]++;
7168
7450
  }
7169
7451
 
7170
7452
  // Apply frequency and presence penalties to the candidates
7171
7453
  for (size_t i = 0; i < candidates->size; ++i) {
7172
- auto token_iter = token_count.find(candidates->data[i].id);
7454
+ const auto token_iter = token_count.find(candidates->data[i].id);
7173
7455
  if (token_iter == token_count.end()) {
7174
7456
  continue;
7175
7457
  }
7176
7458
 
7177
- int count = token_iter->second;
7178
- candidates->data[i].logit -= float(count) * alpha_frequency + float(count > 0) * alpha_presence;
7459
+ const int count = token_iter->second;
7460
+
7461
+ // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
7462
+ // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
7463
+ if (candidates->data[i].logit <= 0) {
7464
+ candidates->data[i].logit *= penalty_repeat;
7465
+ } else {
7466
+ candidates->data[i].logit /= penalty_repeat;
7467
+ }
7468
+
7469
+ candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
7179
7470
  }
7180
7471
 
7181
7472
  candidates->sorted = false;
@@ -7197,14 +7488,14 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
7197
7488
  }
7198
7489
  }
7199
7490
 
7200
- const llama_token eos = llama_token_eos(ctx);
7491
+ const llama_token eos = llama_token_eos(&ctx->model);
7201
7492
 
7202
7493
  std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
7203
7494
  std::vector<llama_grammar_candidate> candidates_grammar;
7204
7495
 
7205
7496
  for (size_t i = 0; i < candidates->size; ++i) {
7206
7497
  const llama_token id = candidates->data[i].id;
7207
- const std::string piece = llama_token_to_str(ctx, id);
7498
+ const std::string piece = llama_token_to_piece(ctx, id);
7208
7499
  if (id == eos) {
7209
7500
  if (!allow_eos) {
7210
7501
  candidates->data[i].logit = -INFINITY;
@@ -7407,7 +7698,7 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
7407
7698
  void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
7408
7699
  const int64_t t_start_sample_us = ggml_time_us();
7409
7700
 
7410
- if (token == llama_token_eos(ctx)) {
7701
+ if (token == llama_token_eos(&ctx->model)) {
7411
7702
  for (const auto & stack : grammar->stacks) {
7412
7703
  if (stack.empty()) {
7413
7704
  return;
@@ -7416,7 +7707,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
7416
7707
  GGML_ASSERT(false);
7417
7708
  }
7418
7709
 
7419
- const std::string piece = llama_token_to_str(ctx, token);
7710
+ const std::string piece = llama_token_to_piece(ctx, token);
7420
7711
 
7421
7712
  // Note terminating 0 in decoded string
7422
7713
  const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8);
@@ -8616,7 +8907,7 @@ struct llama_context * llama_new_context_with_model(
8616
8907
  // build worst-case graph
8617
8908
  int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch);
8618
8909
  int n_past = cparams.n_ctx - n_tokens;
8619
- llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
8910
+ llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
8620
8911
  ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0));
8621
8912
 
8622
8913
  #ifdef GGML_USE_METAL
@@ -8831,6 +9122,9 @@ void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llam
8831
9122
  }
8832
9123
 
8833
9124
  void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
9125
+ if (seq_id_src == seq_id_dst) {
9126
+ return;
9127
+ }
8834
9128
  llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
8835
9129
  }
8836
9130
 
@@ -9283,7 +9577,7 @@ int llama_eval_embd(
9283
9577
  int n_past) {
9284
9578
  llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
9285
9579
 
9286
- llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, };
9580
+ llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
9287
9581
 
9288
9582
  const int ret = llama_decode_internal(*ctx, batch);
9289
9583
  if (ret < 0) {
@@ -9304,20 +9598,21 @@ struct llama_batch llama_batch_get_one(
9304
9598
  llama_pos pos_0,
9305
9599
  llama_seq_id seq_id) {
9306
9600
  return {
9307
- /*n_tokens =*/ n_tokens,
9308
- /*tokens =*/ tokens,
9309
- /*embd =*/ nullptr,
9310
- /*pos =*/ nullptr,
9311
- /*seq_id =*/ nullptr,
9312
- /*logits =*/ nullptr,
9313
- /*all_pos_0 =*/ pos_0,
9314
- /*all_pos_1 =*/ 1,
9315
- /*all_seq_id =*/ seq_id,
9601
+ /*n_tokens =*/ n_tokens,
9602
+ /*tokens =*/ tokens,
9603
+ /*embd =*/ nullptr,
9604
+ /*pos =*/ nullptr,
9605
+ /*n_seq_id =*/ nullptr,
9606
+ /*seq_id =*/ nullptr,
9607
+ /*logits =*/ nullptr,
9608
+ /*all_pos_0 =*/ pos_0,
9609
+ /*all_pos_1 =*/ 1,
9610
+ /*all_seq_id =*/ seq_id,
9316
9611
  };
9317
9612
  }
9318
9613
 
9319
- struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) {
9320
- llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
9614
+ struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) {
9615
+ llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
9321
9616
 
9322
9617
  if (embd) {
9323
9618
  batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);
@@ -9325,19 +9620,29 @@ struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) {
9325
9620
  batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
9326
9621
  }
9327
9622
 
9328
- batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
9329
- batch.seq_id = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_tokens);
9330
- batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
9623
+ batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
9624
+ batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
9625
+ batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
9626
+ for (int i = 0; i < n_tokens; ++i) {
9627
+ batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
9628
+ }
9629
+ batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
9331
9630
 
9332
9631
  return batch;
9333
9632
  }
9334
9633
 
9335
9634
  void llama_batch_free(struct llama_batch batch) {
9336
- if (batch.token) free(batch.token);
9337
- if (batch.embd) free(batch.embd);
9338
- if (batch.pos) free(batch.pos);
9339
- if (batch.seq_id) free(batch.seq_id);
9340
- if (batch.logits) free(batch.logits);
9635
+ if (batch.token) free(batch.token);
9636
+ if (batch.embd) free(batch.embd);
9637
+ if (batch.pos) free(batch.pos);
9638
+ if (batch.n_seq_id) free(batch.n_seq_id);
9639
+ if (batch.seq_id) {
9640
+ for (int i = 0; i < batch.n_tokens; ++i) {
9641
+ free(batch.seq_id[i]);
9642
+ }
9643
+ free(batch.seq_id);
9644
+ }
9645
+ if (batch.logits) free(batch.logits);
9341
9646
  }
9342
9647
 
9343
9648
  int llama_decode(
@@ -9363,45 +9668,45 @@ float * llama_get_embeddings(struct llama_context * ctx) {
9363
9668
  return ctx->embedding.data();
9364
9669
  }
9365
9670
 
9366
- const char * llama_token_get_text(const struct llama_context * ctx, llama_token token) {
9367
- return ctx->model.vocab.id_to_token[token].text.c_str();
9671
+ const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
9672
+ return model->vocab.id_to_token[token].text.c_str();
9368
9673
  }
9369
9674
 
9370
- float llama_token_get_score(const struct llama_context * ctx, llama_token token) {
9371
- return ctx->model.vocab.id_to_token[token].score;
9675
+ float llama_token_get_score(const struct llama_model * model, llama_token token) {
9676
+ return model->vocab.id_to_token[token].score;
9372
9677
  }
9373
9678
 
9374
- llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token) {
9375
- return ctx->model.vocab.id_to_token[token].type;
9679
+ llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token) {
9680
+ return model->vocab.id_to_token[token].type;
9376
9681
  }
9377
9682
 
9378
- llama_token llama_token_bos(const struct llama_context * ctx) {
9379
- return ctx->model.vocab.special_bos_id;
9683
+ llama_token llama_token_bos(const struct llama_model * model) {
9684
+ return model->vocab.special_bos_id;
9380
9685
  }
9381
9686
 
9382
- llama_token llama_token_eos(const struct llama_context * ctx) {
9383
- return ctx->model.vocab.special_eos_id;
9687
+ llama_token llama_token_eos(const struct llama_model * model) {
9688
+ return model->vocab.special_eos_id;
9384
9689
  }
9385
9690
 
9386
- llama_token llama_token_nl(const struct llama_context * ctx) {
9387
- return ctx->model.vocab.linefeed_id;
9388
- }
9389
- llama_token llama_token_prefix(const struct llama_context * ctx) {
9390
- return ctx->model.vocab.special_prefix_id;
9691
+ llama_token llama_token_nl(const struct llama_model * model) {
9692
+ return model->vocab.linefeed_id;
9391
9693
  }
9392
9694
 
9393
- llama_token llama_token_middle(const struct llama_context * ctx) {
9394
- return ctx->model.vocab.special_middle_id;
9695
+ llama_token llama_token_prefix(const struct llama_model * model) {
9696
+ return model->vocab.special_prefix_id;
9395
9697
  }
9396
9698
 
9397
- llama_token llama_token_suffix(const struct llama_context * ctx) {
9398
- return ctx->model.vocab.special_suffix_id;
9699
+ llama_token llama_token_middle(const struct llama_model * model) {
9700
+ return model->vocab.special_middle_id;
9399
9701
  }
9400
9702
 
9401
- llama_token llama_token_eot(const struct llama_context * ctx) {
9402
- return ctx->model.vocab.special_eot_id;
9703
+ llama_token llama_token_suffix(const struct llama_model * model) {
9704
+ return model->vocab.special_suffix_id;
9403
9705
  }
9404
9706
 
9707
+ llama_token llama_token_eot(const struct llama_model * model) {
9708
+ return model->vocab.special_eot_id;
9709
+ }
9405
9710
 
9406
9711
  int llama_tokenize(
9407
9712
  const struct llama_model * model,
@@ -9409,8 +9714,9 @@ int llama_tokenize(
9409
9714
  int text_len,
9410
9715
  llama_token * tokens,
9411
9716
  int n_max_tokens,
9412
- bool add_bos) {
9413
- auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos);
9717
+ bool add_bos,
9718
+ bool special) {
9719
+ auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special);
9414
9720
 
9415
9721
  if (n_max_tokens < (int) res.size()) {
9416
9722
  // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);