@fugood/llama.node 1.4.13 → 1.4.15

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. package/lib/binding.ts +23 -2
  2. package/lib/index.js +2 -1
  3. package/lib/index.ts +8 -1
  4. package/lib/parallel.ts +2 -2
  5. package/package.json +15 -15
  6. package/scripts/llama.cpp.patch +9 -12
  7. package/src/LlamaContext.cpp +16 -4
  8. package/src/llama.cpp/CMakeLists.txt +24 -8
  9. package/src/llama.cpp/common/CMakeLists.txt +3 -34
  10. package/src/llama.cpp/common/arg.cpp +183 -60
  11. package/src/llama.cpp/common/arg.h +0 -8
  12. package/src/llama.cpp/common/chat-parser.cpp +115 -0
  13. package/src/llama.cpp/common/chat.cpp +67 -0
  14. package/src/llama.cpp/common/chat.h +1 -0
  15. package/src/llama.cpp/common/common.cpp +2 -1
  16. package/src/llama.cpp/common/common.h +12 -7
  17. package/src/llama.cpp/common/debug.cpp +165 -0
  18. package/src/llama.cpp/common/debug.h +43 -0
  19. package/src/llama.cpp/common/download.cpp +88 -369
  20. package/src/llama.cpp/common/download.h +32 -5
  21. package/src/llama.cpp/common/preset.cpp +87 -2
  22. package/src/llama.cpp/common/preset.h +10 -1
  23. package/src/llama.cpp/ggml/include/ggml.h +5 -0
  24. package/src/llama.cpp/include/llama.h +5 -2
  25. package/src/llama.cpp/src/CMakeLists.txt +1 -0
  26. package/src/llama.cpp/src/llama-arch.cpp +35 -0
  27. package/src/llama.cpp/src/llama-arch.h +1 -0
  28. package/src/llama.cpp/src/llama-chat.cpp +20 -0
  29. package/src/llama.cpp/src/llama-chat.h +1 -0
  30. package/src/llama.cpp/src/llama-graph.cpp +31 -43
  31. package/src/llama.cpp/src/llama-mmap.cpp +78 -42
  32. package/src/llama.cpp/src/llama-mmap.h +5 -4
  33. package/src/llama.cpp/src/llama-model-loader.cpp +17 -5
  34. package/src/llama.cpp/src/llama-model-loader.h +2 -0
  35. package/src/llama.cpp/src/llama-model.cpp +225 -101
  36. package/src/llama.cpp/src/llama-quant.cpp +1 -1
  37. package/src/llama.cpp/src/llama-sampling.cpp +1 -1
  38. package/src/llama.cpp/src/llama-vocab.cpp +37 -24
  39. package/src/llama.cpp/src/llama-vocab.h +1 -0
  40. package/src/llama.cpp/src/llama.cpp +63 -27
  41. package/src/llama.cpp/src/models/exaone-moe.cpp +146 -0
  42. package/src/llama.cpp/src/models/gemma3n-iswa.cpp +13 -3
  43. package/src/llama.cpp/src/models/models.h +13 -2
  44. package/src/llama.cpp/src/models/qwen3next.cpp +198 -182
@@ -1,12 +1,27 @@
1
1
  #pragma once
2
2
 
3
3
  #include <string>
4
+ #include <vector>
4
5
 
5
6
  struct common_params_model;
6
7
 
7
- //
8
- // download functionalities
9
- //
8
+ using common_header = std::pair<std::string, std::string>;
9
+ using common_header_list = std::vector<common_header>;
10
+
11
+ struct common_remote_params {
12
+ common_header_list headers;
13
+ long timeout = 0; // in seconds, 0 means no timeout
14
+ long max_size = 0; // unlimited if 0
15
+ };
16
+
17
+ // get remote file content, returns <http_code, raw_response_body>
18
+ std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params);
19
+
20
+ // split HF repo with tag into <repo, tag>
21
+ // for example: "user/model:tag" -> <"user/model", "tag">
22
+ // if tag is not present, default to "latest"
23
+ // example: "user/model" -> <"user/model", "latest">
24
+ std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag);
10
25
 
