cui-llama.rn 1.2.2 → 1.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -63,6 +63,30 @@ static void llama_log_softmax(float * array, size_t size) {
63
63
  }
64
64
  */
65
65
 
66
+ static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
67
+ if (temp <= 0.0f) {
68
+ // find the token with the highest logit and set the rest to -inf
69
+ size_t max_i = 0;
70
+ float max_l = cur_p->data[0].logit;
71
+
72
+ for (size_t i = 1; i < cur_p->size; ++i) {
73
+ if (cur_p->data[i ].logit > max_l) {
74
+ cur_p->data[max_i].logit = -INFINITY;
75
+ max_i = i;
76
+ max_l = cur_p->data[i].logit;
77
+ } else {
78
+ cur_p->data[i].logit = -INFINITY;
79
+ }
80
+ }
81
+
82
+ return;
83
+ }
84
+
85
+ for (size_t i = 0; i < cur_p->size; ++i) {
86
+ cur_p->data[i].logit /= temp;
87
+ }
88
+ }
89
+
66
90
  static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
67
91
  LM_GGML_ASSERT(cur_p->size > 0);
68
92
 
@@ -428,6 +452,9 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
428
452
 
429
453
  static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
430
454
  auto * ctx = (llama_sampler_dist *) smpl->ctx;
455
+
456
+ llama_sampler_softmax_impl(cur_p);
457
+
431
458
  cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
432
459
  }
433
460
 
@@ -709,6 +736,7 @@ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
709
736
 
710
737
  // xtc
711
738
 
739
+ /*
712
740
  struct llama_sampler_xtc {
713
741
  const uint32_t seed;
714
742
  std::mt19937 rng;
@@ -717,7 +745,7 @@ struct llama_sampler_xtc {
717
745
  const size_t min_keep;
718
746
  };
719
747
 
720
- static const char * llama_sampler_xtc_name(const struct llama_sampler * /* smpl */) {
748
+ static const char * llama_sampler_xtc_name(const struct llama_sampler * /* smpl /) {
721
749
  return "xtc";
722
750
  }
723
751
 
@@ -830,27 +858,27 @@ static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
830
858
  }
831
859
 
832
860
  static struct llama_sampler_i llama_sampler_xtc_i = {
833
- /* .name = */ llama_sampler_xtc_name,
834
- /* .accept = */ nullptr,
835
- /* .apply = */ llama_sampler_xtc_apply,
836
- /* .reset = */ nullptr,
837
- /* .clone = */ llama_sampler_xtc_clone,
838
- /* .free = */ llama_sampler_xtc_free,
861
+ /* .name = / llama_sampler_xtc_name,
862
+ /* .accept = / nullptr,
863
+ /* .apply = / llama_sampler_xtc_apply,
864
+ /* .reset = / nullptr,
865
+ /* .clone = / llama_sampler_xtc_clone,
866
+ /* .free = / llama_sampler_xtc_free,
839
867
  };
840
868
 
841
869
  struct llama_sampler * llama_sampler_init_xtc(float xtc_p, float xtc_t, size_t min_keep, uint32_t seed) {
842
870
  return new llama_sampler {
843
- /* .iface = */ &llama_sampler_xtc_i,
844
- /* .ctx = */ new llama_sampler_xtc {
845
- /* .seed = */ seed,
846
- /* .rng = */ std::mt19937(seed),
847
- /* .xtc_p = */ xtc_p,
848
- /* .xtc_t = */ xtc_t,
849
- /* .min_keep = */ min_keep
871
+ /* .iface = / &llama_sampler_xtc_i,
872
+ /* .ctx = / new llama_sampler_xtc {
873
+ /* .seed = / seed,
874
+ /* .rng = / std::mt19937(seed),
875
+ /* .xtc_p = / xtc_p,
876
+ /* .xtc_t = / xtc_t,
877
+ /* .min_keep = / min_keep
850
878
  },
851
879
  };
852
880
  }
853
-
881
+ */
854
882
  // tail-free
855
883
 
856
884
  struct llama_sampler_tail_free {
@@ -1057,9 +1085,8 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*
1057
1085
 
1058
1086
  static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1059
1087
  const auto * ctx = (llama_sampler_temp *) smpl->ctx;
1060
- for (size_t i = 0; i < cur_p->size; ++i) {
1061
- cur_p->data[i].logit /= ctx->temp;
1062
- }
1088
+
1089
+ llama_sampler_temp_impl(cur_p, ctx->temp);
1063
1090
  }
1064
1091
 
1065
1092
  static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
@@ -1106,6 +1133,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
1106
1133
  if (ctx->delta > 0) {
1107
1134
  const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
1108
1135
  const float max_temp = ctx->temp + ctx->delta;
1136
+
1109
1137
  float exponent_val = ctx->exponent;
1110
1138
 
1111
1139
  // no need to do anything if there is only one (or zero) candidates
@@ -1143,9 +1171,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
1143
1171
  #endif
1144
1172
 
1145
1173
  // Apply the dynamically calculated temperature scaling
1146
- for (size_t i = 0; i < cur_p->size; ++i) {
1147
- cur_p->data[i].logit /= dyn_temp;
1148
- }
1174
+ llama_sampler_temp_impl(cur_p, dyn_temp);
1149
1175
 
1150
1176
  // Re-compute softmax probabilities after scaling logits with dynamic temperature
1151
1177
  const double max_l_double = cur_p->data[0].logit;
@@ -1169,9 +1195,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
1169
1195
  }
1170
1196
  #endif
1171
1197
  } else {
1172
- for (size_t i = 0; i < cur_p->size; ++i) {
1173
- cur_p->data[i].logit /= ctx->temp;
1174
- }
1198
+ llama_sampler_temp_impl(cur_p, ctx->temp);
1175
1199
  }
1176
1200
  }
1177
1201
 
@@ -1204,6 +1228,101 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
1204
1228
  };
