@fugood/llama.node 1.4.7 → 1.4.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/lib/binding.ts +8 -0
- package/package.json +15 -15
- package/scripts/llama.cpp.patch +22 -23
- package/src/LlamaContext.cpp +2 -2
- package/src/llama.cpp/common/CMakeLists.txt +2 -0
- package/src/llama.cpp/common/arg.cpp +364 -193
- package/src/llama.cpp/common/arg.h +43 -2
- package/src/llama.cpp/common/chat-peg-parser.cpp +16 -2
- package/src/llama.cpp/common/chat.cpp +140 -0
- package/src/llama.cpp/common/common.cpp +130 -67
- package/src/llama.cpp/common/common.h +40 -16
- package/src/llama.cpp/common/console.cpp +98 -18
- package/src/llama.cpp/common/console.h +30 -8
- package/src/llama.cpp/common/download.cpp +69 -25
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +132 -3
- package/src/llama.cpp/common/json-schema-to-grammar.h +20 -0
- package/src/llama.cpp/common/log.cpp +5 -0
- package/src/llama.cpp/common/log.h +1 -0
- package/src/llama.cpp/common/peg-parser.cpp +1 -1
- package/src/llama.cpp/common/preset.cpp +206 -0
- package/src/llama.cpp/common/preset.h +32 -0
- package/src/llama.cpp/common/sampling.cpp +91 -92
- package/src/llama.cpp/common/sampling.h +11 -6
- package/src/llama.cpp/common/speculative.cpp +1 -1
- package/src/llama.cpp/ggml/CMakeLists.txt +4 -0
- package/src/llama.cpp/ggml/include/ggml-alloc.h +9 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +1 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +1 -0
- package/src/llama.cpp/ggml/include/ggml.h +7 -8
- package/src/llama.cpp/ggml/src/CMakeLists.txt +3 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +60 -39
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +2 -1
- package/src/llama.cpp/include/llama.h +18 -1
- package/src/llama.cpp/src/llama-arch.cpp +1890 -2248
- package/src/llama.cpp/src/llama-arch.h +9 -2
- package/src/llama.cpp/src/llama-batch.cpp +12 -2
- package/src/llama.cpp/src/llama-batch.h +4 -2
- package/src/llama.cpp/src/llama-context.cpp +93 -23
- package/src/llama.cpp/src/llama-context.h +8 -2
- package/src/llama.cpp/src/llama-graph.cpp +84 -16
- package/src/llama.cpp/src/llama-graph.h +17 -4
- package/src/llama.cpp/src/llama-hparams.cpp +6 -0
- package/src/llama.cpp/src/llama-hparams.h +5 -1
- package/src/llama.cpp/src/llama-impl.cpp +4 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +90 -42
- package/src/llama.cpp/src/llama-kv-cache.h +19 -2
- package/src/llama.cpp/src/llama-memory-hybrid.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +2 -0
- package/src/llama.cpp/src/llama-model-loader.h +2 -0
- package/src/llama.cpp/src/llama-model.cpp +103 -44
- package/src/llama.cpp/src/llama-model.h +1 -0
- package/src/llama.cpp/src/llama-quant.cpp +1 -1
- package/src/llama.cpp/src/llama-vocab.cpp +2 -1
- package/src/llama.cpp/src/llama.cpp +675 -1
- package/src/llama.cpp/src/models/deepseek2.cpp +9 -5
- package/src/llama.cpp/src/models/glm4-moe.cpp +28 -11
- package/src/llama.cpp/src/models/glm4.cpp +27 -4
- package/src/llama.cpp/src/models/models.h +5 -5
- package/src/llama.cpp/src/models/nemotron-h.cpp +35 -6
- package/src/llama.cpp/src/models/qwen2.cpp +12 -3
- package/src/llama.cpp/src/models/qwen3next.cpp +81 -266
|
@@ -3,11 +3,31 @@
|
|
|
3
3
|
#include <nlohmann/json_fwd.hpp>
|
|
4
4
|
|
|
5
5
|
#include <functional>
|
|
6
|
+
#include <memory>
|
|
6
7
|
#include <string>
|
|
7
8
|
|
|
8
9
|
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
|
|
9
10
|
bool force_gbnf = false);
|
|
10
11
|
|
|
12
|
+
class common_schema_converter;
|
|
13
|
+
|
|
14
|
+
// Probes a JSON schema to extract information about its structure and type constraints.
|
|
15
|
+
class common_schema_info {
|
|
16
|
+
std::unique_ptr<common_schema_converter> impl_;
|
|
17
|
+
|
|
18
|
+
public:
|
|
19
|
+
common_schema_info();
|
|
20
|
+
~common_schema_info();
|
|
21
|
+
|
|
22
|
+
common_schema_info(const common_schema_info &) = delete;
|
|
23
|
+
common_schema_info & operator=(const common_schema_info &) = delete;
|
|
24
|
+
common_schema_info(common_schema_info &&) noexcept;
|
|
25
|
+
common_schema_info & operator=(common_schema_info &&) noexcept;
|
|
26
|
+
|
|
27
|
+
void resolve_refs(nlohmann::ordered_json & schema);
|
|
28
|
+
bool resolves_to_string(const nlohmann::ordered_json & schema);
|
|
29
|
+
};
|
|
30
|
+
|
|
11
31
|
struct common_grammar_builder {
|
|
12
32
|
std::function<std::string(const std::string &, const std::string &)> add_rule;
|
|
13
33
|
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
|
|
@@ -420,6 +420,11 @@ void common_log_set_timestamps(struct common_log * log, bool timestamps) {
|
|
|
420
420
|
log->set_timestamps(timestamps);
|
|
421
421
|
}
|
|
422
422
|
|
|
423
|
+
void common_log_flush(struct common_log * log) {
|
|
424
|
+
log->pause();
|
|
425
|
+
log->resume();
|
|
426
|
+
}
|
|
427
|
+
|
|
423
428
|
static int common_get_verbosity(enum ggml_log_level level) {
|
|
424
429
|
switch (level) {
|
|
425
430
|
case GGML_LOG_LEVEL_DEBUG: return LOG_LEVEL_DEBUG;
|
|
@@ -84,6 +84,7 @@ void common_log_set_file (struct common_log * log, const char * file); // n
|
|
|
84
84
|
void common_log_set_colors (struct common_log * log, log_colors colors); // not thread-safe
|
|
85
85
|
void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log
|
|
86
86
|
void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix
|
|
87
|
+
void common_log_flush (struct common_log * log); // flush all pending log messages
|
|
87
88
|
|
|
88
89
|
// helper macros for logging
|
|
89
90
|
// use these to avoid computing log arguments if the verbosity of the log is higher than the threshold
|
|
@@ -425,7 +425,7 @@ struct parser_executor {
|
|
|
425
425
|
|
|
426
426
|
if (result.need_more_input()) {
|
|
427
427
|
// Propagate - need to know what child would match before negating
|
|
428
|
-
return
|
|
428
|
+
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos);
|
|
429
429
|
}
|
|
430
430
|
|
|
431
431
|
// Child failed, so negation succeeds
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
#include "arg.h"
|
|
2
|
+
#include "preset.h"
|
|
3
|
+
#include "peg-parser.h"
|
|
4
|
+
#include "log.h"
|
|
5
|
+
|
|
6
|
+
#include <fstream>
|
|
7
|
+
#include <sstream>
|
|
8
|
+
#include <filesystem>
|
|
9
|
+
|
|
10
|
+
static std::string rm_leading_dashes(const std::string & str) {
|
|
11
|
+
size_t pos = 0;
|
|
12
|
+
while (pos < str.size() && str[pos] == '-') {
|
|
13
|
+
++pos;
|
|
14
|
+
}
|
|
15
|
+
return str.substr(pos);
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
std::vector<std::string> common_preset::to_args() const {
|
|
19
|
+
std::vector<std::string> args;
|
|
20
|
+
|
|
21
|
+
for (const auto & [opt, value] : options) {
|
|
22
|
+
args.push_back(opt.args.back()); // use the last arg as the main arg
|
|
23
|
+
if (opt.value_hint == nullptr && opt.value_hint_2 == nullptr) {
|
|
24
|
+
// flag option, no value
|
|
25
|
+
if (common_arg_utils::is_falsey(value)) {
|
|
26
|
+
// use negative arg if available
|
|
27
|
+
if (!opt.args_neg.empty()) {
|
|
28
|
+
args.back() = opt.args_neg.back();
|
|
29
|
+
} else {
|
|
30
|
+
// otherwise, skip the flag
|
|
31
|
+
// TODO: maybe throw an error instead?
|
|
32
|
+
args.pop_back();
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
if (opt.value_hint != nullptr) {
|
|
37
|
+
// single value
|
|
38
|
+
args.push_back(value);
|
|
39
|
+
}
|
|
40
|
+
if (opt.value_hint != nullptr && opt.value_hint_2 != nullptr) {
|
|
41
|
+
throw std::runtime_error(string_format(
|
|
42
|
+
"common_preset::to_args(): option '%s' has two values, which is not supported yet",
|
|
43
|
+
opt.args.back()
|
|
44
|
+
));
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
return args;
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
std::string common_preset::to_ini() const {
|
|
52
|
+
std::ostringstream ss;
|
|
53
|
+
|
|
54
|
+
ss << "[" << name << "]\n";
|
|
55
|
+
for (const auto & [opt, value] : options) {
|
|
56
|
+
auto espaced_value = value;
|
|
57
|
+
string_replace_all(espaced_value, "\n", "\\\n");
|
|
58
|
+
ss << rm_leading_dashes(opt.args.back()) << " = ";
|
|
59
|
+
ss << espaced_value << "\n";
|
|
60
|
+
}
|
|
61
|
+
ss << "\n";
|
|
62
|
+
|
|
63
|
+
return ss.str();
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
static std::map<std::string, std::map<std::string, std::string>> parse_ini_from_file(const std::string & path) {
|
|
67
|
+
std::map<std::string, std::map<std::string, std::string>> parsed;
|
|
68
|
+
|
|
69
|
+
if (!std::filesystem::exists(path)) {
|
|
70
|
+
throw std::runtime_error("preset file does not exist: " + path);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
std::ifstream file(path);
|
|
74
|
+
if (!file.good()) {
|
|
75
|
+
throw std::runtime_error("failed to open server preset file: " + path);
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
std::string contents((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
|
79
|
+
|
|
80
|
+
static const auto parser = build_peg_parser([](auto & p) {
|
|
81
|
+
// newline ::= "\r\n" / "\n" / "\r"
|
|
82
|
+
auto newline = p.rule("newline", p.literal("\r\n") | p.literal("\n") | p.literal("\r"));
|
|
83
|
+
|
|
84
|
+
// ws ::= [ \t]*
|
|
85
|
+
auto ws = p.rule("ws", p.chars("[ \t]", 0, -1));
|
|
86
|
+
|
|
87
|
+
// comment ::= [;#] (!newline .)*
|
|
88
|
+
auto comment = p.rule("comment", p.chars("[;#]", 1, 1) + p.zero_or_more(p.negate(newline) + p.any()));
|
|
89
|
+
|
|
90
|
+
// eol ::= ws comment? (newline / EOF)
|
|
91
|
+
auto eol = p.rule("eol", ws + p.optional(comment) + (newline | p.end()));
|
|
92
|
+
|
|
93
|
+
// ident ::= [a-zA-Z_] [a-zA-Z0-9_.-]*
|
|
94
|
+
auto ident = p.rule("ident", p.chars("[a-zA-Z_]", 1, 1) + p.chars("[a-zA-Z0-9_.-]", 0, -1));
|
|
95
|
+
|
|
96
|
+
// value ::= (!eol-start .)*
|
|
97
|
+
auto eol_start = p.rule("eol-start", ws + (p.chars("[;#]", 1, 1) | newline | p.end()));
|
|
98
|
+
auto value = p.rule("value", p.zero_or_more(p.negate(eol_start) + p.any()));
|
|
99
|
+
|
|
100
|
+
// header-line ::= "[" ws ident ws "]" eol
|
|
101
|
+
auto header_line = p.rule("header-line", "[" + ws + p.tag("section-name", p.chars("[^]]")) + ws + "]" + eol);
|
|
102
|
+
|
|
103
|
+
// kv-line ::= ident ws "=" ws value eol
|
|
104
|
+
auto kv_line = p.rule("kv-line", p.tag("key", ident) + ws + "=" + ws + p.tag("value", value) + eol);
|
|
105
|
+
|
|
106
|
+
// comment-line ::= ws comment (newline / EOF)
|
|
107
|
+
auto comment_line = p.rule("comment-line", ws + comment + (newline | p.end()));
|
|
108
|
+
|
|
109
|
+
// blank-line ::= ws (newline / EOF)
|
|
110
|
+
auto blank_line = p.rule("blank-line", ws + (newline | p.end()));
|
|
111
|
+
|
|
112
|
+
// line ::= header-line / kv-line / comment-line / blank-line
|
|
113
|
+
auto line = p.rule("line", header_line | kv_line | comment_line | blank_line);
|
|
114
|
+
|
|
115
|
+
// ini ::= line* EOF
|
|
116
|
+
auto ini = p.rule("ini", p.zero_or_more(line) + p.end());
|
|
117
|
+
|
|
118
|
+
return ini;
|
|
119
|
+
});
|
|
120
|
+
|
|
121
|
+
common_peg_parse_context ctx(contents);
|
|
122
|
+
const auto result = parser.parse(ctx);
|
|
123
|
+
if (!result.success()) {
|
|
124
|
+
throw std::runtime_error("failed to parse server config file: " + path);
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
std::string current_section = COMMON_PRESET_DEFAULT_NAME;
|
|
128
|
+
std::string current_key;
|
|
129
|
+
|
|
130
|
+
ctx.ast.visit(result, [&](const auto & node) {
|
|
131
|
+
if (node.tag == "section-name") {
|
|
132
|
+
const std::string section = std::string(node.text);
|
|
133
|
+
current_section = section;
|
|
134
|
+
parsed[current_section] = {};
|
|
135
|
+
} else if (node.tag == "key") {
|
|
136
|
+
const std::string key = std::string(node.text);
|
|
137
|
+
current_key = key;
|
|
138
|
+
} else if (node.tag == "value" && !current_key.empty() && !current_section.empty()) {
|
|
139
|
+
parsed[current_section][current_key] = std::string(node.text);
|
|
140
|
+
current_key.clear();
|
|
141
|
+
}
|
|
142
|
+
});
|
|
143
|
+
|
|
144
|
+
return parsed;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
static std::map<std::string, common_arg> get_map_key_opt(common_params_context & ctx_params) {
|
|
148
|
+
std::map<std::string, common_arg> mapping;
|
|
149
|
+
for (const auto & opt : ctx_params.options) {
|
|
150
|
+
for (const auto & env : opt.get_env()) {
|
|
151
|
+
mapping[env] = opt;
|
|
152
|
+
}
|
|
153
|
+
for (const auto & arg : opt.get_args()) {
|
|
154
|
+
mapping[rm_leading_dashes(arg)] = opt;
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
return mapping;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
static bool is_bool_arg(const common_arg & arg) {
|
|
161
|
+
return !arg.args_neg.empty();
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
static std::string parse_bool_arg(const common_arg & arg, const std::string & key, const std::string & value) {
|
|
165
|
+
// if this is a negated arg, we need to reverse the value
|
|
166
|
+
for (const auto & neg_arg : arg.args_neg) {
|
|
167
|
+
if (rm_leading_dashes(neg_arg) == key) {
|
|
168
|
+
return common_arg_utils::is_truthy(value) ? "false" : "true";
|
|
169
|
+
}
|
|
170
|
+
}
|
|
171
|
+
// otherwise, not negated
|
|
172
|
+
return value;
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
common_presets common_presets_load(const std::string & path, common_params_context & ctx_params) {
|
|
176
|
+
common_presets out;
|
|
177
|
+
auto key_to_opt = get_map_key_opt(ctx_params);
|
|
178
|
+
auto ini_data = parse_ini_from_file(path);
|
|
179
|
+
|
|
180
|
+
for (auto section : ini_data) {
|
|
181
|
+
common_preset preset;
|
|
182
|
+
if (section.first.empty()) {
|
|
183
|
+
preset.name = COMMON_PRESET_DEFAULT_NAME;
|
|
184
|
+
} else {
|
|
185
|
+
preset.name = section.first;
|
|
186
|
+
}
|
|
187
|
+
LOG_DBG("loading preset: %s\n", preset.name.c_str());
|
|
188
|
+
for (const auto & [key, value] : section.second) {
|
|
189
|
+
LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str());
|
|
190
|
+
if (key_to_opt.find(key) != key_to_opt.end()) {
|
|
191
|
+
auto & opt = key_to_opt[key];
|
|
192
|
+
if (is_bool_arg(opt)) {
|
|
193
|
+
preset.options[opt] = parse_bool_arg(opt, key, value);
|
|
194
|
+
} else {
|
|
195
|
+
preset.options[opt] = value;
|
|
196
|
+
}
|
|
197
|
+
LOG_DBG("accepted option: %s = %s\n", key.c_str(), preset.options[opt].c_str());
|
|
198
|
+
} else {
|
|
199
|
+
// TODO: maybe warn about unknown key?
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
out[preset.name] = preset;
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
return out;
|
|
206
|
+
}
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include "common.h"
|
|
4
|
+
#include "arg.h"
|
|
5
|
+
|
|
6
|
+
#include <string>
|
|
7
|
+
#include <vector>
|
|
8
|
+
#include <map>
|
|
9
|
+
|
|
10
|
+
//
|
|
11
|
+
// INI preset parser and writer
|
|
12
|
+
//
|
|
13
|
+
|
|
14
|
+
constexpr const char * COMMON_PRESET_DEFAULT_NAME = "default";
|
|
15
|
+
|
|
16
|
+
struct common_preset {
|
|
17
|
+
std::string name;
|
|
18
|
+
// TODO: support repeated args in the future
|
|
19
|
+
std::map<common_arg, std::string> options;
|
|
20
|
+
|
|
21
|
+
// convert preset to CLI argument list
|
|
22
|
+
std::vector<std::string> to_args() const;
|
|
23
|
+
|
|
24
|
+
// convert preset to INI format string
|
|
25
|
+
std::string to_ini() const;
|
|
26
|
+
|
|
27
|
+
// TODO: maybe implement to_env() if needed
|
|
28
|
+
};
|
|
29
|
+
|
|
30
|
+
// interface for multiple presets in one file
|
|
31
|
+
using common_presets = std::map<std::string, common_preset>;
|
|
32
|
+
common_presets common_presets_load(const std::string & path, common_params_context & ctx_params);
|
|
@@ -104,9 +104,10 @@ struct ring_buffer {
|
|
|
104
104
|
struct common_sampler {
|
|
105
105
|
common_params_sampling params;
|
|
106
106
|
|
|
107
|
-
struct llama_sampler * grmr;
|
|
108
107
|
struct llama_sampler * chain;
|
|
109
108
|
|
|
109
|
+
bool grammar;
|
|
110
|
+
|
|
110
111
|
ring_buffer<llama_token> prev;
|
|
111
112
|
|
|
112
113
|
std::vector<llama_token_data> cur;
|
|
@@ -116,7 +117,6 @@ struct common_sampler {
|
|
|
116
117
|
void reset() {
|
|
117
118
|
prev.clear();
|
|
118
119
|
|
|
119
|
-
llama_sampler_reset(grmr);
|
|
120
120
|
llama_sampler_reset(chain);
|
|
121
121
|
}
|
|
122
122
|
|
|
@@ -167,10 +167,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|
|
167
167
|
|
|
168
168
|
lparams.no_perf = params.no_perf;
|
|
169
169
|
|
|
170
|
-
|
|
170
|
+
llama_sampler * chain = llama_sampler_chain_init(lparams);
|
|
171
|
+
|
|
172
|
+
bool grammar = false;
|
|
173
|
+
std::vector<llama_sampler *> samplers;
|
|
174
|
+
|
|
171
175
|
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
|
172
176
|
#ifdef LLAMA_USE_LLGUIDANCE
|
|
173
|
-
|
|
177
|
+
samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
|
|
178
|
+
grammar = true;
|
|
174
179
|
#else
|
|
175
180
|
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
|
176
181
|
#endif // LLAMA_USE_LLGUIDANCE
|
|
@@ -217,30 +222,23 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|
|
217
222
|
trigger_patterns_c.push_back(regex.c_str());
|
|
218
223
|
}
|
|
219
224
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
225
|
+
if (!params.grammar.empty()) {
|
|
226
|
+
if (params.grammar_lazy) {
|
|
227
|
+
samplers.push_back(
|
|
228
|
+
llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
|
229
|
+
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
|
230
|
+
trigger_tokens.data(), trigger_tokens.size()));
|
|
231
|
+
} else {
|
|
232
|
+
samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
grammar = true;
|
|
227
236
|
}
|
|
228
237
|
}
|
|
229
238
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
/* .chain = */ llama_sampler_chain_init(lparams),
|
|
234
|
-
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
|
235
|
-
/* .cur = */ {},
|
|
236
|
-
/* .cur_p = */ {},
|
|
237
|
-
};
|
|
238
|
-
|
|
239
|
-
llama_sampler_chain_add(result->chain,
|
|
240
|
-
llama_sampler_init_logit_bias(
|
|
241
|
-
llama_vocab_n_tokens(vocab),
|
|
242
|
-
params.logit_bias.size(),
|
|
243
|
-
params.logit_bias.data()));
|
|
239
|
+
if (params.has_logit_bias()) {
|
|
240
|
+
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
|
|
241
|
+
}
|
|
244
242
|
|
|
245
243
|
if (params.mirostat == 0) {
|
|
246
244
|
for (const auto & cnstr : params.samplers) {
|
|
@@ -253,58 +251,70 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|
|
253
251
|
c_breakers.push_back(str.c_str());
|
|
254
252
|
}
|
|
255
253
|
|
|
256
|
-
|
|
254
|
+
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
|
257
255
|
}
|
|
258
256
|
break;
|
|
259
257
|
case COMMON_SAMPLER_TYPE_TOP_K:
|
|
260
|
-
|
|
258
|
+
samplers.push_back(llama_sampler_init_top_k (params.top_k));
|
|
261
259
|
break;
|
|
262
260
|
case COMMON_SAMPLER_TYPE_TOP_P:
|
|
263
|
-
|
|
261
|
+
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
|
|
264
262
|
break;
|
|
265
263
|
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
|
266
|
-
|
|
264
|
+
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
|
|
267
265
|
break;
|
|
268
266
|
case COMMON_SAMPLER_TYPE_MIN_P:
|
|
269
|
-
|
|
267
|
+
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
|
|
270
268
|
break;
|
|
271
269
|
case COMMON_SAMPLER_TYPE_XTC:
|
|
272
|
-
|
|
270
|
+
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
|
273
271
|
break;
|
|
274
272
|
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
|
275
|
-
|
|
273
|
+
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
|
|
276
274
|
break;
|
|
277
275
|
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
|
278
|
-
|
|
276
|
+
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
|
279
277
|
break;
|
|
280
278
|
case COMMON_SAMPLER_TYPE_INFILL:
|
|
281
|
-
|
|
279
|
+
samplers.push_back(llama_sampler_init_infill (vocab));
|
|
282
280
|
break;
|
|
283
281
|
case COMMON_SAMPLER_TYPE_PENALTIES:
|
|
284
|
-
|
|
282
|
+
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
|
285
283
|
break;
|
|
286
284
|
default:
|
|
287
285
|
GGML_ASSERT(false && "unknown sampler type");
|
|
288
286
|
}
|
|
289
287
|
}
|
|
290
|
-
|
|
288
|
+
|
|
289
|
+
samplers.push_back(llama_sampler_init_dist(params.seed));
|
|
291
290
|
} else if (params.mirostat == 1) {
|
|
292
|
-
|
|
293
|
-
|
|
291
|
+
samplers.push_back(llama_sampler_init_temp(params.temp));
|
|
292
|
+
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
|
294
293
|
} else if (params.mirostat == 2) {
|
|
295
|
-
|
|
296
|
-
|
|
294
|
+
samplers.push_back(llama_sampler_init_temp(params.temp));
|
|
295
|
+
samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
|
297
296
|
} else {
|
|
298
297
|
GGML_ASSERT(false && "unknown mirostat version");
|
|
299
298
|
}
|
|
300
299
|
|
|
300
|
+
for (auto * smpl : samplers) {
|
|
301
|
+
llama_sampler_chain_add(chain, smpl);
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
auto * result = new common_sampler {
|
|
305
|
+
/* .params = */ params,
|
|
306
|
+
/* .chain = */ chain,
|
|
307
|
+
/* .grammar = */ grammar,
|
|
308
|
+
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
|
309
|
+
/* .cur = */ {},
|
|
310
|
+
/* .cur_p = */ {},
|
|
311
|
+
};
|
|
312
|
+
|
|
301
313
|
return result;
|
|
302
314
|
}
|
|
303
315
|
|
|
304
316
|
void common_sampler_free(struct common_sampler * gsmpl) {
|
|
305
317
|
if (gsmpl) {
|
|
306
|
-
llama_sampler_free(gsmpl->grmr);
|
|
307
|
-
|
|
308
318
|
llama_sampler_free(gsmpl->chain);
|
|
309
319
|
|
|
310
320
|
delete gsmpl;
|
|
@@ -314,11 +324,24 @@ void common_sampler_free(struct common_sampler * gsmpl) {
|
|
|
314
324
|
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
|
315
325
|
const auto tm = gsmpl->tm();
|
|
316
326
|
|
|
317
|
-
if (
|
|
318
|
-
|
|
319
|
-
}
|
|
327
|
+
if (gsmpl->grammar) {
|
|
328
|
+
const int n_smpl = llama_sampler_chain_n(gsmpl->chain);
|
|
320
329
|
|
|
321
|
-
|
|
330
|
+
for (int i = 0; i < n_smpl; i++) {
|
|
331
|
+
auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
|
332
|
+
|
|
333
|
+
// the grammar sampler is always the first one
|
|
334
|
+
if (i == 0) {
|
|
335
|
+
if (accept_grammar) {
|
|
336
|
+
llama_sampler_accept(smpl, token);
|
|
337
|
+
}
|
|
338
|
+
} else {
|
|
339
|
+
llama_sampler_accept(smpl, token);
|
|
340
|
+
}
|
|
341
|
+
}
|
|
342
|
+
} else {
|
|
343
|
+
llama_sampler_accept(gsmpl->chain, token);
|
|
344
|
+
}
|
|
322
345
|
|
|
323
346
|
gsmpl->prev.push_back(token);
|
|
324
347
|
}
|
|
@@ -329,12 +352,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
|
|
|
329
352
|
|
|
330
353
|
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
|
331
354
|
return new common_sampler {
|
|
332
|
-
/* .params
|
|
333
|
-
/* .
|
|
334
|
-
/* .
|
|
335
|
-
/* .prev
|
|
336
|
-
/* .cur
|
|
337
|
-
/* .cur_p
|
|
355
|
+
/* .params = */ gsmpl->params,
|
|
356
|
+
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
|
357
|
+
/* .grammar = */ gsmpl->grammar,
|
|
358
|
+
/* .prev = */ gsmpl->prev,
|
|
359
|
+
/* .cur = */ gsmpl->cur,
|
|
360
|
+
/* .cur_p = */ gsmpl->cur_p,
|
|
338
361
|
};
|
|
339
362
|
}
|
|
340
363
|
|
|
@@ -383,58 +406,33 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
|
|
|
383
406
|
}
|
|
384
407
|
}
|
|
385
408
|
|
|
386
|
-
|
|
409
|
+
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
|
|
410
|
+
return gsmpl->chain;
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
|
|
387
414
|
llama_synchronize(ctx);
|
|
388
415
|
|
|
389
416
|
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
|
390
417
|
const auto tm = gsmpl->tm();
|
|
391
418
|
|
|
392
|
-
|
|
419
|
+
llama_token id = LLAMA_TOKEN_NULL;
|
|
393
420
|
|
|
394
|
-
auto & grmr = gsmpl->grmr;
|
|
395
421
|
auto & chain = gsmpl->chain;
|
|
396
422
|
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
|
397
423
|
|
|
398
|
-
|
|
399
|
-
llama_sampler_apply(grmr, &cur_p);
|
|
400
|
-
}
|
|
424
|
+
gsmpl->set_logits(ctx, idx);
|
|
401
425
|
|
|
402
426
|
llama_sampler_apply(chain, &cur_p);
|
|
403
427
|
|
|
404
428
|
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
|
405
429
|
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
if (grammar_first) {
|
|
409
|
-
return id;
|
|
410
|
-
}
|
|
411
|
-
|
|
412
|
-
// check if it the sampled token fits the grammar
|
|
413
|
-
{
|
|
414
|
-
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
|
415
|
-
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
|
|
416
|
-
|
|
417
|
-
llama_sampler_apply(grmr, &single_token_data_array);
|
|
418
|
-
|
|
419
|
-
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
|
420
|
-
if (is_valid) {
|
|
421
|
-
return id;
|
|
422
|
-
}
|
|
423
|
-
}
|
|
424
|
-
|
|
425
|
-
// resampling:
|
|
426
|
-
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
|
|
427
|
-
gsmpl->set_logits(ctx, idx);
|
|
428
|
-
|
|
429
|
-
llama_sampler_apply(grmr, &cur_p);
|
|
430
|
-
llama_sampler_apply(chain, &cur_p);
|
|
431
|
-
|
|
432
|
-
GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
|
|
430
|
+
id = cur_p.data[cur_p.selected].id;
|
|
433
431
|
|
|
434
|
-
return
|
|
432
|
+
return id;
|
|
435
433
|
}
|
|
436
434
|
|
|
437
|
-
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft
|
|
435
|
+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) {
|
|
438
436
|
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
|
439
437
|
|
|
440
438
|
std::vector<llama_token> result;
|
|
@@ -442,7 +440,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
|
|
442
440
|
|
|
443
441
|
size_t i = 0;
|
|
444
442
|
for (; i < draft.size(); i++) {
|
|
445
|
-
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]
|
|
443
|
+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
|
|
446
444
|
|
|
447
445
|
common_sampler_accept(gsmpl, id, true);
|
|
448
446
|
|
|
@@ -454,7 +452,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
|
|
454
452
|
}
|
|
455
453
|
|
|
456
454
|
if (i == draft.size()) {
|
|
457
|
-
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]
|
|
455
|
+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
|
|
458
456
|
|
|
459
457
|
common_sampler_accept(gsmpl, id, true);
|
|
460
458
|
|
|
@@ -464,13 +462,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
|
|
464
462
|
return result;
|
|
465
463
|
}
|
|
466
464
|
|
|
467
|
-
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft
|
|
465
|
+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
|
|
468
466
|
std::vector<int> idxs(draft.size() + 1);
|
|
469
467
|
for (size_t i = 0; i < idxs.size(); ++i) {
|
|
470
468
|
idxs[i] = i;
|
|
471
469
|
}
|
|
472
470
|
|
|
473
|
-
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft
|
|
471
|
+
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
|
|
474
472
|
}
|
|
475
473
|
|
|
476
474
|
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
|
@@ -515,7 +513,8 @@ std::string common_sampler_print(const struct common_sampler * gsmpl) {
|
|
|
515
513
|
|
|
516
514
|
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
|
517
515
|
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
|
518
|
-
result += std::string("-> ")
|
|
516
|
+
result += std::string("-> ");
|
|
517
|
+
result += std::string(llama_sampler_name(smpl)) + " ";
|
|
519
518
|
}
|
|
520
519
|
|
|
521
520
|
return result;
|