11
26
  struct common_cached_model_info {
12
27
  std::string manifest_path;
@@ -41,17 +56,29 @@ struct common_hf_file_res {
41
56
  common_hf_file_res common_get_hf_file(
42
57
  const std::string & hf_repo_with_tag,
43
58
  const std::string & bearer_token,
44
- bool offline);
59
+ bool offline,
60
+ const common_header_list & headers = {}
61
+ );
45
62
 
46
63
  // returns true if download succeeded
47
64
  bool common_download_model(
48
65
  const common_params_model & model,
49
66
  const std::string & bearer_token,
50
- bool offline);
67
+ bool offline,
68
+ const common_header_list & headers = {}
69
+ );
51
70
 
52
71
  // returns list of cached models
53
72
  std::vector<common_cached_model_info> common_list_cached_models();
54
73
 
74
+ // download single file from url to local path
75
+ // returns status code or -1 on error
76
+ int common_download_file_single(const std::string & url,
77
+ const std::string & path,
78
+ const std::string & bearer_token,
79
+ bool offline,
80
+ const common_header_list & headers = {});
81
+
55
82
  // resolve and download model from Docker registry
56
83
  // return local path to downloaded model file
57
84
  std::string common_docker_resolve_model(const std::string & docker);
@@ -16,6 +16,48 @@ static std::string rm_leading_dashes(const std::string & str) {
16
16
  return str.substr(pos);
17
17
  }
18
18
 
19
+ // only allow a subset of args for remote presets for security reasons
20
+ // do not add more args unless absolutely necessary
21
+ // args that output to files are strictly prohibited
22
+ static std::set<std::string> get_remote_preset_whitelist(const std::map<std::string, common_arg> & key_to_opt) {
23
+ static const std::set<std::string> allowed_options = {
24
+ "model-url",
25
+ "hf-repo",
26
+ "hf-repo-draft",
27
+ "hf-repo-v", // vocoder
28
+ "hf-file-v", // vocoder
29
+ "mmproj-url",
30
+ "pooling",
31
+ "jinja",
32
+ "batch-size",
33
+ "ubatch-size",
34
+ "cache-reuse",
35
+ "chat-template-kwargs",
36
+ "mmap",
37
+ // note: sampling params are automatically allowed by default
38
+ // negated args will be added automatically if the positive arg is specified above
39
+ };
40
+
41
+ std::set<std::string> allowed_keys;
42
+
43
+ for (const auto & it : key_to_opt) {
44
+ const std::string & key = it.first;
45
+ const common_arg & opt = it.second;
46
+ if (allowed_options.find(key) != allowed_options.end() || opt.is_sparam) {
47
+ allowed_keys.insert(key);
48
+ // also add variant keys (args without leading dashes and env vars)
49
+ for (const auto & arg : opt.get_args()) {
50
+ allowed_keys.insert(rm_leading_dashes(arg));
51
+ }
52
+ for (const auto & env : opt.get_env()) {
53
+ allowed_keys.insert(env);
54
+ }
55
+ }
56
+ }
57
+
58
+ return allowed_keys;
59
+ }
60
+
19
61
  std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
20
62
  std::vector<std::string> args;
21
63
 
@@ -121,6 +163,29 @@ void common_preset::merge(const common_preset & other) {
121
163
  }
122
164
  }
123
165
 