1205
1229
  }
1206
1230
 
1231
+ // xtc
1232
+
1233
+ struct llama_sampler_xtc {
1234
+ const float probability;
1235
+ const float threshold;
1236
+ const size_t min_keep;
1237
+
1238
+ const uint32_t seed;
1239
+ uint32_t seed_cur;
1240
+
1241
+ std::mt19937 rng;
1242
+ };
1243
+
1244
+ static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
1245
+ return "xtc";
1246
+ }
1247
+
1248
+ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1249
+ auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1250
+
1251
+ if (ctx->probability <= 0.0f
1252
+ || ctx->threshold > 0.5f
1253
+ || cur_p->size < 2) {
1254
+ return;
1255
+ }
1256
+
1257
+ std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
1258
+ float chance = distribution(ctx->rng);
1259
+ if (chance > ctx->probability) return;
1260
+
1261
+ // in case it's not sorted/recalculated yet
1262
+ llama_sampler_softmax_impl(cur_p);
1263
+
1264
+ int pos_last = 0;
1265
+
1266
+ for (size_t i = 0; i < cur_p->size; ++i) {
1267
+ if (cur_p->data[i].p >= ctx->threshold) {
1268
+ pos_last = i;
1269
+ } else break;
1270
+ }
1271
+
1272
+ if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
1273
+ cur_p->data += pos_last;
1274
+ cur_p->size -= pos_last;
1275
+ }
1276
+ }
1277
+
1278
+ static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) {
1279
+ const auto * ctx = (const llama_sampler_xtc *) smpl->ctx;
1280
+ auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed);
1281
+
1282
+ // copy the state
1283
+ {
1284
+ auto * result_ctx = (llama_sampler_xtc *) result->ctx;
1285
+
1286
+ result_ctx->rng = ctx->rng;
1287
+ }
1288
+
1289
+ return result;
1290
+ }
1291
+
1292
+ static void llama_sampler_xtc_free(struct llama_sampler * smpl) {
1293
+ delete (llama_sampler_xtc *) smpl->ctx;
1294
+ }
1295
+
1296
+ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
1297
+ auto * ctx = (llama_sampler_xtc *) smpl->ctx;
1298
+ ctx->seed_cur = get_rng_seed(ctx->seed);
1299
+ ctx->rng.seed(ctx->seed_cur);
1300
+ }
1301
+
1302
+ static struct llama_sampler_i llama_sampler_xtc_i = {
1303
+ /* .name = */ llama_sampler_xtc_name,
1304
+ /* .accept = */ nullptr,
1305
+ /* .apply = */ llama_sample_xtc_apply,
1306
+ /* .reset = */ llama_sampler_xtc_reset,
1307
+ /* .clone = */ llama_sampler_xtc_clone,
1308
+ /* .free = */ llama_sampler_xtc_free,
1309
+ };
1310
+
1311
+ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
1312
+ auto seed_cur = get_rng_seed(seed);
1313
+ return new llama_sampler {
1314
+ /* .iface = */ &llama_sampler_xtc_i,
1315
+ /* .ctx = */ new llama_sampler_xtc {
1316
+ /* .probability = */ p,
1317
+ /* .threshold = */ t,
1318
+ /* .min_keep = */ min_keep,
1319
+ /* .seed = */ seed,
1320
+ /* .seed_cur = */ seed_cur,
1321
+ /* .rng = */ std::mt19937(seed_cur),
1322
+ },
1323
+ };
1324
+ }
1325
+
1207
1326
  // mirostat
