@fugood/llama.node 1.4.7 → 1.4.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. package/lib/binding.ts +8 -0
  2. package/package.json +15 -15
  3. package/scripts/llama.cpp.patch +22 -23
  4. package/src/LlamaContext.cpp +2 -2
  5. package/src/llama.cpp/common/CMakeLists.txt +2 -0
  6. package/src/llama.cpp/common/arg.cpp +364 -193
  7. package/src/llama.cpp/common/arg.h +43 -2
  8. package/src/llama.cpp/common/chat-peg-parser.cpp +16 -2
  9. package/src/llama.cpp/common/chat.cpp +140 -0
  10. package/src/llama.cpp/common/common.cpp +130 -67
  11. package/src/llama.cpp/common/common.h +40 -16
  12. package/src/llama.cpp/common/console.cpp +98 -18
  13. package/src/llama.cpp/common/console.h +30 -8
  14. package/src/llama.cpp/common/download.cpp +69 -25
  15. package/src/llama.cpp/common/json-schema-to-grammar.cpp +132 -3
  16. package/src/llama.cpp/common/json-schema-to-grammar.h +20 -0
  17. package/src/llama.cpp/common/log.cpp +5 -0
  18. package/src/llama.cpp/common/log.h +1 -0
  19. package/src/llama.cpp/common/peg-parser.cpp +1 -1
  20. package/src/llama.cpp/common/preset.cpp +206 -0
  21. package/src/llama.cpp/common/preset.h +32 -0
  22. package/src/llama.cpp/common/sampling.cpp +91 -92
  23. package/src/llama.cpp/common/sampling.h +11 -6
  24. package/src/llama.cpp/common/speculative.cpp +1 -1
  25. package/src/llama.cpp/ggml/CMakeLists.txt +4 -0
  26. package/src/llama.cpp/ggml/include/ggml-alloc.h +9 -0
  27. package/src/llama.cpp/ggml/include/ggml-backend.h +1 -0
  28. package/src/llama.cpp/ggml/include/ggml-cpu.h +1 -0
  29. package/src/llama.cpp/ggml/include/ggml.h +7 -8
  30. package/src/llama.cpp/ggml/src/CMakeLists.txt +3 -0
  31. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2 -0
  32. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +60 -39
  33. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  34. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +2 -1
  35. package/src/llama.cpp/include/llama.h +18 -1
  36. package/src/llama.cpp/src/llama-arch.cpp +1890 -2248
  37. package/src/llama.cpp/src/llama-arch.h +9 -2
  38. package/src/llama.cpp/src/llama-batch.cpp +12 -2
  39. package/src/llama.cpp/src/llama-batch.h +4 -2
  40. package/src/llama.cpp/src/llama-context.cpp +93 -23
  41. package/src/llama.cpp/src/llama-context.h +8 -2
  42. package/src/llama.cpp/src/llama-graph.cpp +84 -16
  43. package/src/llama.cpp/src/llama-graph.h +17 -4
  44. package/src/llama.cpp/src/llama-hparams.cpp +6 -0
  45. package/src/llama.cpp/src/llama-hparams.h +5 -1
  46. package/src/llama.cpp/src/llama-impl.cpp +4 -0
  47. package/src/llama.cpp/src/llama-kv-cache.cpp +90 -42
  48. package/src/llama.cpp/src/llama-kv-cache.h +19 -2
  49. package/src/llama.cpp/src/llama-memory-hybrid.cpp +1 -1
  50. package/src/llama.cpp/src/llama-model-loader.cpp +2 -0
  51. package/src/llama.cpp/src/llama-model-loader.h +2 -0
  52. package/src/llama.cpp/src/llama-model.cpp +103 -44
  53. package/src/llama.cpp/src/llama-model.h +1 -0
  54. package/src/llama.cpp/src/llama-quant.cpp +1 -1
  55. package/src/llama.cpp/src/llama-vocab.cpp +2 -1
  56. package/src/llama.cpp/src/llama.cpp +675 -1
  57. package/src/llama.cpp/src/models/deepseek2.cpp +9 -5
  58. package/src/llama.cpp/src/models/glm4-moe.cpp +28 -11
  59. package/src/llama.cpp/src/models/glm4.cpp +27 -4
  60. package/src/llama.cpp/src/models/models.h +5 -5
  61. package/src/llama.cpp/src/models/nemotron-h.cpp +35 -6
  62. package/src/llama.cpp/src/models/qwen2.cpp +12 -3
  63. package/src/llama.cpp/src/models/qwen3next.cpp +81 -266
