cui-llama.rn 1.1.2 → 1.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/sampling.cpp CHANGED
@@ -1,464 +1,463 @@
1
- #define LLAMA_API_INTERNAL
2
1
  #include "sampling.h"
3
- #include <random>
4
2
 
5
- struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
6
- struct llama_sampling_context * result = new llama_sampling_context();
3
+ #include "common.h"
7
4
 
8
- result->params = params;
9
- result->grammar = nullptr;
5
+ #include <cmath>
6
+ #include <unordered_map>
10
7
 
11
- // if there is a grammar, parse it
12
- if (!params.grammar.empty()) {
13
- result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
8
+ // the ring buffer works similarly to std::deque, but with a fixed capacity
9
+ // TODO: deduplicate with llama-impl.h
10
+ template<typename T>
11
+ struct ring_buffer {
12
+ ring_buffer(size_t cap) : capacity(cap), data(cap) {}
14
13
 
15
- // will be empty (default) if there are parse errors
16
- if (result->parsed_grammar.rules.empty()) {
17
- fprintf(stderr, "%s: failed to parse grammar\n", __func__);
18
- delete result;
19
- return nullptr;
14
+ T & front() {
15
+ if (sz == 0) {
16
+ throw std::runtime_error("ring buffer is empty");
20
17
  }
18
+ return data[first];
19
+ }
21
20
 
22
- // Ensure that there is a "root" node.
23
- if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) {
24
- fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
25
- delete result;
26
- return nullptr;
21
+ const T & front() const {
22
+ if (sz == 0) {
23
+ throw std::runtime_error("ring buffer is empty");
27
24
  }
25
+ return data[first];
26
+ }
28
27
 
29
- std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
30
-
31
- struct llama_grammar * grammar = llama_grammar_init(
32
- grammar_rules.data(),
33
- grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
34
- if (grammar == nullptr) {
35
- throw std::runtime_error("Failed to initialize llama_grammar");
28
+ T & back() {
29
+ if (sz == 0) {
30
+ throw std::runtime_error("ring buffer is empty");
36
31
  }
37
- result->grammar = grammar;
32
+ return data[pos];
38
33
  }
39
34
 
40
- result->prev.resize(params.n_prev);
41
-
42
- result->n_valid = 0;
43
-
44
- llama_sampling_set_rng_seed(result, params.seed);
45
-
46
- return result;
47
- }
48
-
49
- void llama_sampling_free(struct llama_sampling_context * ctx) {
50
- if (ctx->grammar != NULL) {
51
- llama_grammar_free(ctx->grammar);
35
+ const T & back() const {
36
+ if (sz == 0) {
37
+ throw std::runtime_error("ring buffer is empty");
38
+ }
39
+ return data[pos];
52
40
  }
53
41
 
54
- delete ctx;
55
- }
56
-
57
- void llama_sampling_reset(llama_sampling_context * ctx) {
58
- if (ctx->grammar != NULL) {
59
- llama_grammar_free(ctx->grammar);
60
- ctx->grammar = NULL;
42
+ void push_back(const T & value) {
43
+ if (sz == capacity) {
44
+ // advance the start when buffer is full
45
+ first = (first + 1) % capacity;
46
+ } else {
47
+ sz++;
48
+ }
49
+ data[pos] = value;
50
+ pos = (pos + 1) % capacity;
61
51
  }
62
52
 
63
- if (!ctx->parsed_grammar.rules.empty()) {
64
- std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
53
+ T pop_front() {
54
+ if (sz == 0) {
55
+ throw std::runtime_error("ring buffer is empty");
56
+ }
57
+ T value = data[first];
58
+ first = (first + 1) % capacity;
59
+ sz--;
60
+ return value;
61
+ }
65
62
 
66
- struct llama_grammar * grammar = llama_grammar_init(
67
- grammar_rules.data(),
68
- grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
69
- if (grammar == nullptr) {
70
- throw std::runtime_error("Failed to initialize llama_grammar");
63
+ const T & rat(size_t i) const {
64
+ if (i >= sz) {
65
+ throw std::runtime_error("ring buffer: index out of bounds");
71
66
  }
72
- ctx->grammar = grammar;
67
+ return data[(first + sz - i - 1) % capacity];
73
68
  }
74
69
 
75
- std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
76
- ctx->cur.clear();
77
- ctx->n_valid = 0;
78
- }
70
+ std::vector<T> to_vector() const {
71
+ std::vector<T> result;
72
+ result.reserve(sz);
73
+ for (size_t i = 0; i < sz; i++) {
74
+ result.push_back(data[(first + i) % capacity]);
75
+ }
76
+ return result;
77
+ }
79
78
 
80
- void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
81
- if (seed == LLAMA_DEFAULT_SEED) {
82
- seed = std::random_device{}();
79
+ void clear() {
80
+ // here only reset the status of the buffer
81
+ sz = 0;
82
+ first = 0;
83
+ pos = 0;
83
84
  }
84
- ctx->rng.seed(seed);
85
- }
86
85
 
87
- void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
88
- if (dst->grammar) {
89
- llama_grammar_free(dst->grammar);
90
- dst->grammar = nullptr;
86
+ bool empty() const {
87
+ return sz == 0;
91
88
  }
92
89
 
93
- if (src->grammar) {
94
- dst->grammar = llama_grammar_copy(src->grammar);
90
+ size_t size() const {
91
+ return sz;
95
92
  }
96
93
 
97
- dst->prev = src->prev;
98
- }
94
+ size_t capacity = 0;
95
+ size_t sz = 0;
96
+ size_t first = 0;
97
+ size_t pos = 0;
98
+ std::vector<T> data;
99
+ };
99
100
 
100
- llama_token llama_sampling_last(llama_sampling_context * ctx) {
101
- return ctx->prev.back();
102
- }
101
+ struct gpt_sampler {
102
+ gpt_sampler_params params;
103
103
 
104
- std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) {
105
- const int size = ctx_sampling->prev.size();
104
+ struct llama_sampler * grmr;
105
+ struct llama_sampler * chain;
106
106
 
107
- n = std::min(n, size);
107
+ ring_buffer<llama_token> prev;
108
108
 
109
- std::string result;
109
+ std::vector<llama_token_data> cur;
110
110
 
111
- for (int i = size - n; i < size; i++) {
112
- result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]);
113
- }
111
+ llama_token_data_array cur_p;
114
112
 
115
- return result;
116
- }
113
+ void set_logits(struct llama_context * ctx, int idx) {
114
+ const auto * logits = llama_get_logits_ith(ctx, idx);
115
+
116
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
117
117
 
118
- std::string llama_sampling_print(const llama_sampling_params & params) {
118
+ cur.resize(n_vocab);
119
+
120
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
121
+ cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
122
+ }
123
+
124
+ cur_p = { cur.data(), cur.size(), -1, false };
125
+ }
126
+ };
127
+
128
+ std::string gpt_sampler_params::print() const {
119
129
  char result[1024];
120
130
 
121
131
  snprintf(result, sizeof(result),
122
132
  "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
123
133
  "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
124
134
  "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
125
- params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
126
- params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
127
- params.mirostat, params.mirostat_eta, params.mirostat_tau);
135
+ penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
136
+ top_k, tfs_z, top_p, min_p, typ_p, temp,
137
+ mirostat, mirostat_eta, mirostat_tau);
128
138
 
129
139
  return std::string(result);
130
140
  }
131
141
 
132
- std::string llama_sampling_order_print(const llama_sampling_params & params) {
133
- std::string result = "CFG -> Penalties ";
134
- if (params.mirostat == 0) {
135
- for (auto sampler_type : params.samplers_sequence) {
136
- const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
137
- if (!sampler_type_name.empty()) {
138
- result += "-> " + sampler_type_name + " ";
142
+ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
143
+ llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
144
+
145
+ lparams.no_perf = params.no_perf;
146
+
147
+ auto * result = new gpt_sampler {
148
+ /* .params = */ params,
149
+ /* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
150
+ /* .chain = */ llama_sampler_chain_init(lparams),
151
+ /* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
152
+ /* .cur = */ {},
153
+ /* .cur_p = */ {},
154
+ };
155
+
156
+ llama_sampler_chain_add(result->chain,
157
+ llama_sampler_init_logit_bias(
158
+ llama_n_vocab(model),
159
+ params.logit_bias.size(),
160
+ params.logit_bias.data()));
161
+
162
+ llama_sampler_chain_add(result->chain,
163
+ llama_sampler_init_penalties(
164
+ llama_n_vocab (model),
165
+ llama_token_eos(model),
166
+ llama_token_nl (model),
167
+ params.penalty_last_n,
168
+ params.penalty_repeat,
169
+ params.penalty_freq,
170
+ params.penalty_present,
171
+ params.penalize_nl,
172
+ params.ignore_eos));
173
+
174
+ if (params.temp > 0.0f) {
175
+ if (params.mirostat == 0) {
176
+ for (const auto & cnstr : params.samplers) {
177
+ switch (cnstr) {
178
+ case GPT_SAMPLER_TYPE_TOP_K:
179
+ llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
180
+ break;
181
+ case GPT_SAMPLER_TYPE_TOP_P:
182
+ llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
183
+ break;
184
+ case GPT_SAMPLER_TYPE_MIN_P:
185
+ llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
186
+ break;
187
+ case GPT_SAMPLER_TYPE_TFS_Z:
188
+ llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
189
+ break;
190
+ case GPT_SAMPLER_TYPE_TYPICAL_P:
191
+ llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
192
+ break;
193
+ case GPT_SAMPLER_TYPE_XTC:
194
+ llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_p, params.xtc_t, params.min_keep, params.seed));
195
+ break;
196
+ case GPT_SAMPLER_TYPE_TEMPERATURE:
197
+ llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
198
+ break;
199
+ default:
200
+ LM_GGML_ASSERT(false && "unknown sampler type");
201
+ }
139
202
  }
203
+ llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
204
+ llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
205
+ } else if (params.mirostat == 1) {
206
+ llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
207
+ llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
208
+ } else if (params.mirostat == 2) {
209
+ llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
210
+ llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
211
+ } else {
212
+ LM_GGML_ASSERT(false && "unknown mirostat version");
140
213
  }
141
214
  } else {
142
- result += "-> mirostat ";
215
+ llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
216
+ llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
143
217
  }
144
218
 
145
219
  return result;
146
220
  }
147
221
 
148
- std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
149
- switch (sampler_type) {
150
- case llama_sampler_type::TOP_K: return "top_k";
151
- case llama_sampler_type::TFS_Z: return "tfs_z";
152
- case llama_sampler_type::TYPICAL_P: return "typical_p";
153
- case llama_sampler_type::TOP_P: return "top_p";
154
- case llama_sampler_type::MIN_P: return "min_p";
155
- case llama_sampler_type::TEMPERATURE: return "temperature";
156
- default : return "";
222
+ void gpt_sampler_free(struct gpt_sampler * gsmpl) {
223
+ if (gsmpl) {
224
+ llama_sampler_free(gsmpl->grmr);
225
+
226
+ llama_sampler_free(gsmpl->chain);
227
+
228
+ delete gsmpl;
157
229
  }
158
230
  }
159
231
 
160
- std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
161
- std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
162
- {"top_k", llama_sampler_type::TOP_K},
163
- {"top_p", llama_sampler_type::TOP_P},
164
- {"typical_p", llama_sampler_type::TYPICAL_P},
165
- {"min_p", llama_sampler_type::MIN_P},
166
- {"tfs_z", llama_sampler_type::TFS_Z},
167
- {"temperature", llama_sampler_type::TEMPERATURE}
168
- };
232
+ void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar) {
233
+ if (accept_grammar) {
234
+ llama_sampler_accept(gsmpl->grmr, token);
235
+ }
169
236
 
170
- // since samplers names are written multiple ways
171
- // make it ready for both system names and input names
172
- std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
173
- {"top-k", llama_sampler_type::TOP_K},
174
- {"top-p", llama_sampler_type::TOP_P},
175
- {"nucleus", llama_sampler_type::TOP_P},
176
- {"typical-p", llama_sampler_type::TYPICAL_P},
177
- {"typical", llama_sampler_type::TYPICAL_P},
178
- {"min-p", llama_sampler_type::MIN_P},
179
- {"tfs-z", llama_sampler_type::TFS_Z},
180
- {"tfs", llama_sampler_type::TFS_Z},
181
- {"temp", llama_sampler_type::TEMPERATURE}
182
- };
237
+ llama_sampler_accept(gsmpl->chain, token);
183
238
 
184
- std::vector<llama_sampler_type> sampler_types;
185
- sampler_types.reserve(names.size());
186
- for (const auto & name : names)
187
- {
188
- auto sampler_item = sampler_canonical_name_map.find(name);
189
- if (sampler_item != sampler_canonical_name_map.end())
190
- {
191
- sampler_types.push_back(sampler_item->second);
192
- }
193
- else
194
- {
195
- if (allow_alt_names)
196
- {
197
- sampler_item = sampler_alt_name_map.find(name);
198
- if (sampler_item != sampler_alt_name_map.end())
199
- {
200
- sampler_types.push_back(sampler_item->second);
201
- }
202
- }
203
- }
204
- }
205
- return sampler_types;
239
+ gsmpl->prev.push_back(token);
206
240
  }
207
241
 
208
- std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string) {
209
- std::unordered_map<char, llama_sampler_type> sampler_name_map {
210
- {'k', llama_sampler_type::TOP_K},
211
- {'p', llama_sampler_type::TOP_P},
212
- {'y', llama_sampler_type::TYPICAL_P},
213
- {'m', llama_sampler_type::MIN_P},
214
- {'f', llama_sampler_type::TFS_Z},
215
- {'t', llama_sampler_type::TEMPERATURE}
216
- };
242
+ void gpt_sampler_reset(struct gpt_sampler * gsmpl) {
243
+ llama_sampler_reset(gsmpl->grmr);
217
244
 
218
- std::vector<llama_sampler_type> sampler_types;
219
- sampler_types.reserve(names_string.size());
220
- for (const auto & c : names_string) {
221
- const auto sampler_item = sampler_name_map.find(c);
222
- if (sampler_item != sampler_name_map.end()) {
223
- sampler_types.push_back(sampler_item->second);
224
- }
225
- }
226
- return sampler_types;
245
+ llama_sampler_reset(gsmpl->chain);
227
246
  }
228
247
 
229
- // no reasons to expose this function in header
230
- static void sampler_queue(
231
- struct llama_context * ctx_main,
232
- struct llama_sampling_context * ctx_sampling,
233
- const llama_sampling_params & params,
234
- llama_token_data_array & cur_p,
235
- size_t min_keep) {
236
- const float temp = params.temp;
237
- const float dynatemp_range = params.dynatemp_range;
238
- const float dynatemp_exponent = params.dynatemp_exponent;
239
- const int32_t top_k = params.top_k;
240
- const float top_p = params.top_p;
241
- const float min_p = params.min_p;
242
- const float xtc_t = params.xtc_t;
243
- const float xtc_p = params.xtc_p;
244
- const float tfs_z = params.tfs_z;
245
- const float typical_p = params.typical_p;
246
- const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
247
-
248
- for (auto sampler_type : samplers_sequence) {
249
- switch (sampler_type) {
250
- case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
251
- case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
252
- case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
253
- case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
254
- case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
255
- case llama_sampler_type::XTC : llama_sample_xtc (ctx_main, &cur_p, xtc_t, xtc_p, min_keep, ctx_sampling->rng); break;
256
- case llama_sampler_type::TEMPERATURE:
257
- if (dynatemp_range > 0) {
258
- float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
259
- float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
260
- llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
261
- } else {
262
- llama_sample_temp(ctx_main, &cur_p, temp);
263
- }
264
- break;
265
- default : break;
266
- }
267
- }
248
+ struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
249
+ return new gpt_sampler {
250
+ /* .params = */ gsmpl->params,
251
+ /* .grmr = */ llama_sampler_clone(gsmpl->grmr),
252
+ /* .chain = */ llama_sampler_clone(gsmpl->chain),
253
+ /* .prev = */ gsmpl->prev,
254
+ /* .cur = */ gsmpl->cur,
255
+ /* .cur_p = */ gsmpl->cur_p,
256
+ };
268
257
  }
269
258
 
270
- static llama_token llama_sampling_sample_impl(
271
- struct llama_sampling_context * ctx_sampling,
272
- struct llama_context * ctx_main,
273
- struct llama_context * ctx_cfg,
274
- const int idx,
275
- bool is_resampling) {
276
- const llama_sampling_params & params = ctx_sampling->params;
277
-
278
- const float temp = params.temp;
279
- const int mirostat = params.mirostat;
280
- const float mirostat_tau = params.mirostat_tau;
281
- const float mirostat_eta = params.mirostat_eta;
282
-
283
- std::vector<float> original_logits;
284
- auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
285
- if (ctx_sampling->grammar != NULL && !is_resampling) {
286
- LM_GGML_ASSERT(!original_logits.empty());
287
- }
288
- llama_token id = 0;
289
-
290
- if (temp < 0.0) {
291
- // greedy sampling, with probs
292
- llama_sample_softmax(ctx_main, &cur_p);
293
- id = cur_p.data[0].id;
294
- } else if (temp == 0.0) {
295
- // greedy sampling, no probs
296
- id = llama_sample_token_greedy(ctx_main, &cur_p);
297
- } else {
298
- if (mirostat == 1) {
299
- const int mirostat_m = 100;
300
- llama_sample_temp(ctx_main, &cur_p, temp);
301
- id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
302
- } else if (mirostat == 2) {
303
- llama_sample_temp(ctx_main, &cur_p, temp);
304
- id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
305
- } else {
306
- // temperature sampling
307
- size_t min_keep = std::max(1, params.min_keep);
308
-
309
- sampler_queue(ctx_main, ctx_sampling, params, cur_p, min_keep);
259
+ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl) {
260
+ // TODO: measure grammar performance
310
261
 
311
- id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
262
+ if (gsmpl) {
263
+ llama_perf_sampler_print(gsmpl->chain);
264
+ }
265
+ if (ctx) {
266
+ llama_perf_context_print(ctx);
267
+ }
268
+ }
312
269
 
313
- //{
314
- // const int n_top = 10;
315
- // LOG("top %d candidates:\n", n_top);
270
+ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
271
+ gsmpl->set_logits(ctx, idx);
316
272
 
317
- // for (int i = 0; i < n_top; i++) {
318
- // const llama_token id = cur_p.data[i].id;
319
- // (void)id; // To avoid a warning that id is unused when logging is disabled.
320
- // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
321
- // }
322
- //}
273
+ auto & grmr = gsmpl->grmr;
274
+ auto & chain = gsmpl->chain;
275
+ auto & cur_p = gsmpl->cur_p; // initialized by set_logits
323
276
 
324
- //LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str());
325
- }
277
+ if (grammar_first) {
278
+ llama_sampler_apply(grmr, &cur_p);
326
279
  }
327
280
 
328
- if (ctx_sampling->grammar != NULL && !is_resampling) {
329
- // Get a pointer to the logits
330
- float * logits = llama_get_logits_ith(ctx_main, idx);
281
+ llama_sampler_apply(chain, &cur_p);
331
282
 
332
- // Create an array with a single token data element for the sampled id
333
- llama_token_data single_token_data = {id, logits[id], 0.0f};
334
- llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
283
+ LM_GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
335
284
 
336
- // Apply grammar constraints to the single token
337
- llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);
285
+ const llama_token id = cur_p.data[cur_p.selected].id;
338
286
 
339
- // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
340
- bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
287
+ if (grammar_first) {
288
+ return id;
289
+ }
341
290
 
342
- // If the token is not valid according to the grammar, perform resampling
343
- if (!is_valid) {
344
- LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
291
+ // check if it the sampled token fits the grammar
292
+ {
293
+ llama_token_data single_token_data = { id, 1.0f, 0.0f };
294
+ llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
345
295
 
346
- // Restore logits from the copy
347
- std::copy(original_logits.begin(), original_logits.end(), logits);
296
+ llama_sampler_apply(grmr, &single_token_data_array);
348
297
 
349
- return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true);
298
+ const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
299
+ if (is_valid) {
300
+ return id;
350
301
  }
351
302
  }
352
303
 
353
- ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size;
304
+ // resampling:
305
+ // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
306
+ gsmpl->set_logits(ctx, idx);
354
307
 
355
- return id;
356
- }
308
+ llama_sampler_apply(grmr, &cur_p);
309
+ llama_sampler_apply(chain, &cur_p);
357
310
 
358
- static llama_token_data_array llama_sampling_prepare_impl(
359
- struct llama_sampling_context * ctx_sampling,
360
- struct llama_context * ctx_main,
361
- struct llama_context * ctx_cfg,
362
- const int idx,
363
- bool apply_grammar,
364
- std::vector<float> * original_logits) {
365
- const llama_sampling_params & params = ctx_sampling->params;
311
+ LM_GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
366
312
 
367
- const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
313
+ return cur_p.data[cur_p.selected].id;
314
+ }
368
315
 
369
- const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
370
- const float penalty_repeat = params.penalty_repeat;
371
- const float penalty_freq = params.penalty_freq;
372
- const float penalty_present = params.penalty_present;
316
+ uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl) {
317
+ return llama_sampler_get_seed(gsmpl->chain);
318
+ }
373
319
 
374
- const bool penalize_nl = params.penalize_nl;
320
+ // helpers
375
321
 
376
- auto & prev = ctx_sampling->prev;
377
- auto & cur = ctx_sampling->cur;
322
+ llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {
323
+ return &gsmpl->cur_p;
324
+ }
378
325
 
379
- // Get a pointer to the logits
380
- float * logits = llama_get_logits_ith(ctx_main, idx);
326
+ llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) {
327
+ return gsmpl->prev.rat(0);
328
+ }
329
+
330
+ std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
331
+ std::string result = "\tlogits ";
381
332
 
382
- if (ctx_sampling->grammar != NULL && !apply_grammar) {
383
- LM_GGML_ASSERT(original_logits != NULL);
384
- // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
385
- *original_logits = {logits, logits + n_vocab};
333
+ for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
334
+ const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
335
+ result += std::string("-> ") + llama_sampler_name(smpl) + " ";
386
336
  }
387
337
 
388
- // apply params.logit_bias map
389
- for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
390
- logits[it->first] += it->second;
338
+ return result;
339
+ }
340
+
341
+ std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) {
342
+ n = std::min(n, (int) gsmpl->prev.size());
343
+
344
+ if (n <= 0) {
345
+ return "";
391
346
  }
392
347
 
393
- if (ctx_cfg) {
394
- float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
395
- llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
348
+ std::string result;
349
+ result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab
350
+
351
+ for (int i = n - 1; i >= 0; i--) {
352
+ const llama_token id = gsmpl->prev.rat(i);
353
+
354
+ LM_GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");
355
+
356
+ result += llama_token_to_piece(ctx_main, id);
396
357
  }
397
358
 
398
- cur.resize(n_vocab);
359
+ return result;
360
+ }
399
361
 
400
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
401
- cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
362
+ struct llama_sampler_timings gpt_sampler_get_timigs(const struct gpt_sampler * gsmpl){
363
+ return llama_sampler_chain_timings(gsmpl -> chain);
364
+ }
365
+
366
+ char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr) {
367
+ switch (cnstr) {
368
+ case GPT_SAMPLER_TYPE_TOP_K: return 'k';
369
+ case GPT_SAMPLER_TYPE_TFS_Z: return 'f';
370
+ case GPT_SAMPLER_TYPE_TYPICAL_P: return 'y';
371
+ case GPT_SAMPLER_TYPE_TOP_P: return 'p';
372
+ case GPT_SAMPLER_TYPE_MIN_P: return 'm';
373
+ case GPT_SAMPLER_TYPE_TEMPERATURE: return 't';
374
+ case GPT_SAMPLER_TYPE_XTC: return 'x';
375
+ default : return '?';
376
+ }
377
+ }
378
+
379
+ std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr) {
380
+ switch (cnstr) {
381
+ case GPT_SAMPLER_TYPE_TOP_K: return "top_k";
382
+ case GPT_SAMPLER_TYPE_TFS_Z: return "tfs_z";
383
+ case GPT_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
384
+ case GPT_SAMPLER_TYPE_TOP_P: return "top_p";
385
+ case GPT_SAMPLER_TYPE_MIN_P: return "min_p";
386
+ case GPT_SAMPLER_TYPE_XTC: return "xtc";
387
+ case GPT_SAMPLER_TYPE_TEMPERATURE: return "temperature";
388
+ default : return "";
402
389
  }
390
+ }
403
391
 