1208
1327
 
1209
1328
  struct llama_sampler_mirostat {
@@ -1789,6 +1908,229 @@ struct llama_sampler * llama_sampler_init_logit_bias(
1789
1908
  };
1790
1909
  }
1791
1910
 
1911
+ // infill
1912
+
1913
+ //#define LM_GGML_DEBUG_SAMPLER_INFILL
1914
+
1915
+ struct llama_sampler_infill {
1916
+ const struct llama_vocab * vocab;
1917
+
1918
+ std::vector<char> buf0;
1919
+ std::vector<char> buf1;
1920
+ };
1921
+
1922
+ static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
1923
+ return "infill";
1924
+ }
1925
+
1926
+ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1927
+ auto * ctx = (llama_sampler_infill *) smpl->ctx;
1928
+
1929
+ llama_sampler_softmax_impl(cur_p);
1930
+
1931
+ #if defined(LM_GGML_DEBUG_SAMPLER_INFILL)
1932
+ #define LOG_DBG_CUR LLAMA_LOG_DEBUG
1933
+ #else
1934
+ #define LOG_DBG_CUR(...)
1935
+ #endif
1936
+
1937
+ for (size_t i = 0; i < cur_p->size; ++i) {
1938
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
1939
+ }
1940
+
1941
+ float p_txt_sum = 0.0f;
1942
+ float p_eog_sum = 0.0f;
1943
+
1944
+ for (size_t i = 0; i < cur_p->size; ++i) {
1945
+ if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
1946
+ p_eog_sum += cur_p->data[i].p;
1947
+ } else {
1948
+ p_txt_sum += cur_p->data[i].p;
1949
+ }
1950
+ }
1951
+
1952
+ const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; LM_GGML_UNUSED(rat);
1953
+
1954
+ LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
1955
+
1956
+ if (3*p_eog_sum*cur_p->size > p_txt_sum) {
1957
+ LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
1958
+
1959
+ // keep just the EOG tokens
1960
+ const auto size_org = cur_p->size;
1961
+
1962
+ cur_p->size = 0;
1963
+
1964
+ float p_sum = 0.0f;
1965
+
1966
+ for (size_t i = 0; i < size_org; ++i) {
1967
+ if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
1968
+ p_sum += cur_p->data[i].p;
1969
+
1970
+ cur_p->data[cur_p->size++] = cur_p->data[i];
1971
+ }
1972
+ }
1973
+
1974
+ // normalize probs
1975
+ for (size_t i = 0; i < cur_p->size; ++i) {
1976
+ cur_p->data[i].p /= p_sum;
1977
+ }
1978
+
1979
+ return;
1980
+ }
1981
+
1982
+ size_t n_combined = 0; LM_GGML_UNUSED(n_combined);
1983
+
1984
+ // combine tokens with common prefix
1985
+ for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
1986
+ for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
1987
+ if (cur_p->data[i0].logit == -INFINITY) {
1988
+ break;
1989
+ }
1990
+
1991
+ if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
1992
+ continue;
1993
+ }
1994
+
1995
+ int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
1996
+ if (len0 < 0) {
1997
+ ctx->buf0.resize(len0);
1998
+ len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
1999
+ assert(len0 > 0);
2000
+ }
2001
+
2002
+ int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
2003
+ if (len1 < 0) {
2004
+ ctx->buf1.resize(len1);
2005
+ len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
2006
+ assert(len1 > 0);
2007
+ }
2008
+
2009
+ // token i0 is a prefix of token i1
2010
+ if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
2011
+ int dst = i0;
2012
+ int src = i1;
2013
+
2014
+ // merge into the token with higher probability
2015
+ if (cur_p->data[i1].p > cur_p->data[i0].p) {
2016
+ std::swap(dst, src);
2017
+ }
2018
+
2019
+ cur_p->data[dst].p += cur_p->data[src].p;
2020
+ cur_p->data[src].logit = -INFINITY;
2021
+ cur_p->data[src].p = 0.0f;
2022
+
2023
+ n_combined++;
2024
+ }
2025
+ }
2026
+ }
2027
+
2028
+ size_t n_non_eog = 0;
2029
+
2030
+ size_t size_org = cur_p->size;
2031
+
2032
+ float p_sum = 0.0f;
2033
+ float thold = 0.2f;
2034
+
2035
+ cur_p->size = 0;
2036
+
2037
+ LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
2038
+
2039
+ for (size_t i = 0; i < size_org; ++i) {
2040
+ const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
2041
+
2042
+ if (cur_p->data[i].p < thold && !is_eog) {
2043
+ continue;
2044
+ }
2045
+
2046
+ if (!is_eog) {
2047
+ ++n_non_eog;
2048
+ }
2049
+
2050
+ p_sum += cur_p->data[i].p;
2051
+
2052
+ // keep this token
2053
+ cur_p->data[cur_p->size++] = cur_p->data[i];
2054
+ }
2055
+
2056
+ LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
2057
+
2058
+ // if no non-EOG tokens are left -> reduce cur_p to single EOT token
2059
+ if (n_non_eog == 0) {
2060
+ cur_p->size = 1;
2061
+ cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
2062
+ cur_p->data[0].logit = 1.0f;
2063
+
2064
+ return;
2065
+ }
2066
+
2067
+ // normalize probs
2068
+ for (size_t i = 0; i < cur_p->size; ++i) {
2069
+ cur_p->data[i].p /= p_sum;
2070
+
2071
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2072
+ }
2073
+
2074
+ size_org = cur_p->size;
2075
+ p_sum = 0.0f;
2076
+ thold = 1.0/(n_non_eog + 1);
2077
+
2078
+ cur_p->size = 0;
2079
+
2080
+ LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
2081
+
2082
+ for (size_t i = 0; i < size_org; ++i) {
2083
+ const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
2084
+
2085
+ if (cur_p->data[i].p < thold && !is_eog) {
2086
+ continue;
2087
+ }
2088
+
2089
+ p_sum += cur_p->data[i].p;
2090
+
2091
+ cur_p->data[cur_p->size++] = cur_p->data[i];
2092
+ }
2093
+
2094
+ // normalize probs
2095
+ for (size_t i = 0; i < cur_p->size; ++i) {
2096
+ cur_p->data[i].p /= p_sum;
2097
+
2098
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
2099
+ }
2100
+
2101
+ #undef LOG_DBG_CUR
2102
+ }
2103
+
2104
+ static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
2105
+ const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
2106
+ return llama_sampler_init_infill_impl(*ctx->vocab);
2107
+ }
2108
+
2109
+ static void llama_sampler_infill_free(struct llama_sampler * smpl) {
2110
+ delete (llama_sampler_infill *) smpl->ctx;
2111
+ }
2112
+
2113
+ static struct llama_sampler_i llama_sampler_infill_i = {
2114
+ /* .name = */ llama_sampler_infill_name,
2115
+ /* .accept = */ nullptr,
2116
+ /* .apply = */ llama_sampler_infill_apply,
2117
+ /* .reset = */ nullptr,
2118
+ /* .clone = */ llama_sampler_infill_clone,
2119
+ /* .free = */ llama_sampler_infill_free,
2120
+ };
2121
+
2122
+ struct llama_sampler * llama_sampler_init_infill_impl(
2123
+ const struct llama_vocab & vocab) {
2124
+ return new llama_sampler {
2125
+ /* .iface = */ &llama_sampler_infill_i,
2126
+ /* .ctx = */ new llama_sampler_infill {
2127
+ /* .vocab = */ &vocab,
2128
+ /* .buf0 = */ std::vector<char>(512),
2129
+ /* .buf1 = */ std::vector<char>(512),
2130
+ },
2131
+ };
2132
+ }
2133
+
1792
2134
  // utils
