cui-llama.rn 1.1.2 → 1.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,57 +1,29 @@
1
1
  #pragma once
2
2
 
3
- #include "llama-impl.h"
3
+ // TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
4
4
 
5
- struct llama_sampling {
6
- llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
5
+ #include "llama-grammar.h"
7
6
 
8
- std::mt19937 rng;
7
+ #include <unordered_map>
9
8
 
10
- int32_t n_vocab = 0;
9
+ struct llama_vocab;
10
+ struct llama_grammar;
11
11
 
12
- mutable int64_t t_sample_us = 0;
13
- mutable int32_t n_sample = 0;
12
+ // sampler chain
14
13
 
15
- void reset_timings() const {
16
- t_sample_us = 0;
17
- n_sample = 0;
18
- }
19
- };
14
+ struct llama_sampler_chain {
15
+ llama_sampler_chain_params params;
16
+
17
+ std::vector<struct llama_sampler *> samplers;
18
+
19
+ // timing
20
20
 
21
- //
22
- // internal API
23
- //
24
-
25
- void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed);
26
-
27
- void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
28
- void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
29
- void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
30
- void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
31
- void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
32
- void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
33
- void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
34
- void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
35
- void llama_sample_xtc_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float xtc_threshold, float xtc_probability, size_t min_keep, std::mt19937 & rng);
36
-
37
- void llama_sample_repetition_penalties_impl(
38
- struct llama_sampling * smpl,
39
- llama_token_data_array * candidates,
40
- const llama_token * last_tokens,
41
- size_t penalty_last_n,
42
- float penalty_repeat,
43
- float penalty_freq,
44
- float penalty_present);
45
-
46
- void llama_sample_apply_guidance_impl(
47
- struct llama_sampling * smpl,
48
- float * logits,
49
- float * logits_guidance,
50
- float scale);
51
-
52
- llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
53
- llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
54
- llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
55
- llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
56
- llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
21
+ mutable int64_t t_sample_us;
22
+
23
+ mutable int32_t n_sample;
24
+ };
57
25
 
26
+ struct llama_sampler * llama_sampler_init_grammar_impl(
27
+ const struct llama_vocab & vocab,
28
+ const char * grammar_str,
29
+ const char * grammar_root);
@@ -58,17 +58,17 @@ struct naive_trie {
58
58
  auto res = children.find(c);
59
59
  if (res != children.end()) {
60
60
  return res->second.get_longest_prefix(key, len, offset + 1);
61
- } else {
62
- return std::make_pair(key, offset);
63
61
  }
62
+
63
+ return std::make_pair(key, offset);
64
64
  }
