@fugood/llama.node 1.4.13 → 1.4.15
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/lib/binding.ts +23 -2
- package/lib/index.js +2 -1
- package/lib/index.ts +8 -1
- package/lib/parallel.ts +2 -2
- package/package.json +15 -15
- package/scripts/llama.cpp.patch +9 -12
- package/src/LlamaContext.cpp +16 -4
- package/src/llama.cpp/CMakeLists.txt +24 -8
- package/src/llama.cpp/common/CMakeLists.txt +3 -34
- package/src/llama.cpp/common/arg.cpp +183 -60
- package/src/llama.cpp/common/arg.h +0 -8
- package/src/llama.cpp/common/chat-parser.cpp +115 -0
- package/src/llama.cpp/common/chat.cpp +67 -0
- package/src/llama.cpp/common/chat.h +1 -0
- package/src/llama.cpp/common/common.cpp +2 -1
- package/src/llama.cpp/common/common.h +12 -7
- package/src/llama.cpp/common/debug.cpp +165 -0
- package/src/llama.cpp/common/debug.h +43 -0
- package/src/llama.cpp/common/download.cpp +88 -369
- package/src/llama.cpp/common/download.h +32 -5
- package/src/llama.cpp/common/preset.cpp +87 -2
- package/src/llama.cpp/common/preset.h +10 -1
- package/src/llama.cpp/ggml/include/ggml.h +5 -0
- package/src/llama.cpp/include/llama.h +5 -2
- package/src/llama.cpp/src/CMakeLists.txt +1 -0
- package/src/llama.cpp/src/llama-arch.cpp +35 -0
- package/src/llama.cpp/src/llama-arch.h +1 -0
- package/src/llama.cpp/src/llama-chat.cpp +20 -0
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +31 -43
- package/src/llama.cpp/src/llama-mmap.cpp +78 -42
- package/src/llama.cpp/src/llama-mmap.h +5 -4
- package/src/llama.cpp/src/llama-model-loader.cpp +17 -5
- package/src/llama.cpp/src/llama-model-loader.h +2 -0
- package/src/llama.cpp/src/llama-model.cpp +225 -101
- package/src/llama.cpp/src/llama-quant.cpp +1 -1
- package/src/llama.cpp/src/llama-sampling.cpp +1 -1
- package/src/llama.cpp/src/llama-vocab.cpp +37 -24
- package/src/llama.cpp/src/llama-vocab.h +1 -0
- package/src/llama.cpp/src/llama.cpp +63 -27
- package/src/llama.cpp/src/models/exaone-moe.cpp +146 -0
- package/src/llama.cpp/src/models/gemma3n-iswa.cpp +13 -3
- package/src/llama.cpp/src/models/models.h +13 -2
- package/src/llama.cpp/src/models/qwen3next.cpp +198 -182
|
@@ -1,12 +1,27 @@
|
|
|
1
1
|
#pragma once
|
|
2
2
|
|
|
3
3
|
#include <string>
|
|
4
|
+
#include <vector>
|
|
4
5
|
|
|
5
6
|
struct common_params_model;
|
|
6
7
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
8
|
+
using common_header = std::pair<std::string, std::string>;
|
|
9
|
+
using common_header_list = std::vector<common_header>;
|
|
10
|
+
|
|
11
|
+
struct common_remote_params {
|
|
12
|
+
common_header_list headers;
|
|
13
|
+
long timeout = 0; // in seconds, 0 means no timeout
|
|
14
|
+
long max_size = 0; // unlimited if 0
|
|
15
|
+
};
|
|
16
|
+
|
|
17
|
+
// get remote file content, returns <http_code, raw_response_body>
|
|
18
|
+
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params);
|
|
19
|
+
|
|
20
|
+
// split HF repo with tag into <repo, tag>
|
|
21
|
+
// for example: "user/model:tag" -> <"user/model", "tag">
|
|
22
|
+
// if tag is not present, default to "latest"
|
|
23
|
+
// example: "user/model" -> <"user/model", "latest">
|
|
24
|
+
std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag);
|
|
10
25
|
|
|
11
26
|
struct common_cached_model_info {
|
|
12
27
|
std::string manifest_path;
|
|
@@ -41,17 +56,29 @@ struct common_hf_file_res {
|
|
|
41
56
|
common_hf_file_res common_get_hf_file(
|
|
42
57
|
const std::string & hf_repo_with_tag,
|
|
43
58
|
const std::string & bearer_token,
|
|
44
|
-
bool offline
|
|
59
|
+
bool offline,
|
|
60
|
+
const common_header_list & headers = {}
|
|
61
|
+
);
|
|
45
62
|
|
|
46
63
|
// returns true if download succeeded
|
|
47
64
|
bool common_download_model(
|
|
48
65
|
const common_params_model & model,
|
|
49
66
|
const std::string & bearer_token,
|
|
50
|
-
bool offline
|
|
67
|
+
bool offline,
|
|
68
|
+
const common_header_list & headers = {}
|
|
69
|
+
);
|
|
51
70
|
|
|
52
71
|
// returns list of cached models
|
|
53
72
|
std::vector<common_cached_model_info> common_list_cached_models();
|
|
54
73
|
|
|
74
|
+
// download single file from url to local path
|
|
75
|
+
// returns status code or -1 on error
|
|
76
|
+
int common_download_file_single(const std::string & url,
|
|
77
|
+
const std::string & path,
|
|
78
|
+
const std::string & bearer_token,
|
|
79
|
+
bool offline,
|
|
80
|
+
const common_header_list & headers = {});
|
|
81
|
+
|
|
55
82
|
// resolve and download model from Docker registry
|
|
56
83
|
// return local path to downloaded model file
|
|
57
84
|
std::string common_docker_resolve_model(const std::string & docker);
|
|
@@ -16,6 +16,48 @@ static std::string rm_leading_dashes(const std::string & str) {
|
|
|
16
16
|
return str.substr(pos);
|
|
17
17
|
}
|
|
18
18
|
|
|
19
|
+
// only allow a subset of args for remote presets for security reasons
|
|
20
|
+
// do not add more args unless absolutely necessary
|
|
21
|
+
// args that output to files are strictly prohibited
|
|
22
|
+
static std::set<std::string> get_remote_preset_whitelist(const std::map<std::string, common_arg> & key_to_opt) {
|
|
23
|
+
static const std::set<std::string> allowed_options = {
|
|
24
|
+
"model-url",
|
|
25
|
+
"hf-repo",
|
|
26
|
+
"hf-repo-draft",
|
|
27
|
+
"hf-repo-v", // vocoder
|
|
28
|
+
"hf-file-v", // vocoder
|
|
29
|
+
"mmproj-url",
|
|
30
|
+
"pooling",
|
|
31
|
+
"jinja",
|
|
32
|
+
"batch-size",
|
|
33
|
+
"ubatch-size",
|
|
34
|
+
"cache-reuse",
|
|
35
|
+
"chat-template-kwargs",
|
|
36
|
+
"mmap",
|
|
37
|
+
// note: sampling params are automatically allowed by default
|
|
38
|
+
// negated args will be added automatically if the positive arg is specified above
|
|
39
|
+
};
|
|
40
|
+
|
|
41
|
+
std::set<std::string> allowed_keys;
|
|
42
|
+
|
|
43
|
+
for (const auto & it : key_to_opt) {
|
|
44
|
+
const std::string & key = it.first;
|
|
45
|
+
const common_arg & opt = it.second;
|
|
46
|
+
if (allowed_options.find(key) != allowed_options.end() || opt.is_sparam) {
|
|
47
|
+
allowed_keys.insert(key);
|
|
48
|
+
// also add variant keys (args without leading dashes and env vars)
|
|
49
|
+
for (const auto & arg : opt.get_args()) {
|
|
50
|
+
allowed_keys.insert(rm_leading_dashes(arg));
|
|
51
|
+
}
|
|
52
|
+
for (const auto & env : opt.get_env()) {
|
|
53
|
+
allowed_keys.insert(env);
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
return allowed_keys;
|
|
59
|
+
}
|
|
60
|
+
|
|
19
61
|
std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
|
|
20
62
|
std::vector<std::string> args;
|
|
21
63
|
|
|
@@ -121,6 +163,29 @@ void common_preset::merge(const common_preset & other) {
|
|
|
121
163
|
}
|
|
122
164
|
}
|
|
123
165
|
|
|
166
|
+
void common_preset::apply_to_params(common_params & params) const {
|
|
167
|
+
for (const auto & [opt, val] : options) {
|
|
168
|
+
// apply each option to params
|
|
169
|
+
if (opt.handler_string) {
|
|
170
|
+
opt.handler_string(params, val);
|
|
171
|
+
} else if (opt.handler_int) {
|
|
172
|
+
opt.handler_int(params, std::stoi(val));
|
|
173
|
+
} else if (opt.handler_bool) {
|
|
174
|
+
opt.handler_bool(params, common_arg_utils::is_truthy(val));
|
|
175
|
+
} else if (opt.handler_str_str) {
|
|
176
|
+
// not supported yet
|
|
177
|
+
throw std::runtime_error(string_format(
|
|
178
|
+
"%s: option with two values is not supported yet",
|
|
179
|
+
__func__
|
|
180
|
+
));
|
|
181
|
+
} else if (opt.handler_void) {
|
|
182
|
+
opt.handler_void(params);
|
|
183
|
+
} else {
|
|
184
|
+
GGML_ABORT("unknown handler type");
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
|
|
124
189
|
static std::map<std::string, std::map<std::string, std::string>> parse_ini_from_file(const std::string & path) {
|
|
125
190
|
std::map<std::string, std::map<std::string, std::string>> parsed;
|
|
126
191
|
|
|
@@ -230,10 +295,16 @@ static std::string parse_bool_arg(const common_arg & arg, const std::string & ke
|
|
|
230
295
|
return value;
|
|
231
296
|
}
|
|
232
297
|
|
|
233
|
-
common_preset_context::common_preset_context(llama_example ex)
|
|
298
|
+
common_preset_context::common_preset_context(llama_example ex, bool only_remote_allowed)
|
|
234
299
|
: ctx_params(common_params_parser_init(default_params, ex)) {
|
|
235
300
|
common_params_add_preset_options(ctx_params.options);
|
|
236
301
|
key_to_opt = get_map_key_opt(ctx_params);
|
|
302
|
+
|
|
303
|
+
// setup allowed keys if only_remote_allowed is true
|
|
304
|
+
if (only_remote_allowed) {
|
|
305
|
+
filter_allowed_keys = true;
|
|
306
|
+
allowed_keys = get_remote_preset_whitelist(key_to_opt);
|
|
307
|
+
}
|
|
237
308
|
}
|
|
238
309
|
|
|
239
310
|
common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const {
|
|
@@ -249,7 +320,18 @@ common_presets common_preset_context::load_from_ini(const std::string & path, co
|
|
|
249
320
|
}
|
|
250
321
|
LOG_DBG("loading preset: %s\n", preset.name.c_str());
|
|
251
322
|
for (const auto & [key, value] : section.second) {
|
|
323
|
+
if (key == "version") {
|
|
324
|
+
// skip version key (reserved for future use)
|
|
325
|
+
continue;
|
|
326
|
+
}
|
|
327
|
+
|
|
252
328
|
LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str());
|
|
329
|
+
if (filter_allowed_keys && allowed_keys.find(key) == allowed_keys.end()) {
|
|
330
|
+
throw std::runtime_error(string_format(
|
|
331
|
+
"option '%s' is not allowed in remote presets",
|
|
332
|
+
key.c_str()
|
|
333
|
+
));
|
|
334
|
+
}
|
|
253
335
|
if (key_to_opt.find(key) != key_to_opt.end()) {
|
|
254
336
|
const auto & opt = key_to_opt.at(key);
|
|
255
337
|
if (is_bool_arg(opt)) {
|
|
@@ -259,7 +341,10 @@ common_presets common_preset_context::load_from_ini(const std::string & path, co
|
|
|
259
341
|
}
|
|
260
342
|
LOG_DBG("accepted option: %s = %s\n", key.c_str(), preset.options[opt].c_str());
|
|
261
343
|
} else {
|
|
262
|
-
|
|
344
|
+
throw std::runtime_error(string_format(
|
|
345
|
+
"option '%s' not recognized in preset '%s'",
|
|
346
|
+
key.c_str(), preset.name.c_str()
|
|
347
|
+
));
|
|
263
348
|
}
|
|
264
349
|
}
|
|
265
350
|
|
|
@@ -6,6 +6,7 @@
|
|
|
6
6
|
#include <string>
|
|
7
7
|
#include <vector>
|
|
8
8
|
#include <map>
|
|
9
|
+
#include <set>
|
|
9
10
|
|
|
10
11
|
//
|
|
11
12
|
// INI preset parser and writer
|
|
@@ -40,6 +41,9 @@ struct common_preset {
|
|
|
40
41
|
|
|
41
42
|
// merge another preset into this one, overwriting existing options
|
|
42
43
|
void merge(const common_preset & other);
|
|
44
|
+
|
|
45
|
+
// apply preset options to common_params
|
|
46
|
+
void apply_to_params(common_params & params) const;
|
|
43
47
|
};
|
|
44
48
|
|
|
45
49
|
// interface for multiple presets in one file
|
|
@@ -50,7 +54,12 @@ struct common_preset_context {
|
|
|
50
54
|
common_params default_params; // unused for now
|
|
51
55
|
common_params_context ctx_params;
|
|
52
56
|
std::map<std::string, common_arg> key_to_opt;
|
|
53
|
-
|
|
57
|
+
|
|
58
|
+
bool filter_allowed_keys = false;
|
|
59
|
+
std::set<std::string> allowed_keys;
|
|
60
|
+
|
|
61
|
+
// if only_remote_allowed is true, only accept whitelisted keys
|
|
62
|
+
common_preset_context(llama_example ex, bool only_remote_allowed = false);
|
|
54
63
|
|
|
55
64
|
// load presets from INI file
|
|
56
65
|
common_presets load_from_ini(const std::string & path, common_preset & global) const;
|
|
@@ -234,6 +234,11 @@
|
|
|
234
234
|
|
|
235
235
|
#if UINTPTR_MAX == 0xFFFFFFFF
|
|
236
236
|
#define GGML_MEM_ALIGN 4
|
|
237
|
+
#elif defined(__EMSCRIPTEN__)
|
|
238
|
+
// emscripten uses max_align_t == 8, so we need GGML_MEM_ALIGN == 8 for 64-bit wasm.
|
|
239
|
+
// (for 32-bit wasm, the first conditional is true and GGML_MEM_ALIGN stays 4.)
|
|
240
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/18628
|
|
241
|
+
#define GGML_MEM_ALIGN 8
|
|
237
242
|
#else
|
|
238
243
|
#define GGML_MEM_ALIGN 16
|
|
239
244
|
#endif
|
|
@@ -309,6 +309,7 @@ extern "C" {
|
|
|
309
309
|
// Keep the booleans together to avoid misalignment during copy-by-value.
|
|
310
310
|
bool vocab_only; // only load the vocabulary, no weights
|
|
311
311
|
bool use_mmap; // use mmap if possible
|
|
312
|
+
bool use_direct_io; // use direct io, takes precedence over use_mmap
|
|
312
313
|
bool use_mlock; // force system to keep model in RAM
|
|
313
314
|
bool check_tensors; // validate model tensor data
|
|
314
315
|
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
|
|
@@ -494,7 +495,7 @@ extern "C" {
|
|
|
494
495
|
struct llama_context_params * cparams,
|
|
495
496
|
float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements
|
|
496
497
|
struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements
|
|
497
|
-
size_t
|
|
498
|
+
size_t * margins, // margins of memory to leave per device in bytes
|
|
498
499
|
uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use
|
|
499
500
|
enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log
|
|
500
501
|
|
|
@@ -1291,7 +1292,9 @@ extern "C" {
|
|
|
1291
1292
|
// available samplers:
|
|
1292
1293
|
|
|
1293
1294
|
LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
|
|
1294
|
-
|
|
1295
|
+
|
|
1296
|
+
/// seed == LLAMA_DEFAULT_SEED to use a random seed.
|
|
1297
|
+
LLAMA_API struct llama_sampler * llama_sampler_init_dist(uint32_t seed);
|
|
1295
1298
|
|
|
1296
1299
|
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
|
1297
1300
|
/// Setting k <= 0 makes this a noop
|
|
@@ -81,6 +81,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
|
81
81
|
{ LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" },
|
|
82
82
|
{ LLM_ARCH_EXAONE, "exaone" },
|
|
83
83
|
{ LLM_ARCH_EXAONE4, "exaone4" },
|
|
84
|
+
{ LLM_ARCH_EXAONE_MOE, "exaone-moe" },
|
|
84
85
|
{ LLM_ARCH_RWKV6, "rwkv6" },
|
|
85
86
|
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
|
|
86
87
|
{ LLM_ARCH_RWKV7, "rwkv7" },
|
|
@@ -950,6 +951,8 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
|
|
950
951
|
LLM_TENSOR_ATTN_K_NORM,
|
|
951
952
|
LLM_TENSOR_ATTN_V,
|
|
952
953
|
LLM_TENSOR_ATTN_OUT,
|
|
954
|
+
LLM_TENSOR_ATTN_QKV,
|
|
955
|
+
LLM_TENSOR_ATTN_GATE,
|
|
953
956
|
LLM_TENSOR_FFN_NORM,
|
|
954
957
|
LLM_TENSOR_FFN_GATE_INP,
|
|
955
958
|
LLM_TENSOR_FFN_GATE_EXPS,
|
|
@@ -1726,6 +1729,38 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
|
|
1726
1729
|
LLM_TENSOR_FFN_UP,
|
|
1727
1730
|
LLM_TENSOR_FFN_POST_NORM,
|
|
1728
1731
|
};
|
|
1732
|
+
case LLM_ARCH_EXAONE_MOE:
|
|
1733
|
+
return {
|
|
1734
|
+
LLM_TENSOR_TOKEN_EMBD,
|
|
1735
|
+
LLM_TENSOR_OUTPUT_NORM,
|
|
1736
|
+
LLM_TENSOR_OUTPUT,
|
|
1737
|
+
LLM_TENSOR_ROPE_FREQS,
|
|
1738
|
+
LLM_TENSOR_ATTN_NORM,
|
|
1739
|
+
LLM_TENSOR_ATTN_Q,
|
|
1740
|
+
LLM_TENSOR_ATTN_Q_NORM,
|
|
1741
|
+
LLM_TENSOR_ATTN_K,
|
|
1742
|
+
LLM_TENSOR_ATTN_K_NORM,
|
|
1743
|
+
LLM_TENSOR_ATTN_V,
|
|
1744
|
+
LLM_TENSOR_ATTN_OUT,
|
|
1745
|
+
LLM_TENSOR_FFN_NORM,
|
|
1746
|
+
LLM_TENSOR_FFN_GATE,
|
|
1747
|
+
LLM_TENSOR_FFN_DOWN,
|
|
1748
|
+
LLM_TENSOR_FFN_UP,
|
|
1749
|
+
LLM_TENSOR_FFN_GATE_INP,
|
|
1750
|
+
LLM_TENSOR_FFN_GATE_EXPS,
|
|
1751
|
+
LLM_TENSOR_FFN_DOWN_EXPS,
|
|
1752
|
+
LLM_TENSOR_FFN_UP_EXPS,
|
|
1753
|
+
LLM_TENSOR_FFN_GATE_SHEXP,
|
|
1754
|
+
LLM_TENSOR_FFN_UP_SHEXP,
|
|
1755
|
+
LLM_TENSOR_FFN_DOWN_SHEXP,
|
|
1756
|
+
LLM_TENSOR_FFN_EXP_PROBS_B,
|
|
1757
|
+
LLM_TENSOR_NEXTN_EH_PROJ,
|
|
1758
|
+
LLM_TENSOR_NEXTN_EMBED_TOKENS,
|
|
1759
|
+
LLM_TENSOR_NEXTN_ENORM,
|
|
1760
|
+
LLM_TENSOR_NEXTN_HNORM,
|
|
1761
|
+
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
|
|
1762
|
+
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
|
|
1763
|
+
};
|
|
1729
1764
|
case LLM_ARCH_RWKV6:
|
|
1730
1765
|
return {
|
|
1731
1766
|
LLM_TENSOR_TOKEN_EMBD,
|
|
@@ -57,6 +57,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|
|
57
57
|
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
|
|
58
58
|
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
|
|
59
59
|
{ "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 },
|
|
60
|
+
{ "exaone-moe", LLM_CHAT_TEMPLATE_EXAONE_MOE },
|
|
60
61
|
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
|
|
61
62
|
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
|
|
62
63
|
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
|
@@ -137,6 +138,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|
|
137
138
|
} else if (tmpl_contains("[gMASK]<sop>")) {
|
|
138
139
|
return LLM_CHAT_TEMPLATE_CHATGLM_4;
|
|
139
140
|
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
|
|
141
|
+
if (tmpl_contains("<|tool_declare|>")) {
|
|
142
|
+
return LLM_CHAT_TEMPLATE_EXAONE_MOE;
|
|
143
|
+
}
|
|
140
144
|
return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
|
|
141
145
|
} else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) {
|
|
142
146
|
return LLM_CHAT_TEMPLATE_GLMEDGE;
|
|
@@ -576,6 +580,22 @@ int32_t llm_chat_apply_template(
|
|
|
576
580
|
if (add_ass) {
|
|
577
581
|
ss << "[|assistant|]";
|
|
578
582
|
}
|
|
583
|
+
} else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_MOE) {
|
|
584
|
+
for (auto message : chat) {
|
|
585
|
+
std::string role(message->role);
|
|
586
|
+
if (role == "system") {
|
|
587
|
+
ss << "<|system|>\n" << trim(message->content) << "<|endofturn|>\n";
|
|
588
|
+
} else if (role == "user") {
|
|
589
|
+
ss << "<|user|>\n" << trim(message->content) << "<|endofturn|>\n";
|
|
590
|
+
} else if (role == "assistant") {
|
|
591
|
+
ss << "<|assistant|>\n" << trim(message->content) << "<|endofturn|>\n";
|
|
592
|
+
} else if (role == "tool") {
|
|
593
|
+
ss << "<|tool|>\n" << trim(message->content) << "<|endofturn|>\n";
|
|
594
|
+
}
|
|
595
|
+
}
|
|
596
|
+
if (add_ass) {
|
|
597
|
+
ss << "<|assistant|>\n";
|
|
598
|
+
}
|
|
579
599
|
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
|
|
580
600
|
// this template requires the model to have "\n\n" as EOT token
|
|
581
601
|
for (size_t i = 0; i < chat.size(); i++) {
|
|
@@ -96,11 +96,9 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|
|
96
96
|
|
|
97
97
|
int32_t * data = (int32_t *) pos_bucket->data;
|
|
98
98
|
|
|
99
|
-
for (int
|
|
100
|
-
for (int
|
|
101
|
-
|
|
102
|
-
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
|
|
103
|
-
}
|
|
99
|
+
for (int j = 0; j < n_tokens; ++j) {
|
|
100
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
101
|
+
data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
|
|
104
102
|
}
|
|
105
103
|
}
|
|
106
104
|
}
|
|
@@ -323,34 +321,32 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
|
323
321
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
324
322
|
|
|
325
323
|
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
|
|
326
|
-
for (int
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
const llama_pos p1 = ubatch->pos[i1];
|
|
324
|
+
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
|
325
|
+
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
|
326
|
+
const llama_pos p1 = ubatch->pos[i1];
|
|
330
327
|
|
|
331
|
-
|
|
328
|
+
const uint64_t idst = i1*n_kv;
|
|
332
329
|
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
// mask different sequences
|
|
338
|
-
if (s0 != s1) {
|
|
339
|
-
continue;
|
|
340
|
-
}
|
|
330
|
+
for (int i0 = 0; i0 < n_tokens; ++i0) {
|
|
331
|
+
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
|
332
|
+
const llama_pos p0 = ubatch->pos[i0];
|
|
341
333
|
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
334
|
+
// mask different sequences
|
|
335
|
+
if (s0 != s1) {
|
|
336
|
+
continue;
|
|
337
|
+
}
|
|
346
338
|
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
339
|
+
// mask future tokens
|
|
340
|
+
if (cparams.causal_attn && p0 > p1) {
|
|
341
|
+
continue;
|
|
342
|
+
}
|
|
351
343
|
|
|
352
|
-
|
|
344
|
+
// apply SWA if any
|
|
345
|
+
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
|
|
346
|
+
continue;
|
|
353
347
|
}
|
|
348
|
+
|
|
349
|
+
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
|
|
354
350
|
}
|
|
355
351
|
}
|
|
356
352
|
};
|
|
@@ -454,27 +450,19 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
|
454
450
|
|
|
455
451
|
float * data = (float *) cross_kq_mask->data;
|
|
456
452
|
|
|
457
|
-
for (int
|
|
458
|
-
for (int
|
|
459
|
-
|
|
460
|
-
float f = -INFINITY;
|
|
453
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
454
|
+
for (int j = 0; j < n_enc; ++j) {
|
|
455
|
+
float f = -INFINITY;
|
|
461
456
|
|
|
462
|
-
|
|
463
|
-
|
|
457
|
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
|
458
|
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
464
459
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
}
|
|
460
|
+
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
|
|
461
|
+
f = 0.0f;
|
|
468
462
|
}
|
|
469
|
-
|
|
470
|
-
data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
|
|
471
463
|
}
|
|
472
|
-
}
|
|
473
464
|
|
|
474
|
-
|
|
475
|
-
for (int j = 0; j < n_enc; ++j) {
|
|
476
|
-
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
|
477
|
-
}
|
|
465
|
+
data[i*n_enc + j] = f;
|
|
478
466
|
}
|
|
479
467
|
}
|
|
480
468
|
}
|