1793
2135
 
1794
2136
  uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
@@ -4,8 +4,6 @@
4
4
 
5
5
  #include "llama-grammar.h"
6
6
 
7
- #include <unordered_map>
8
-
9
7
  struct llama_vocab;
10
8
  struct llama_grammar;
11
9
 
@@ -27,3 +25,6 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
27
25
  const struct llama_vocab & vocab,
28
26
  const char * grammar_str,
29
27
  const char * grammar_root);
28
+
29
+ struct llama_sampler * llama_sampler_init_infill_impl(
30
+ const struct llama_vocab & vocab);
@@ -221,7 +221,7 @@ struct llm_tokenizer_spm_session {
221
221
  }
222
222
 
223
223
  // seed the work queue with all possible 2-character tokens.
224
- for (size_t i = 1; i < symbols.size(); ++i) {
224
+ for (int i = 1; i < (int) symbols.size(); ++i) {
225
225
  try_add_bigram(i - 1, i);
226
226
  }
227
227
 
@@ -563,7 +563,7 @@ struct llm_tokenizer_bpe_session {
563
563
  index++;
564
564
  symbols.emplace_back(sym);
565
565
  }
566
- for (size_t i = 1; i < symbols.size(); ++i) {
566
+ for (int i = 1; i < (int) symbols.size(); ++i) {
567
567
  add_new_bigram(i - 1, i);
568
568
  }
569
569
 
@@ -1663,6 +1663,14 @@ llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
1663
1663
  return vocab.special_eos_id;
1664
1664
  }
1665
1665
 
1666
+ llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
1667
+ return vocab.special_eot_id;
1668
+ }
1669
+
1670
+ llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
1671
+ return vocab.special_eom_id;
1672
+ }
1673
+
1666
1674
  llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