65
- struct naive_trie * traverse(const char c) {
65
+ const struct naive_trie * traverse(const char c) const {
66
66
  auto res = children.find(c);
67
67
  if (res != children.end()) {
68
68
  return &res->second;
69
- } else {
70
- return NULL;
71
69
  }
70
+
71
+ return NULL;
72
72
  }
73
73
  std::map<char, struct naive_trie> children;
74
74
  bool has_value;
@@ -843,7 +843,7 @@ struct llm_tokenizer_ugm {
843
843
  // traverse the token matcher trie to find a matching token
844
844
  bool single_codepoint_token_found = false;
845
845
  const struct best_tokenization & current_best = tokenization_results[input_offset];
846
- struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
846
+ const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
847
847
 
848
848
  while (prefix_offset <= input_len && node != NULL) {
849
849
  // check if we found valid token in prefix
@@ -963,7 +963,7 @@ private:
963
963
  /*
964
964
  * This structure is a view wrapper for XOR-compressed double array (XCDA)
965
965
  * See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries.
966
- * Eeach bit-packed entry contains:
966
+ * Each bit-packed entry contains:
967
967
  * - BASE array value in bits 10-30
968
968
  * - LCHECK array value in bits 0-7
969
969
  * - LEAF array value in bit 9
@@ -1097,6 +1097,111 @@ private:
1097
1097
  struct naive_trie token_matcher;
1098
1098
  };
1099
1099
 
1100
+ //
1101
+ // RWKV tokenizer
1102
+ //
1103
+
1104
+ static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escaped) {
1105
+ std::vector<uint8_t> output;
1106
+ output.reserve(escaped.size());
1107
+
1108
+ // Parser state
1109
+ bool escaping = false;
1110
+ uint8_t hex_remaining = 0;
1111
+ uint8_t hex_acc = 0;
1112
+
1113
+ // Step through characters, performing parsing
1114
+ for (const char & c : escaped) {
1115
+ // If we're parsing a hex code, interpret the next character
1116
+ if (hex_remaining != 0) {
1117
+ uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0');
1118
+ hex_acc = (hex_acc << 4) + value;
1119
+
1120
+ hex_remaining -= 1;
1121
+ if (hex_remaining == 0) {
1122
+ output.push_back(hex_acc);
1123
+ hex_acc = 0;
1124
+ }
1125
+
1126
+ continue;
1127
+ }
1128
+
1129
+ // If we got an escape character, interpret it
1130
+ if (escaping) {
1131
+ if (c == 't') {
1132
+ output.push_back('\t');
1133
+ } else if (c == 'n') {
1134
+ output.push_back('\n');
1135
+ } else if (c == 'r') {
1136
+ output.push_back('\r');
1137
+ } else if (c == 'x') {
1138
+ hex_remaining = 2;
1139
+ } else {
1140
+ output.push_back(c);
1141
+ }
1142
+
1143
+ escaping = false;
1144
+ continue;
1145
+ }
1146
+
1147
+ if (c == '\\') {
1148
+ escaping = true;
1149
+ continue;
1150
+ }
1151
+
1152
+ output.push_back(c);
1153
+ }
1154
+
1155
+ return output;
1156
+ }
1157
+
1158
+ struct llm_tokenizer_rwkv {
1159
+ llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
1160
+ // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
1161
+ // For now, we decode the vocab here into the lookup we'll use for tokenization.
1162
+
1163
+ // build trie
1164
+ for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
1165
+ const auto & token = vocab.id_to_token[id];
1166
+ const auto data = llama_unescape_rwkv_token(token.text);
1167
+ token_matcher.insert((const char *) data.data(), data.size(), id);
1168
+ }
1169
+ }
1170
+
1171
+ void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
1172
+ uint32_t position = 0;
1173
+
1174
+ while (position < text.size()) {
1175
+ const struct naive_trie * node = token_matcher.traverse(text[position]);
1176
+ if (node == NULL) {
1177
+ // no matching token found, add unknown token
1178
+ output.push_back(vocab.special_unk_id);
1179
+ position += 1;
1180
+ continue;
1181
+ }
1182
+
1183
+ // traverse the trie to find the longest matching token
1184
+ uint32_t token_id = 0;
1185
+ uint32_t token_length = 0;
1186
+ while (node != NULL) {
1187
+ if (node->has_value) {
1188
+ token_id = node->value;
1189
+ token_length = position + 1;
1190
+ }
1191
+ node = node->traverse(text[++position]);
1192
+ }
1193
+
1194
+ // add the longest matching token
1195
+ output.push_back(token_id);
1196
+ position = token_length;
1197
+ }
1198
+ }
1199
+
1200
+ const llama_vocab & vocab;
1201
+
1202
+ struct naive_trie token_matcher;
1203
+ };
1204
+
1100
1205
  //
1101
1206
  // (de-) tokenize
1102
1207
  //
@@ -1401,6 +1506,23 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
1401
1506
  output.push_back(vocab.special_eos_id);
1402
1507
  }
1403
1508
  } break;
1509
+ case LLAMA_VOCAB_TYPE_RWKV:
1510
+ {
1511
+ for (const auto & fragment : fragment_buffer) {
1512
+ if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
1513
+ auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
1514
+
1515
+ #ifdef PRETOKENIZERDEBUG
1516
+ LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1517
+ #endif
1518
+
1519
+ llm_tokenizer_rwkv tokenizer(vocab);
1520
+ tokenizer.tokenize(raw_text, output);
1521
+ } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1522
+ output.push_back(fragment.token);
1523
+ }
1524
+ }
1525
+ } break;
1404
1526
  case LLAMA_VOCAB_TYPE_NONE:
1405
1527
  LM_GGML_ABORT("fatal error");
1406
1528
  }
@@ -1616,6 +1738,17 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
1616
1738
  }
1617
1739
  break;
1618
1740
  }
1741
+ case LLAMA_VOCAB_TYPE_RWKV: {
1742
+ std::vector<uint8_t> result = llama_unescape_rwkv_token(token_text);
1743
+
1744
+ // If we don't have enough space, return an error
1745
+ if (result.size() > (size_t)length) {
1746
+ return -(int)result.size();
1747
+ }
1748
+
1749
+ memcpy(buf, result.data(), result.size());
1750
+ return (int)result.size();
1751
+ }
1619
1752
  default:
1620
1753
  LM_GGML_ABORT("fatal error");
1621
1754
  }
package/cpp/llama-vocab.h CHANGED
@@ -18,6 +18,8 @@ struct llama_vocab {
18
18
  tattr attr;
19
19
  };
20
20
 
21
+ uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
22
+
21
23
  enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
22
24
  enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
23
25
 
@@ -62,8 +64,6 @@ struct llama_vocab {
62
64
  int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
63
65
  };
64
66
 
65
- const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
66
-
67
67
  //
68
68
  // internal API
69
69
  //
@@ -76,6 +76,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
76
76
  bool add_special,
77
77
  bool parse_special = false);
78
78
 
79
+ // TODO: move the API below as member functions of llama_vocab
79
80
  llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
80
81
 
81
82
  const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);