404
- llama_token_data_array cur_p = { cur.data(), cur.size(), false };
392
+ std::vector<gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
393
+ std::unordered_map<std::string, gpt_sampler_type> sampler_canonical_name_map {
394
+ { "top_k", GPT_SAMPLER_TYPE_TOP_K },
395
+ { "top_p", GPT_SAMPLER_TYPE_TOP_P },
396
+ { "typ_p", GPT_SAMPLER_TYPE_TYPICAL_P },
397
+ { "min_p", GPT_SAMPLER_TYPE_MIN_P },
398
+ { "tfs_z", GPT_SAMPLER_TYPE_TFS_Z },
399
+ { "xtc", GPT_SAMPLER_TYPE_XTC},
400
+ { "temperature", GPT_SAMPLER_TYPE_TEMPERATURE },
401
+ };
405
402
 
406
- // apply penalties
407
- const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
408
- const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
409
- if (penalty_tokens_used_size) {
410
- const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
403
+ // since samplers names are written multiple ways
404
+ // make it ready for both system names and input names
405
+ std::unordered_map<std::string, gpt_sampler_type> sampler_alt_name_map {
406
+ { "top-k", GPT_SAMPLER_TYPE_TOP_K },
407
+ { "top-p", GPT_SAMPLER_TYPE_TOP_P },
408
+ { "nucleus", GPT_SAMPLER_TYPE_TOP_P },
409
+ { "typical-p", GPT_SAMPLER_TYPE_TYPICAL_P },
410
+ { "typical", GPT_SAMPLER_TYPE_TYPICAL_P },
411
+ { "typ-p", GPT_SAMPLER_TYPE_TYPICAL_P },
412
+ { "typ", GPT_SAMPLER_TYPE_TYPICAL_P },
413
+ { "min-p", GPT_SAMPLER_TYPE_MIN_P },
414
+ { "tfs-z", GPT_SAMPLER_TYPE_TFS_Z },
415
+ { "tfs", GPT_SAMPLER_TYPE_TFS_Z },
416
+ { "xtc_p", GPT_SAMPLER_TYPE_XTC},
417
+ { "xtc_t", GPT_SAMPLER_TYPE_XTC},
418
+ { "temp", GPT_SAMPLER_TYPE_TEMPERATURE },
419
+ };
411
420
 
412
- llama_sample_repetition_penalties(ctx_main, &cur_p,
413
- penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
414
- penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
421
+ std::vector<gpt_sampler_type> samplers;
422
+ samplers.reserve(names.size());
415
423
 
416
- if (!penalize_nl) {
417
- for (size_t idx = 0; idx < cur_p.size; idx++) {
418
- if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
419
- cur_p.data[idx].logit = nl_logit;
420
- break;
424
+ for (const auto & name : names) {
425
+ auto sampler = sampler_canonical_name_map.find(name);
426
+ if (sampler != sampler_canonical_name_map.end()) {
427
+ samplers.push_back(sampler->second);
428
+ } else {
429
+ if (allow_alt_names) {
430
+ sampler = sampler_alt_name_map.find(name);
431
+ if (sampler != sampler_alt_name_map.end()) {
432
+ samplers.push_back(sampler->second);
421
433
  }
422
434
  }
423
435
  }
424
436
  }
425
437
 
426
- // apply grammar checks before sampling logic
427
- if (apply_grammar && ctx_sampling->grammar != NULL) {
428
- llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
429
- }
430
-
431
- return cur_p;
438
+ return samplers;
432
439
  }
433
440
 
434
- llama_token llama_sampling_sample(
435
- struct llama_sampling_context * ctx_sampling,
436
- struct llama_context * ctx_main,
437
- struct llama_context * ctx_cfg,
438
- const int idx) {
439
- // Call the implementation function with is_resampling set to false by default
440
- return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false);
441
- }
442
-
443
- llama_token_data_array llama_sampling_prepare(
444
- struct llama_sampling_context * ctx_sampling,
445
- struct llama_context * ctx_main,
446
- struct llama_context * ctx_cfg,
447
- const int idx,
448
- bool apply_grammar,
449
- std::vector<float> * original_logits) {
450
- return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
451
- }
441
+ std::vector<gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars) {
442
+ std::unordered_map<char, gpt_sampler_type> sampler_name_map = {
443
+ { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_K), GPT_SAMPLER_TYPE_TOP_K },
444
+ { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TFS_Z), GPT_SAMPLER_TYPE_TFS_Z },
445
+ { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TYPICAL_P), GPT_SAMPLER_TYPE_TYPICAL_P },
446
+ { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_P), GPT_SAMPLER_TYPE_TOP_P },
447
+ { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_MIN_P), GPT_SAMPLER_TYPE_MIN_P },
448
+ { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_XTC), GPT_SAMPLER_TYPE_XTC},
449
+ { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TEMPERATURE), GPT_SAMPLER_TYPE_TEMPERATURE }
450
+ };
452
451
 
453
- void llama_sampling_accept(
454
- struct llama_sampling_context * ctx_sampling,
455
- struct llama_context * ctx_main,
456
- llama_token id,
457
- bool apply_grammar) {
458
- ctx_sampling->prev.erase(ctx_sampling->prev.begin());
459
- ctx_sampling->prev.push_back(id);
452
+ std::vector<gpt_sampler_type> samplers;
453
+ samplers.reserve(chars.size());
460
454
 
461
- if (ctx_sampling->grammar != NULL && apply_grammar) {
462
- llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
455
+ for (const auto & c : chars) {
456
+ const auto sampler = sampler_name_map.find(c);
457
+ if (sampler != sampler_name_map.end()) {
458
+ samplers.push_back(sampler->second);
459
+ }
463
460
  }
461
+
462
+ return samplers;
464
463
  }