1667
1675
  return vocab.special_cls_id;
1668
1676
  }
@@ -1688,23 +1696,39 @@ bool llama_add_eos_token_impl(const struct llama_vocab & vocab) {
1688
1696
  }
1689
1697
 
1690
1698
  llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
1691
- return vocab.special_prefix_id;
1699
+ return vocab.special_fim_pre_id;
1692
1700
  }
1693
1701
 
1694
1702
  llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
1695
- return vocab.special_middle_id;
1703
+ return vocab.special_fim_mid_id;
1696
1704
  }
1697
1705
 
1698
1706
  llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
1699
- return vocab.special_suffix_id;
1707
+ return vocab.special_fim_suf_id;
1700
1708
  }
1701
1709
 
1702
- llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
1703
- return vocab.special_eot_id;
1710
+ llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab) {
1711
+ return vocab.special_fim_pre_id;
1704
1712
  }
1705
1713
 
1706
- llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
1707
- return vocab.special_eom_id;
1714
+ llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab) {
1715
+ return vocab.special_fim_suf_id;
1716
+ }
1717
+
1718
+ llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab) {
1719
+ return vocab.special_fim_mid_id;
1720
+ }
1721
+
1722
+ llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab) {
1723
+ return vocab.special_fim_pad_id;
1724
+ }
1725
+
1726
+ llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab) {
1727
+ return vocab.special_fim_rep_id;
1728
+ }
1729
+
1730
+ llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab) {
1731
+ return vocab.special_fim_sep_id;
1708
1732
  }