@@ -3,11 +3,31 @@
3
3
  #include <nlohmann/json_fwd.hpp>
4
4
 
5
5
  #include <functional>
6
+ #include <memory>
6
7
  #include <string>
7
8
 
8
9
  std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
9
10
  bool force_gbnf = false);
10
11
 
12
+ class common_schema_converter;
13
+
14
+ // Probes a JSON schema to extract information about its structure and type constraints.
15
+ class common_schema_info {
16
+ std::unique_ptr<common_schema_converter> impl_;
17
+
18
+ public:
19
+ common_schema_info();
20
+ ~common_schema_info();
21
+
22
+ common_schema_info(const common_schema_info &) = delete;
23
+ common_schema_info & operator=(const common_schema_info &) = delete;
24
+ common_schema_info(common_schema_info &&) noexcept;
25
+ common_schema_info & operator=(common_schema_info &&) noexcept;
26
+
27
+ void resolve_refs(nlohmann::ordered_json & schema);
28
+ bool resolves_to_string(const nlohmann::ordered_json & schema);
29
+ };
30
+
11
31
  struct common_grammar_builder {
12
32
  std::function<std::string(const std::string &, const std::string &)> add_rule;
13
33
  std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
@@ -420,6 +420,11 @@ void common_log_set_timestamps(struct common_log * log, bool timestamps) {
420
420
  log->set_timestamps(timestamps);
421
421
  }
422
422
 
423
+ void common_log_flush(struct common_log * log) {
424
+ log->pause();
425
+ log->resume();
426
+ }
427
+
423
428
  static int common_get_verbosity(enum ggml_log_level level) {
424
429
  switch (level) {
425
430
  case GGML_LOG_LEVEL_DEBUG: return LOG_LEVEL_DEBUG;
@@ -84,6 +84,7 @@ void common_log_set_file (struct common_log * log, const char * file); // n
84
84
  void common_log_set_colors (struct common_log * log, log_colors colors); // not thread-safe
85
85
  void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log
86
86
  void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix
87
+ void common_log_flush (struct common_log * log); // flush all pending log messages
87
88
 
88
89
  // helper macros for logging
89
90
  // use these to avoid computing log arguments if the verbosity of the log is higher than the threshold
@@ -425,7 +425,7 @@ struct parser_executor {
425
425
 
426
426
  if (result.need_more_input()) {
427
427
  // Propagate - need to know what child would match before negating
428
- return result;
428
+ return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos);
429
429
  }
430
430
 
431
431
  // Child failed, so negation succeeds
@@ -0,0 +1,206 @@
1
+ #include "arg.h"
2
+ #include "preset.h"
3
+ #include "peg-parser.h"
4
+ #include "log.h"
5
+
6
+ #include <fstream>
7
+ #include <sstream>
8
+ #include <filesystem>
9
+
10
+ static std::string rm_leading_dashes(const std::string & str) {
11
+ size_t pos = 0;
12
+ while (pos < str.size() && str[pos] == '-') {
13
+ ++pos;
14
+ }
15
+ return str.substr(pos);
16
+ }
17
+
18
+ std::vector<std::string> common_preset::to_args() const {
19
+ std::vector<std::string> args;
20
+
21
+ for (const auto & [opt, value] : options) {
22
+ args.push_back(opt.args.back()); // use the last arg as the main arg
23
+ if (opt.value_hint == nullptr && opt.value_hint_2 == nullptr) {
24
+ // flag option, no value
25
+ if (common_arg_utils::is_falsey(value)) {
26
+ // use negative arg if available
27
+ if (!opt.args_neg.empty()) {
28
+ args.back() = opt.args_neg.back();
29
+ } else {
30
+ // otherwise, skip the flag
31
+ // TODO: maybe throw an error instead?
32
+ args.pop_back();
33
+ }
34
+ }
35
+ }
36
+ if (opt.value_hint != nullptr) {
37
+ // single value
38
+ args.push_back(value);
39
+ }
40
+ if (opt.value_hint != nullptr && opt.value_hint_2 != nullptr) {
41
+ throw std::runtime_error(string_format(
42
+ "common_preset::to_args(): option '%s' has two values, which is not supported yet",
43
+ opt.args.back()
44
+ ));
45
+ }
46
+ }
47
+
48
+ return args;
49
+ }
50
+
51
+ std::string common_preset::to_ini() const {
52
+ std::ostringstream ss;
53
+
54
+ ss << "[" << name << "]\n";
55
+ for (const auto & [opt, value] : options) {
56
+ auto espaced_value = value;
57
+ string_replace_all(espaced_value, "\n", "\\\n");
58
+ ss << rm_leading_dashes(opt.args.back()) << " = ";
59
+ ss << espaced_value << "\n";
60
+ }
61
+ ss << "\n";
62
+
63
+ return ss.str();
64
+ }
65
+
66
+ static std::map<std::string, std::map<std::string, std::string>> parse_ini_from_file(const std::string & path) {
67
+ std::map<std::string, std::map<std::string, std::string>> parsed;
68
+
69
+ if (!std::filesystem::exists(path)) {
70
+ throw std::runtime_error("preset file does not exist: " + path);
71
+ }
72
+
73
+ std::ifstream file(path);
74
+ if (!file.good()) {
75
+ throw std::runtime_error("failed to open server preset file: " + path);
76
+ }
77
+
78
+ std::string contents((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
79
+
80
+ static const auto parser = build_peg_parser([](auto & p) {
81
+ // newline ::= "\r\n" / "\n" / "\r"
82
+ auto newline = p.rule("newline", p.literal("\r\n") | p.literal("\n") | p.literal("\r"));
83
+
84
+ // ws ::= [ \t]*
85
+ auto ws = p.rule("ws", p.chars("[ \t]", 0, -1));
86
+
87
+ // comment ::= [;#] (!newline .)*
88
+ auto comment = p.rule("comment", p.chars("[;#]", 1, 1) + p.zero_or_more(p.negate(newline) + p.any()));
89
+
90
+ // eol ::= ws comment? (newline / EOF)
91
+ auto eol = p.rule("eol", ws + p.optional(comment) + (newline | p.end()));
92
+
93
+ // ident ::= [a-zA-Z_] [a-zA-Z0-9_.-]*
94
+ auto ident = p.rule("ident", p.chars("[a-zA-Z_]", 1, 1) + p.chars("[a-zA-Z0-9_.-]", 0, -1));
95
+
96
+ // value ::= (!eol-start .)*
97
+ auto eol_start = p.rule("eol-start", ws + (p.chars("[;#]", 1, 1) | newline | p.end()));
98
+ auto value = p.rule("value", p.zero_or_more(p.negate(eol_start) + p.any()));
99
+
100
+ // header-line ::= "[" ws ident ws "]" eol
101
+ auto header_line = p.rule("header-line", "[" + ws + p.tag("section-name", p.chars("[^]]")) + ws + "]" + eol);
102
+
103
+ // kv-line ::= ident ws "=" ws value eol
104
+ auto kv_line = p.rule("kv-line", p.tag("key", ident) + ws + "=" + ws + p.tag("value", value) + eol);
105
+
106
+ // comment-line ::= ws comment (newline / EOF)
107
+ auto comment_line = p.rule("comment-line", ws + comment + (newline | p.end()));
108
+
109
+ // blank-line ::= ws (newline / EOF)
110
+ auto blank_line = p.rule("blank-line", ws + (newline | p.end()));
111
+
112
+ // line ::= header-line / kv-line / comment-line / blank-line
113
+ auto line = p.rule("line", header_line | kv_line | comment_line | blank_line);
114
+
115
+ // ini ::= line* EOF
116
+ auto ini = p.rule("ini", p.zero_or_more(line) + p.end());
117
+
118
+ return ini;
119
+ });
120
+
121
+ common_peg_parse_context ctx(contents);
122
+ const auto result = parser.parse(ctx);
123
+ if (!result.success()) {
124
+ throw std::runtime_error("failed to parse server config file: " + path);
125
+ }
126
+
127
+ std::string current_section = COMMON_PRESET_DEFAULT_NAME;
128
+ std::string current_key;
129
+
130
+ ctx.ast.visit(result, [&](const auto & node) {
131
+ if (node.tag == "section-name") {
132
+ const std::string section = std::string(node.text);
133
+ current_section = section;
134
+ parsed[current_section] = {};
135
+ } else if (node.tag == "key") {
136
+ const std::string key = std::string(node.text);
137
+ current_key = key;
138
+ } else if (node.tag == "value" && !current_key.empty() && !current_section.empty()) {
139
+ parsed[current_section][current_key] = std::string(node.text);
140
+ current_key.clear();
141
+ }
142
+ });
143
+
144
+ return parsed;
145
+ }
146
+
147
+ static std::map<std::string, common_arg> get_map_key_opt(common_params_context & ctx_params) {
148
+ std::map<std::string, common_arg> mapping;
149
+ for (const auto & opt : ctx_params.options) {
150
+ for (const auto & env : opt.get_env()) {
151
+ mapping[env] = opt;
152
+ }
153
+ for (const auto & arg : opt.get_args()) {
154
+ mapping[rm_leading_dashes(arg)] = opt;
155
+ }
156
+ }
157
+ return mapping;
158
+ }
159
+
160
+ static bool is_bool_arg(const common_arg & arg) {
161
+ return !arg.args_neg.empty();
162
+ }
163
+
164
+ static std::string parse_bool_arg(const common_arg & arg, const std::string & key, const std::string & value) {
165
+ // if this is a negated arg, we need to reverse the value
166
+ for (const auto & neg_arg : arg.args_neg) {
167
+ if (rm_leading_dashes(neg_arg) == key) {
168
+ return common_arg_utils::is_truthy(value) ? "false" : "true";
169
+ }
170
+ }
171
+ // otherwise, not negated
172
+ return value;
173
+ }
174
+
175
+ common_presets common_presets_load(const std::string & path, common_params_context & ctx_params) {
176
+ common_presets out;
177
+ auto key_to_opt = get_map_key_opt(ctx_params);
178
+ auto ini_data = parse_ini_from_file(path);
179
+
180
+ for (auto section : ini_data) {
181
+ common_preset preset;
182
+ if (section.first.empty()) {
183
+ preset.name = COMMON_PRESET_DEFAULT_NAME;
184
+ } else {
185
+ preset.name = section.first;
186
+ }
187
+ LOG_DBG("loading preset: %s\n", preset.name.c_str());
188
+ for (const auto & [key, value] : section.second) {
189
+ LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str());
190
+ if (key_to_opt.find(key) != key_to_opt.end()) {
191
+ auto & opt = key_to_opt[key];
192
+ if (is_bool_arg(opt)) {
193
+ preset.options[opt] = parse_bool_arg(opt, key, value);
194
+ } else {
195
+ preset.options[opt] = value;
196
+ }
197
+ LOG_DBG("accepted option: %s = %s\n", key.c_str(), preset.options[opt].c_str());
198
+ } else {
199
+ // TODO: maybe warn about unknown key?
200
+ }
201
+ }
202
+ out[preset.name] = preset;
203
+ }
204
+
205
+ return out;
206
+ }
@@ -0,0 +1,32 @@
1
+ #pragma once
2
+
3
+ #include "common.h"
4
+ #include "arg.h"
5
+
6
+ #include <string>
7
+ #include <vector>
8
+ #include <map>
9
+
10
+ //
11
+ // INI preset parser and writer
12
+ //
13
+
14
+ constexpr const char * COMMON_PRESET_DEFAULT_NAME = "default";
15
+
16
+ struct common_preset {
17
+ std::string name;
18
+ // TODO: support repeated args in the future
19
+ std::map<common_arg, std::string> options;
20
+
21
+ // convert preset to CLI argument list
22
+ std::vector<std::string> to_args() const;
23
+
24
+ // convert preset to INI format string
25
+ std::string to_ini() const;
26
+
27
+ // TODO: maybe implement to_env() if needed
28
+ };
29
+
30
+ // interface for multiple presets in one file
31
+ using common_presets = std::map<std::string, common_preset>;
32
+ common_presets common_presets_load(const std::string & path, common_params_context & ctx_params);
@@ -104,9 +104,10 @@ struct ring_buffer {
104
104
  struct common_sampler {
105
105
  common_params_sampling params;
106
106
 
107
- struct llama_sampler * grmr;
108
107
  struct llama_sampler * chain;
109
108
 
109
+ bool grammar;
110
+
110
111
  ring_buffer<llama_token> prev;
111
112
 
112
113
  std::vector<llama_token_data> cur;
@@ -116,7 +117,6 @@ struct common_sampler {
116
117
  void reset() {
117
118
  prev.clear();
118
119
 
119
- llama_sampler_reset(grmr);
120
120
  llama_sampler_reset(chain);
121
121
  }
122
122
 
@@ -167,10 +167,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
167
167
 
168
168
  lparams.no_perf = params.no_perf;
169
169
 
170
- struct llama_sampler * grmr;
170
+ llama_sampler * chain = llama_sampler_chain_init(lparams);
171
+
172
+ bool grammar = false;
173
+ std::vector<llama_sampler *> samplers;
174
+
171
175
  if (params.grammar.compare(0, 11, "%llguidance") == 0) {
172
176
  #ifdef LLAMA_USE_LLGUIDANCE
173
- grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
177
+ samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
178
+ grammar = true;
174
179
  #else
175
180
  GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
176
181
  #endif // LLAMA_USE_LLGUIDANCE
@@ -217,30 +222,23 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
217
222
  trigger_patterns_c.push_back(regex.c_str());
218
223
  }
219
224
 
220
- grmr = params.grammar_lazy
221
- ? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
222
- trigger_patterns_c.data(), trigger_patterns_c.size(),
223
- trigger_tokens.data(), trigger_tokens.size())
224
- : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
225
- if (!grmr) {
226
- return nullptr;
225
+ if (!params.grammar.empty()) {
226
+ if (params.grammar_lazy) {
227
+ samplers.push_back(
228
+ llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
229
+ trigger_patterns_c.data(), trigger_patterns_c.size(),
230
+ trigger_tokens.data(), trigger_tokens.size()));
231
+ } else {
232
+ samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
233
+ }
234
+
235
+ grammar = true;
227
236
  }
228
237
  }
229
238
 
230
- auto * result = new common_sampler {
231
- /* .params = */ params,
232
- /* .grmr = */ grmr,
233
- /* .chain = */ llama_sampler_chain_init(lparams),
234
- /* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
235
- /* .cur = */ {},
236
- /* .cur_p = */ {},
237
- };
238
-
239
- llama_sampler_chain_add(result->chain,
240
- llama_sampler_init_logit_bias(
241
- llama_vocab_n_tokens(vocab),
242
- params.logit_bias.size(),
243
- params.logit_bias.data()));
239
+ if (params.has_logit_bias()) {
240
+ samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
241
+ }
244
242
 
245
243
  if (params.mirostat == 0) {
246
244
  for (const auto & cnstr : params.samplers) {
@@ -253,58 +251,70 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
253
251
  c_breakers.push_back(str.c_str());
254
252
  }
255
253
 
256
- llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
254
+ samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
257
255
  }
258
256
  break;
259
257
  case COMMON_SAMPLER_TYPE_TOP_K:
260
- llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
258
+ samplers.push_back(llama_sampler_init_top_k (params.top_k));
261
259
  break;
262
260
  case COMMON_SAMPLER_TYPE_TOP_P:
263
- llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
261
+ samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
264
262
  break;
265
263
  case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
266
- llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
264
+ samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
267
265
  break;
268
266
  case COMMON_SAMPLER_TYPE_MIN_P:
269
- llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
267
+ samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
270
268
  break;
271
269
  case COMMON_SAMPLER_TYPE_XTC:
272
- llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
270
+ samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
273
271
  break;
274
272
  case COMMON_SAMPLER_TYPE_TYPICAL_P:
275
- llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
273
+ samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
276
274
  break;
277
275
  case COMMON_SAMPLER_TYPE_TEMPERATURE:
278
- llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
276
+ samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
279
277
  break;
280
278
  case COMMON_SAMPLER_TYPE_INFILL:
281
- llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
279
+ samplers.push_back(llama_sampler_init_infill (vocab));
282
280
  break;
283
281
  case COMMON_SAMPLER_TYPE_PENALTIES:
284
- llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
282
+ samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
285
283
  break;
286
284
  default:
287
285
  GGML_ASSERT(false && "unknown sampler type");
288
286
  }
289
287
  }
290
- llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
288
+
289
+ samplers.push_back(llama_sampler_init_dist(params.seed));
291
290
  } else if (params.mirostat == 1) {
292
- llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
293
- llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
291
+ samplers.push_back(llama_sampler_init_temp(params.temp));
292
+ samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
294
293
  } else if (params.mirostat == 2) {
295
- llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
296
- llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
294
+ samplers.push_back(llama_sampler_init_temp(params.temp));
295
+ samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
297
296
  } else {
298
297
  GGML_ASSERT(false && "unknown mirostat version");
299
298
  }
300
299
 
300
+ for (auto * smpl : samplers) {
301
+ llama_sampler_chain_add(chain, smpl);
302
+ }
303
+
304
+ auto * result = new common_sampler {
305
+ /* .params = */ params,
306
+ /* .chain = */ chain,
307
+ /* .grammar = */ grammar,
308
+ /* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
309
+ /* .cur = */ {},
310
+ /* .cur_p = */ {},
311
+ };
312
+
301
313
  return result;
302
314
  }
303
315
 
304
316
  void common_sampler_free(struct common_sampler * gsmpl) {
305
317
  if (gsmpl) {
306
- llama_sampler_free(gsmpl->grmr);
307
-
308
318
  llama_sampler_free(gsmpl->chain);
309
319
 
310
320
  delete gsmpl;
@@ -314,11 +324,24 @@ void common_sampler_free(struct common_sampler * gsmpl) {
314
324
  void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
315
325
  const auto tm = gsmpl->tm();
316
326
 
317
- if (accept_grammar) {
318
- llama_sampler_accept(gsmpl->grmr, token);
319
- }
327
+ if (gsmpl->grammar) {
328
+ const int n_smpl = llama_sampler_chain_n(gsmpl->chain);
320
329
 
321
- llama_sampler_accept(gsmpl->chain, token);
330
+ for (int i = 0; i < n_smpl; i++) {
331
+ auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
332
+
333
+ // the grammar sampler is always the first one
334
+ if (i == 0) {
335
+ if (accept_grammar) {
336
+ llama_sampler_accept(smpl, token);
337
+ }
338
+ } else {
339
+ llama_sampler_accept(smpl, token);
340
+ }
341
+ }
342
+ } else {
343
+ llama_sampler_accept(gsmpl->chain, token);
344
+ }
322
345
 
323
346
  gsmpl->prev.push_back(token);
324
347
  }
@@ -329,12 +352,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
329
352
 
330
353
  struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
331
354
  return new common_sampler {
332
- /* .params = */ gsmpl->params,
333
- /* .grmr = */ llama_sampler_clone(gsmpl->grmr),
334
- /* .chain = */ llama_sampler_clone(gsmpl->chain),
335
- /* .prev = */ gsmpl->prev,
336
- /* .cur = */ gsmpl->cur,
337
- /* .cur_p = */ gsmpl->cur_p,
355
+ /* .params = */ gsmpl->params,
356
+ /* .chain = */ llama_sampler_clone(gsmpl->chain),
357
+ /* .grammar = */ gsmpl->grammar,
358
+ /* .prev = */ gsmpl->prev,
359
+ /* .cur = */ gsmpl->cur,
360
+ /* .cur_p = */ gsmpl->cur_p,
338
361
  };
339
362
  }
340
363
 
@@ -383,58 +406,33 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
383
406
  }
384
407
  }