166
+ void common_preset::apply_to_params(common_params & params) const {
167
+ for (const auto & [opt, val] : options) {
168
+ // apply each option to params
169
+ if (opt.handler_string) {
170
+ opt.handler_string(params, val);
171
+ } else if (opt.handler_int) {
172
+ opt.handler_int(params, std::stoi(val));
173
+ } else if (opt.handler_bool) {
174
+ opt.handler_bool(params, common_arg_utils::is_truthy(val));
175
+ } else if (opt.handler_str_str) {
176
+ // not supported yet
177
+ throw std::runtime_error(string_format(
178
+ "%s: option with two values is not supported yet",
179
+ __func__
180
+ ));
181
+ } else if (opt.handler_void) {
182
+ opt.handler_void(params);
183
+ } else {
184
+ GGML_ABORT("unknown handler type");
185
+ }
186
+ }
187
+ }
188
+
124
189
  static std::map<std::string, std::map<std::string, std::string>> parse_ini_from_file(const std::string & path) {
125
190
  std::map<std::string, std::map<std::string, std::string>> parsed;
126
191
 
@@ -230,10 +295,16 @@ static std::string parse_bool_arg(const common_arg & arg, const std::string & ke
230
295
  return value;
231
296
  }
232
297
 
233
- common_preset_context::common_preset_context(llama_example ex)
298
+ common_preset_context::common_preset_context(llama_example ex, bool only_remote_allowed)
234
299
  : ctx_params(common_params_parser_init(default_params, ex)) {
235
300
  common_params_add_preset_options(ctx_params.options);
236
301
  key_to_opt = get_map_key_opt(ctx_params);
302
+
303
+ // setup allowed keys if only_remote_allowed is true
304
+ if (only_remote_allowed) {
305
+ filter_allowed_keys = true;
306
+ allowed_keys = get_remote_preset_whitelist(key_to_opt);
307
+ }
237
308
  }
238
309
 
239
310
  common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const {
@@ -249,7 +320,18 @@ common_presets common_preset_context::load_from_ini(const std::string & path, co
249
320
  }
250
321
  LOG_DBG("loading preset: %s\n", preset.name.c_str());
251
322
  for (const auto & [key, value] : section.second) {
323
+ if (key == "version") {
324
+ // skip version key (reserved for future use)
325
+ continue;
326
+ }
327
+
252
328
  LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str());
329
+ if (filter_allowed_keys && allowed_keys.find(key) == allowed_keys.end()) {
330
+ throw std::runtime_error(string_format(
331
+ "option '%s' is not allowed in remote presets",
332
+ key.c_str()
333
+ ));
334
+ }
253
335
  if (key_to_opt.find(key) != key_to_opt.end()) {
254
336
  const auto & opt = key_to_opt.at(key);
255
337
  if (is_bool_arg(opt)) {
@@ -259,7 +341,10 @@ common_presets common_preset_context::load_from_ini(const std::string & path, co
259
341
  }
260
342
  LOG_DBG("accepted option: %s = %s\n", key.c_str(), preset.options[opt].c_str());
261
343
  } else {
262
- // TODO: maybe warn about unknown key?
344
+ throw std::runtime_error(string_format(
345
+ "option '%s' not recognized in preset '%s'",
346
+ key.c_str(), preset.name.c_str()
347
+ ));
263
348
  }
264
349
  }
265
350
 
@@ -6,6 +6,7 @@
6
6
  #include <string>
7
7
  #include <vector>
8
8
  #include <map>
9
+ #include <set>
9
10
 
10
11
  //
11
12
  // INI preset parser and writer
@@ -40,6 +41,9 @@ struct common_preset {
40
41
 
41
42
  // merge another preset into this one, overwriting existing options
42
43
  void merge(const common_preset & other);
44
+
45
+ // apply preset options to common_params
46
+ void apply_to_params(common_params & params) const;
43
47
  };
44
48
 
45
49
  // interface for multiple presets in one file
@@ -50,7 +54,12 @@ struct common_preset_context {
50
54
  common_params default_params; // unused for now
51
55
  common_params_context ctx_params;
52
56
  std::map<std::string, common_arg> key_to_opt;
53
- common_preset_context(llama_example ex);
57
+
58
+ bool filter_allowed_keys = false;
59
+ std::set<std::string> allowed_keys;
60
+
61
+ // if only_remote_allowed is true, only accept whitelisted keys
62
+ common_preset_context(llama_example ex, bool only_remote_allowed = false);
54
63
 
55
64
  // load presets from INI file
56
65
  common_presets load_from_ini(const std::string & path, common_preset & global) const;
@@ -234,6 +234,11 @@
234
234
 
235
235
  #if UINTPTR_MAX == 0xFFFFFFFF
236
236
  #define GGML_MEM_ALIGN 4
237
+ #elif defined(__EMSCRIPTEN__)
238
+ // emscripten uses max_align_t == 8, so we need GGML_MEM_ALIGN == 8 for 64-bit wasm.
239
+ // (for 32-bit wasm, the first conditional is true and GGML_MEM_ALIGN stays 4.)
240
+ // ref: https://github.com/ggml-org/llama.cpp/pull/18628
241
+ #define GGML_MEM_ALIGN 8
237
242
  #else
238
243
  #define GGML_MEM_ALIGN 16
239
244
  #endif
@@ -309,6 +309,7 @@ extern "C" {
309
309
  // Keep the booleans together to avoid misalignment during copy-by-value.
310
310
  bool vocab_only; // only load the vocabulary, no weights
311
311
  bool use_mmap; // use mmap if possible
312
+ bool use_direct_io; // use direct io, takes precedence over use_mmap
312
313
  bool use_mlock; // force system to keep model in RAM
313
314
  bool check_tensors; // validate model tensor data
314
315
  bool use_extra_bufts; // use extra buffer types (used for weight repacking)
@@ -494,7 +495,7 @@ extern "C" {
494
495
  struct llama_context_params * cparams,
495
496
  float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements
496
497
  struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements
497
- size_t margin, // margin of memory to leave per device in bytes
498
+ size_t * margins, // margins of memory to leave per device in bytes
498
499
  uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use
499
500
  enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log
500
501
 
@@ -1291,7 +1292,9 @@ extern "C" {
1291
1292
  // available samplers:
1292
1293
 
1293
1294
  LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
1294
- LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
1295
+
1296
+ /// seed == LLAMA_DEFAULT_SEED to use a random seed.
1297
+ LLAMA_API struct llama_sampler * llama_sampler_init_dist(uint32_t seed);
1295
1298
 
1296
1299
  /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1297
1300
  /// Setting k <= 0 makes this a noop
@@ -62,6 +62,7 @@ add_library(llama
62
62
  models/ernie4-5.cpp
63
63
  models/exaone.cpp
64
64
  models/exaone4.cpp
65
+ models/exaone-moe.cpp
65
66
  models/falcon-h1.cpp
66
67
  models/falcon.cpp
67
68
  models/gemma-embedding.cpp
@@ -81,6 +81,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
81
81
  { LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" },
82
82
  { LLM_ARCH_EXAONE, "exaone" },
83
83
  { LLM_ARCH_EXAONE4, "exaone4" },
84
+ { LLM_ARCH_EXAONE_MOE, "exaone-moe" },
84
85
  { LLM_ARCH_RWKV6, "rwkv6" },
85
86
  { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
86
87
  { LLM_ARCH_RWKV7, "rwkv7" },
@@ -950,6 +951,8 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
950
951
  LLM_TENSOR_ATTN_K_NORM,
951
952
  LLM_TENSOR_ATTN_V,
952
953
  LLM_TENSOR_ATTN_OUT,
954
+ LLM_TENSOR_ATTN_QKV,
955
+ LLM_TENSOR_ATTN_GATE,
953
956
  LLM_TENSOR_FFN_NORM,
954
957
  LLM_TENSOR_FFN_GATE_INP,
955
958
  LLM_TENSOR_FFN_GATE_EXPS,
@@ -1726,6 +1729,38 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
1726
1729
  LLM_TENSOR_FFN_UP,
1727
1730
  LLM_TENSOR_FFN_POST_NORM,
1728
1731
  };
1732
+ case LLM_ARCH_EXAONE_MOE:
1733
+ return {
1734
+ LLM_TENSOR_TOKEN_EMBD,
1735
+ LLM_TENSOR_OUTPUT_NORM,
1736
+ LLM_TENSOR_OUTPUT,
1737
+ LLM_TENSOR_ROPE_FREQS,
1738
+ LLM_TENSOR_ATTN_NORM,
1739
+ LLM_TENSOR_ATTN_Q,
1740
+ LLM_TENSOR_ATTN_Q_NORM,
1741
+ LLM_TENSOR_ATTN_K,
1742
+ LLM_TENSOR_ATTN_K_NORM,
1743
+ LLM_TENSOR_ATTN_V,
1744
+ LLM_TENSOR_ATTN_OUT,
1745
+ LLM_TENSOR_FFN_NORM,
1746
+ LLM_TENSOR_FFN_GATE,
1747
+ LLM_TENSOR_FFN_DOWN,
1748
+ LLM_TENSOR_FFN_UP,
1749
+ LLM_TENSOR_FFN_GATE_INP,
1750
+ LLM_TENSOR_FFN_GATE_EXPS,
1751
+ LLM_TENSOR_FFN_DOWN_EXPS,
1752
+ LLM_TENSOR_FFN_UP_EXPS,
1753
+ LLM_TENSOR_FFN_GATE_SHEXP,
1754
+ LLM_TENSOR_FFN_UP_SHEXP,
1755
+ LLM_TENSOR_FFN_DOWN_SHEXP,
1756
+ LLM_TENSOR_FFN_EXP_PROBS_B,
1757
+ LLM_TENSOR_NEXTN_EH_PROJ,
1758
+ LLM_TENSOR_NEXTN_EMBED_TOKENS,
1759
+ LLM_TENSOR_NEXTN_ENORM,
1760
+ LLM_TENSOR_NEXTN_HNORM,
1761
+ LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
1762
+ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
1763
+ };
1729
1764
  case LLM_ARCH_RWKV6:
1730
1765
  return {
1731
1766
  LLM_TENSOR_TOKEN_EMBD,
@@ -85,6 +85,7 @@ enum llm_arch {
85
85
  LLM_ARCH_NEMOTRON_H_MOE,
86
86
  LLM_ARCH_EXAONE,
87
87
  LLM_ARCH_EXAONE4,
88
+ LLM_ARCH_EXAONE_MOE,
88
89
  LLM_ARCH_RWKV6,
89
90
  LLM_ARCH_RWKV6QWEN2,
90
91
  LLM_ARCH_RWKV7,
@@ -57,6 +57,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
57
57
  { "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
58
58
  { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
59
59
  { "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 },
60
+ { "exaone-moe", LLM_CHAT_TEMPLATE_EXAONE_MOE },
60
61
  { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
61
62
  { "granite", LLM_CHAT_TEMPLATE_GRANITE },
62
63
  { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
@@ -137,6 +138,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
137
138
  } else if (tmpl_contains("[gMASK]<sop>")) {
138
139
  return LLM_CHAT_TEMPLATE_CHATGLM_4;
139
140
  } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
141
+ if (tmpl_contains("<|tool_declare|>")) {
142
+ return LLM_CHAT_TEMPLATE_EXAONE_MOE;
143
+ }
140
144
  return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
141
145
  } else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) {
142
146
  return LLM_CHAT_TEMPLATE_GLMEDGE;
@@ -576,6 +580,22 @@ int32_t llm_chat_apply_template(
576
580
  if (add_ass) {
577
581
  ss << "[|assistant|]";
578
582
  }
583
+ } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_MOE) {
584
+ for (auto message : chat) {
585
+ std::string role(message->role);
586
+ if (role == "system") {
587
+ ss << "<|system|>\n" << trim(message->content) << "<|endofturn|>\n";
588
+ } else if (role == "user") {
589
+ ss << "<|user|>\n" << trim(message->content) << "<|endofturn|>\n";
590
+ } else if (role == "assistant") {
591
+ ss << "<|assistant|>\n" << trim(message->content) << "<|endofturn|>\n";
592
+ } else if (role == "tool") {
593
+ ss << "<|tool|>\n" << trim(message->content) << "<|endofturn|>\n";
594
+ }
595
+ }
596
+ if (add_ass) {
597
+ ss << "<|assistant|>\n";
598
+ }
579
599
  } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
580
600
  // this template requires the model to have "\n\n" as EOT token
581
601
  for (size_t i = 0; i < chat.size(); i++) {
@@ -36,6 +36,7 @@ enum llm_chat_template {
36
36
  LLM_CHAT_TEMPLATE_MINICPM,
37
37
  LLM_CHAT_TEMPLATE_EXAONE_3,
38
38
  LLM_CHAT_TEMPLATE_EXAONE_4,
39
+ LLM_CHAT_TEMPLATE_EXAONE_MOE,
39
40
  LLM_CHAT_TEMPLATE_RWKV_WORLD,
40
41
  LLM_CHAT_TEMPLATE_GRANITE,
41
42
  LLM_CHAT_TEMPLATE_GIGACHAT,
@@ -96,11 +96,9 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
96
96
 
97
97
  int32_t * data = (int32_t *) pos_bucket->data;
98
98
 
99
- for (int h = 0; h < 1; ++h) {
100
- for (int j = 0; j < n_tokens; ++j) {
101
- for (int i = 0; i < n_tokens; ++i) {
102
- data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
103
- }
99
+ for (int j = 0; j < n_tokens; ++j) {
100
+ for (int i = 0; i < n_tokens; ++i) {
101
+ data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
104
102
  }
105
103
  }
106
104
  }
@@ -323,34 +321,32 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
323
321
  const int64_t n_tokens = ubatch->n_tokens;
324
322
 
325
323
  const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
326
- for (int h = 0; h < 1; ++h) {
327
- for (int i1 = 0; i1 < n_tokens; ++i1) {
328
- const llama_seq_id s1 = ubatch->seq_id[i1][0];
329
- const llama_pos p1 = ubatch->pos[i1];
324
+ for (int i1 = 0; i1 < n_tokens; ++i1) {
325
+ const llama_seq_id s1 = ubatch->seq_id[i1][0];
326
+ const llama_pos p1 = ubatch->pos[i1];
330
327
 
331
- const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
328
+ const uint64_t idst = i1*n_kv;
332
329
 
333
- for (int i0 = 0; i0 < n_tokens; ++i0) {
334
- const llama_seq_id s0 = ubatch->seq_id[i0][0];
335
- const llama_pos p0 = ubatch->pos[i0];
336
-
337
- // mask different sequences
338
- if (s0 != s1) {
339
- continue;
340
- }
330
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
331
+ const llama_seq_id s0 = ubatch->seq_id[i0][0];
332
+ const llama_pos p0 = ubatch->pos[i0];
341
333
 
342
- // mask future tokens
343
- if (cparams.causal_attn && p0 > p1) {
344
- continue;
345
- }
334
+ // mask different sequences
335
+ if (s0 != s1) {
336
+ continue;
337
+ }
346
338
 
347
- // apply SWA if any
348
- if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
349
- continue;
350
- }
339
+ // mask future tokens
340
+ if (cparams.causal_attn && p0 > p1) {
341
+ continue;
342
+ }
351
343
 
352
- data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
344
+ // apply SWA if any
345
+ if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
346
+ continue;
353
347
  }
348
+
349
+ data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
354
350
  }
355
351
  }
356
352
  };
@@ -454,27 +450,19 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
454
450
 
455
451
  float * data = (float *) cross_kq_mask->data;
456
452
 
457
- for (int h = 0; h < 1; ++h) {
458
- for (int i = 0; i < n_tokens; ++i) {
459
- for (int j = 0; j < n_enc; ++j) {
460
- float f = -INFINITY;
453
+ for (int i = 0; i < n_tokens; ++i) {
454
+ for (int j = 0; j < n_enc; ++j) {
455
+ float f = -INFINITY;
461
456
 
462
- for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
463
- const llama_seq_id seq_id = ubatch->seq_id[i][s];
457
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
458
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
464
459
 
465
- if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
466
- f = 0.0f;
467
- }
460
+ if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
461
+ f = 0.0f;
468
462
  }
469
-
470
- data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
471
463
  }
472
- }
473
464
 
474
- for (int i = n_tokens; i < n_tokens; ++i) {
475
- for (int j = 0; j < n_enc; ++j) {
476
- data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
477
- }
465
+ data[i*n_enc + j] = f;
478
466
  }
479
467
  }
480
468
  }