1709
1733
 
1710
1734
  int32_t llama_tokenize_impl(
package/cpp/llama-vocab.h CHANGED
@@ -37,20 +37,26 @@ struct llama_vocab {
37
37
  std::map<std::pair<std::string, std::string>, int> bpe_ranks;
38
38
 
39
39
  // default LLaMA special tokens
40
+ // TODO: should we set all of these to LLAMA_TOKEN_NULL?
40
41
  id special_bos_id = 1;
41
42
  id special_eos_id = 2;
43
+ id special_eot_id = LLAMA_TOKEN_NULL;
44
+ id special_eom_id = LLAMA_TOKEN_NULL;
42
45
  id special_unk_id = 0;
43
46
  id special_sep_id = LLAMA_TOKEN_NULL;
44
47
  id special_pad_id = LLAMA_TOKEN_NULL;
45
48
  id special_cls_id = LLAMA_TOKEN_NULL;
46
49
  id special_mask_id = LLAMA_TOKEN_NULL;
47
50
 
48
- id linefeed_id = 13;
49
- id special_prefix_id = LLAMA_TOKEN_NULL;
50
- id special_suffix_id = LLAMA_TOKEN_NULL;
51
- id special_middle_id = LLAMA_TOKEN_NULL;
52
- id special_eot_id = LLAMA_TOKEN_NULL; // TODO: move above after "eos_id", and here add "file separator" token
53
- id special_eom_id = LLAMA_TOKEN_NULL;
51
+ id linefeed_id = 13;
52
+
53
+ // fim tokens
54
+ id special_fim_pre_id = LLAMA_TOKEN_NULL;
55
+ id special_fim_suf_id = LLAMA_TOKEN_NULL;
56
+ id special_fim_mid_id = LLAMA_TOKEN_NULL;
57
+ id special_fim_pad_id = LLAMA_TOKEN_NULL;
58
+ id special_fim_rep_id = LLAMA_TOKEN_NULL; // repo
59
+ id special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator
54
60
 
55
61
  // set of all tokens that cause "end of generation"
56
62
  std::set<id> special_eog_ids;
@@ -104,19 +110,26 @@ bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token t
104
110
 
105
111
  llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
106
112
  llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
113
+ llama_token llama_token_eot_impl(const struct llama_vocab & vocab);
114
+ llama_token llama_token_eom_impl(const struct llama_vocab & vocab);
107
115
  llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
108
116
  llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
109
117
  llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
110
118
  llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
111
119
 
112
- bool llama_add_bos_token_impl(const struct llama_vocab & vocab);
113
- bool llama_add_eos_token_impl(const struct llama_vocab & vocab);
114
-
115
120
  llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
116
121
  llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
117
122
  llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
118
- llama_token llama_token_eot_impl (const struct llama_vocab & vocab);
119
- llama_token llama_token_eom_impl (const struct llama_vocab & vocab);
123
+
124
+ llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab);
125
+ llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab);
126
+ llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab);
127
+ llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab);
128
+ llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab);
129
+ llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab);
130
+
131
+ bool llama_add_bos_token_impl(const struct llama_vocab & vocab);
132
+ bool llama_add_eos_token_impl(const struct llama_vocab & vocab);
120
133
 
121
134
  int32_t llama_tokenize_impl(
122
135
  const struct llama_vocab & vocab,
@@ -136,6 +149,12 @@ int32_t llama_token_to_piece_impl(
136
149
  int32_t lstrip,
137
150
  bool special);
138
151
 
152
+ // check if token0 is contained as a prefix in token1
153
+ bool llama_token_is_prefix_impl(
154
+ const struct llama_vocab & vocab,
155
+ llama_token token0,
156
+ llama_token token1);
157
+
139
158
  int32_t llama_detokenize_impl(
140
159
  const struct llama_vocab & vocab,
141
160
  const llama_token * tokens,