385
408
 
386
- llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
409
+ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
410
+ return gsmpl->chain;
411
+ }
412
+
413
+ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
387
414
  llama_synchronize(ctx);
388
415
 
389
416
  // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
390
417
  const auto tm = gsmpl->tm();
391
418
 
392
- gsmpl->set_logits(ctx, idx);
419
+ llama_token id = LLAMA_TOKEN_NULL;
393
420
 
394
- auto & grmr = gsmpl->grmr;
395
421
  auto & chain = gsmpl->chain;
396
422
  auto & cur_p = gsmpl->cur_p; // initialized by set_logits
397
423
 
398
- if (grammar_first) {
399
- llama_sampler_apply(grmr, &cur_p);
400
- }
424
+ gsmpl->set_logits(ctx, idx);
401
425
 
402
426
  llama_sampler_apply(chain, &cur_p);
403
427
 
404
428
  GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
405
429
 
406
- const llama_token id = cur_p.data[cur_p.selected].id;
407
-
408
- if (grammar_first) {
409
- return id;
410
- }
411
-
412
- // check if it the sampled token fits the grammar
413
- {
414
- llama_token_data single_token_data = { id, 1.0f, 0.0f };
415
- llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
416
-
417
- llama_sampler_apply(grmr, &single_token_data_array);
418
-
419
- const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
420
- if (is_valid) {
421
- return id;
422
- }
423
- }
424
-
425
- // resampling:
426
- // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
427
- gsmpl->set_logits(ctx, idx);
428
-
429
- llama_sampler_apply(grmr, &cur_p);
430
- llama_sampler_apply(chain, &cur_p);
431
-
432
- GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
430
+ id = cur_p.data[cur_p.selected].id;
433
431
 
434
- return cur_p.data[cur_p.selected].id;
432
+ return id;
435
433
  }
436
434
 
437
- std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
435
+ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) {
438
436
  GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
439
437
 
440
438
  std::vector<llama_token> result;
@@ -442,7 +440,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
442
440
 
443
441
  size_t i = 0;
444
442
  for (; i < draft.size(); i++) {
445
- const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
443
+ const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
446
444
 
447
445
  common_sampler_accept(gsmpl, id, true);
448
446
 
@@ -454,7 +452,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
454
452
  }
455
453
 
456
454
  if (i == draft.size()) {
457
- const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
455
+ const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
458
456
 
459
457
  common_sampler_accept(gsmpl, id, true);
460
458
 
@@ -464,13 +462,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
464
462
  return result;
465
463
  }
466
464
 
467
- std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
465
+ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
468
466
  std::vector<int> idxs(draft.size() + 1);
469
467
  for (size_t i = 0; i < idxs.size(); ++i) {
470
468
  idxs[i] = i;
471
469
  }
472
470
 
473
- return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
471
+ return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
474
472
  }
475
473
 
476
474
  uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
@@ -515,7 +513,8 @@ std::string common_sampler_print(const struct common_sampler * gsmpl) {
515
513
 
516
514
  for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
517
515
  const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
518
- result += std::string("-> ") + llama_sampler_name(smpl) + " ";
516
+ result += std::string("-> ");
517
+ result += std::string(llama_sampler_name(smpl)) + " ";
519
518
  }
520
519
 
521
520
  return result;