@fugood/llama.node 1.4.13 → 1.4.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -157,6 +157,10 @@ static std::string read_etag(const std::string & path) {
157
157
  return none;
158
158
  }
159
159
 
160
+ static bool is_http_status_ok(int status) {
161
+ return status >= 200 && status < 400;
162
+ }
163
+
160
164
  #ifdef LLAMA_USE_CURL
161
165
 
162
166
  //
@@ -306,11 +310,14 @@ static bool common_download_head(CURL * curl,
306
310
  }
307
311
 
308
312
  // download one single file from remote URL to local path
309
- static bool common_download_file_single_online(const std::string & url,
313
+ // returns status code or -1 on error
314
+ static int common_download_file_single_online(const std::string & url,
310
315
  const std::string & path,
311
- const std::string & bearer_token) {
316
+ const std::string & bearer_token,
317
+ const common_header_list & custom_headers) {
312
318
  static const int max_attempts = 3;
313
319
  static const int retry_delay_seconds = 2;
320
+
314
321
  for (int i = 0; i < max_attempts; ++i) {
315
322
  std::string etag;
316
323
 
@@ -330,6 +337,11 @@ static bool common_download_file_single_online(const std::string & url,
330
337
  common_load_model_from_url_headers headers;
331
338
  curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
332
339
  curl_slist_ptr http_headers;
340
+
341
+ for (const auto & h : custom_headers) {
342
+ std::string s = h.first + ": " + h.second;
343
+ http_headers.ptr = curl_slist_append(http_headers.ptr, s.c_str());
344
+ }
333
345
  const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token);
334
346
  if (!was_perform_successful) {
335
347
  head_request_ok = false;
@@ -365,7 +377,7 @@ static bool common_download_file_single_online(const std::string & url,
365
377
  LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
366
378
  if (remove(path.c_str()) != 0) {
367
379
  LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
368
- return false;
380
+ return -1;
369
381
  }
370
382
  }
371
383
 
@@ -374,14 +386,14 @@ static bool common_download_file_single_online(const std::string & url,
374
386
  if (std::filesystem::exists(path_temporary)) {
375
387
  if (remove(path_temporary.c_str()) != 0) {
376
388
  LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
377
- return false;
389
+ return -1;
378
390
  }
379
391
  }
380
392
 
381
393
  if (std::filesystem::exists(path)) {
382
394
  if (remove(path.c_str()) != 0) {
383
395
  LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
384
- return false;
396
+ return -1;
385
397
  }
386
398
  }
387
399
  }
@@ -408,23 +420,27 @@ static bool common_download_file_single_online(const std::string & url,
408
420
 
409
421
  long http_code = 0;
410
422
  curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
411
- if (http_code < 200 || http_code >= 400) {
423
+
424
+ int status = static_cast<int>(http_code);
425
+ if (!is_http_status_ok(http_code)) {
412
426
  LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
413
- return false;
427
+ return status; // TODO: maybe only return on certain codes
414
428
  }
415
429
 
416
430
  if (rename(path_temporary.c_str(), path.c_str()) != 0) {
417
431
  LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
418
- return false;
432
+ return -1;
419
433
  }
434
+
435
+ return static_cast<int>(http_code);
420
436
  } else {
421
437
  LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
422
- }
423
438
 
424
- break;
439
+ return 304; // Not Modified - fake cached response
440
+ }
425
441
  }
426
442
 
427
- return true;
443
+ return -1; // max attempts reached
428
444
  }
429
445
 
430
446
  std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params) {
@@ -454,8 +470,10 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
454
470
  curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size);
455
471
  }
456
472
  http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
473
+
457
474
  for (const auto & header : params.headers) {
458
- http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str());
475
+ std::string header_ = header.first + ": " + header.second;
476
+ http_headers.ptr = curl_slist_append(http_headers.ptr, header_.c_str());
459
477
  }
460
478
  curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
461
479
 
@@ -617,9 +635,11 @@ static bool common_pull_file(httplib::Client & cli,
617
635
  }
618
636
 
619
637
  // download one single file from remote URL to local path
620
- static bool common_download_file_single_online(const std::string & url,
638
+ // returns status code or -1 on error
639
+ static int common_download_file_single_online(const std::string & url,
621
640
  const std::string & path,
622
- const std::string & bearer_token) {
641
+ const std::string & bearer_token,
642
+ const common_header_list & custom_headers) {
623
643
  static const int max_attempts = 3;
624
644
  static const int retry_delay_seconds = 2;
625
645
 
@@ -629,6 +649,9 @@ static bool common_download_file_single_online(const std::string & url,
629
649
  if (!bearer_token.empty()) {
630
650
  default_headers.insert({"Authorization", "Bearer " + bearer_token});
631
651
  }
652
+ for (const auto & h : custom_headers) {
653
+ default_headers.emplace(h.first, h.second);
654
+ }
632
655
  cli.set_default_headers(default_headers);
633
656
 
634
657
  const bool file_exists = std::filesystem::exists(path);
@@ -647,8 +670,10 @@ static bool common_download_file_single_online(const std::string & url,
647
670
  LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
648
671
  if (file_exists) {
649
672
  LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
650
- return true;
673
+ return 304; // 304 Not Modified - fake cached response
651
674
  }
675
+ return head->status; // cannot use cached file, return raw status code
676
+ // TODO: maybe retry only on certain codes
652
677
  }
653
678
 
654
679
  std::string etag;
@@ -680,12 +705,12 @@ static bool common_download_file_single_online(const std::string & url,
680
705
  if (file_exists) {
681
706
  if (!should_download_from_scratch) {
682
707
  LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
683
- return true;
708
+ return 304; // 304 Not Modified - fake cached response
684
709
  }
685
710
  LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
686
711
  if (remove(path.c_str()) != 0) {
687
712
  LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
688
- return false;
713
+ return -1;
689
714
  }
690
715
  }
691
716
 
@@ -697,7 +722,7 @@ static bool common_download_file_single_online(const std::string & url,
697
722
  existing_size = std::filesystem::file_size(path_temporary);
698
723
  } else if (remove(path_temporary.c_str()) != 0) {
699
724
  LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
700
- return false;
725
+ return -1;
701
726
  }
702
727
  }
703
728
 
@@ -718,15 +743,16 @@ static bool common_download_file_single_online(const std::string & url,
718
743
 
719
744
  if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
720
745
  LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
721
- return false;
746
+ return -1;
722
747
  }
723
748
  if (!etag.empty()) {
724
749
  write_etag(path, etag);
725
750
  }
726
- break;
751
+
752
+ return head->status; // TODO: use actual GET status?
727
753
  }
728
754
 
729
- return true;
755
+ return -1; // max attempts reached
730
756
  }
731
757
 
732
758
  std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
@@ -734,13 +760,9 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
734
760
  auto [cli, parts] = common_http_client(url);
735
761
 
736
762
  httplib::Headers headers = {{"User-Agent", "llama-cpp"}};
763
+
737
764
  for (const auto & header : params.headers) {
738
- size_t pos = header.find(':');
739
- if (pos != std::string::npos) {
740
- headers.emplace(header.substr(0, pos), header.substr(pos + 1));
741
- } else {
742
- headers.emplace(header, "");
743
- }
765
+ headers.emplace(header.first, header.second);
744
766
  }
745
767
 
746
768
  if (params.timeout > 0) {
@@ -769,32 +791,45 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
769
791
 
770
792
  #if defined(LLAMA_USE_CURL) || defined(LLAMA_USE_HTTPLIB)
771
793
 
772
- static bool common_download_file_single(const std::string & url,
773
- const std::string & path,
774
- const std::string & bearer_token,
775
- bool offline) {
794
+ int common_download_file_single(const std::string & url,
795
+ const std::string & path,
796
+ const std::string & bearer_token,
797
+ bool offline,
798
+ const common_header_list & headers) {
776
799
  if (!offline) {
777
- return common_download_file_single_online(url, path, bearer_token);
800
+ return common_download_file_single_online(url, path, bearer_token, headers);
778
801
  }
779
802
 
780
803
  if (!std::filesystem::exists(path)) {
781
804
  LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
782
- return false;
805
+ return -1;
783
806
  }
784
807
 
785
808
  LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
786
- return true;
809
+ return 304; // Not Modified - fake cached response
787
810
  }
788
811
 
789
812
  // download multiple files from remote URLs to local paths
790
813
  // the input is a vector of pairs <url, path>
791
- static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls, const std::string & bearer_token, bool offline) {
814
+ static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls,
815
+ const std::string & bearer_token,
816
+ bool offline,
817
+ const common_header_list & headers) {
792
818
  // Prepare download in parallel
793
819
  std::vector<std::future<bool>> futures_download;
820
+ futures_download.reserve(urls.size());
821
+
794
822
  for (auto const & item : urls) {
795
- futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair<std::string, std::string> & it) -> bool {
796
- return common_download_file_single(it.first, it.second, bearer_token, offline);
797
- }, item));
823
+ futures_download.push_back(
824
+ std::async(
825
+ std::launch::async,
826
+ [&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
827
+ const int http_status = common_download_file_single(it.first, it.second, bearer_token, offline, headers);
828
+ return is_http_status_ok(http_status);
829
+ },
830
+ item
831
+ )
832
+ );
798
833
  }
799
834
 
800
835
  // Wait for all downloads to complete
@@ -807,17 +842,18 @@ static bool common_download_file_multiple(const std::vector<std::pair<std::strin
807
842
  return true;
808
843
  }
809
844
 
810
- bool common_download_model(
811
- const common_params_model & model,
812
- const std::string & bearer_token,
813
- bool offline) {
845
+ bool common_download_model(const common_params_model & model,
846
+ const std::string & bearer_token,
847
+ bool offline,
848
+ const common_header_list & headers) {
814
849
  // Basic validation of the model.url
815
850
  if (model.url.empty()) {
816
851
  LOG_ERR("%s: invalid model url\n", __func__);
817
852
  return false;
818
853
  }
819
854
 
820
- if (!common_download_file_single(model.url, model.path, bearer_token, offline)) {
855
+ const int http_status = common_download_file_single(model.url, model.path, bearer_token, offline, headers);
856
+ if (!is_http_status_ok(http_status)) {
821
857
  return false;
822
858
  }
823
859
 
@@ -876,13 +912,16 @@ bool common_download_model(
876
912
  }
877
913
 
878
914
  // Download in parallel
879
- common_download_file_multiple(urls, bearer_token, offline);
915
+ common_download_file_multiple(urls, bearer_token, offline, headers);
880
916
  }
881
917
 
882
918
  return true;
883
919
  }
884
920
 
885
- common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token, bool offline) {
921
+ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag,
922
+ const std::string & bearer_token,
923
+ bool offline,
924
+ const common_header_list & custom_headers) {
886
925
  auto parts = string_split<std::string>(hf_repo_with_tag, ':');
887
926
  std::string tag = parts.size() > 1 ? parts.back() : "latest";
888
927
  std::string hf_repo = parts[0];
@@ -893,10 +932,10 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons
893
932
  std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;
894
933
 
895
934
  // headers
896
- std::vector<std::string> headers;
897
- headers.push_back("Accept: application/json");
935
+ common_header_list headers = custom_headers;
936
+ headers.push_back({"Accept", "application/json"});
898
937
  if (!bearer_token.empty()) {
899
- headers.push_back("Authorization: Bearer " + bearer_token);
938
+ headers.push_back({"Authorization", "Bearer " + bearer_token});
900
939
  }
901
940
  // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
902
941
  // User-Agent header is already set in common_remote_get_content, no need to set it here
@@ -952,7 +991,7 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons
952
991
  } else if (res_code == 401) {
953
992
  throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
954
993
  } else {
955
- throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
994
+ throw std::runtime_error(string_format("error from HF API (%s), response code: %ld, data: %s", url.c_str(), res_code, res_str.c_str()));
956
995
  }
957
996
 
958
997
  // check response
@@ -1031,9 +1070,10 @@ std::string common_docker_resolve_model(const std::string & docker) {
1031
1070
  const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
1032
1071
  std::string manifest_url = url_prefix + "/manifests/" + tag;
1033
1072
  common_remote_params manifest_params;
1034
- manifest_params.headers.push_back("Authorization: Bearer " + token);
1035
- manifest_params.headers.push_back(
1036
- "Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
1073
+ manifest_params.headers.push_back({"Authorization", "Bearer " + token});
1074
+ manifest_params.headers.push_back({"Accept",
1075
+ "application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"
1076
+ });
1037
1077
  auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
1038
1078
  if (manifest_res.first != 200) {
1039
1079
  throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
@@ -1070,7 +1110,8 @@ std::string common_docker_resolve_model(const std::string & docker) {
1070
1110
  std::string local_path = fs_get_cache_file(model_filename);
1071
1111
 
1072
1112
  const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
1073
- if (!common_download_file_single(blob_url, local_path, token, false)) {
1113
+ const int http_status = common_download_file_single(blob_url, local_path, token, false, {});
1114
+ if (!is_http_status_ok(http_status)) {
1074
1115
  throw std::runtime_error("Failed to download Docker Model");
1075
1116
  }
1076
1117
 
@@ -1084,11 +1125,11 @@ std::string common_docker_resolve_model(const std::string & docker) {
1084
1125
 
1085
1126
  #else
1086
1127
 
1087
- common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool) {
1128
+ common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) {
1088
1129
  throw std::runtime_error("download functionality is not enabled in this build");
1089
1130
  }
1090
1131
 
1091
- bool common_download_model(const common_params_model &, const std::string &, bool) {
1132
+ bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) {
1092
1133
  throw std::runtime_error("download functionality is not enabled in this build");
1093
1134
  }
1094
1135
 
@@ -1096,6 +1137,14 @@ std::string common_docker_resolve_model(const std::string &) {
1096
1137
  throw std::runtime_error("download functionality is not enabled in this build");
1097
1138
  }
1098
1139
 
1140
+ int common_download_file_single(const std::string &,
1141
+ const std::string &,
1142
+ const std::string &,
1143
+ bool,
1144
+ const common_header_list &) {
1145
+ throw std::runtime_error("download functionality is not enabled in this build");
1146
+ }
1147
+
1099
1148
  #endif // LLAMA_USE_CURL || LLAMA_USE_HTTPLIB
1100
1149
 
1101
1150
  std::vector<common_cached_model_info> common_list_cached_models() {
@@ -1,12 +1,21 @@
1
1
  #pragma once
2
2
 
3
3
  #include <string>
4
+ #include <vector>
4
5
 
5
6
  struct common_params_model;
6
7
 
7
- //
8
- // download functionalities
9
- //
8
+ using common_header = std::pair<std::string, std::string>;
9
+ using common_header_list = std::vector<common_header>;
10
+
11
+ struct common_remote_params {
12
+ common_header_list headers;
13
+ long timeout = 0; // in seconds, 0 means no timeout
14
+ long max_size = 0; // unlimited if 0
15
+ };
16
+
17
+ // get remote file content, returns <http_code, raw_response_body>
18
+ std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params);
10
19
 
11
20
  struct common_cached_model_info {
12
21
  std::string manifest_path;
@@ -41,17 +50,29 @@ struct common_hf_file_res {
41
50
  common_hf_file_res common_get_hf_file(
42
51
  const std::string & hf_repo_with_tag,
43
52
  const std::string & bearer_token,
44
- bool offline);
53
+ bool offline,
54
+ const common_header_list & headers = {}
55
+ );
45
56
 
46
57
  // returns true if download succeeded
47
58
  bool common_download_model(
48
59
  const common_params_model & model,
49
60
  const std::string & bearer_token,
50
- bool offline);
61
+ bool offline,
62
+ const common_header_list & headers = {}
63
+ );
51
64
 
52
65
  // returns list of cached models
53
66
  std::vector<common_cached_model_info> common_list_cached_models();
54
67
 
68
+ // download single file from url to local path
69
+ // returns status code or -1 on error
70
+ int common_download_file_single(const std::string & url,
71
+ const std::string & path,
72
+ const std::string & bearer_token,
73
+ bool offline,
74
+ const common_header_list & headers = {});
75
+
55
76
  // resolve and download model from Docker registry
56
77
  // return local path to downloaded model file
57
78
  std::string common_docker_resolve_model(const std::string & docker);
@@ -16,6 +16,46 @@ static std::string rm_leading_dashes(const std::string & str) {
16
16
  return str.substr(pos);
17
17
  }
18
18
 
19
+ // only allow a subset of args for remote presets for security reasons
20
+ // do not add more args unless absolutely necessary
21
+ // args that output to files are strictly prohibited
22
+ static std::set<std::string> get_remote_preset_whitelist(const std::map<std::string, common_arg> & key_to_opt) {
23
+ static const std::set<std::string> allowed_options = {
24
+ "model-url",
25
+ "hf-repo",
26
+ "hf-repo-draft",
27
+ "hf-repo-v", // vocoder
28
+ "hf-file-v", // vocoder
29
+ "mmproj-url",
30
+ "pooling",
31
+ "jinja",
32
+ "batch-size",
33
+ "ubatch-size",
34
+ "cache-reuse",
35
+ // note: sampling params are automatically allowed by default
36
+ // negated args will be added automatically
37
+ };
38
+
39
+ std::set<std::string> allowed_keys;
40
+
41
+ for (const auto & it : key_to_opt) {
42
+ const std::string & key = it.first;
43
+ const common_arg & opt = it.second;
44
+ if (allowed_options.find(key) != allowed_options.end() || opt.is_sparam) {
45
+ allowed_keys.insert(key);
46
+ // also add variant keys (args without leading dashes and env vars)
47
+ for (const auto & arg : opt.get_args()) {
48
+ allowed_keys.insert(rm_leading_dashes(arg));
49
+ }
50
+ for (const auto & env : opt.get_env()) {
51
+ allowed_keys.insert(env);
52
+ }
53
+ }
54
+ }
55
+
56
+ return allowed_keys;
57
+ }
58
+
19
59
  std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
20
60
  std::vector<std::string> args;
21
61
 
@@ -121,6 +161,29 @@ void common_preset::merge(const common_preset & other) {
121
161
  }
122
162
  }
123
163
 
164
+ void common_preset::apply_to_params(common_params & params) const {
165
+ for (const auto & [opt, val] : options) {
166
+ // apply each option to params
167
+ if (opt.handler_string) {
168
+ opt.handler_string(params, val);
169
+ } else if (opt.handler_int) {
170
+ opt.handler_int(params, std::stoi(val));
171
+ } else if (opt.handler_bool) {
172
+ opt.handler_bool(params, common_arg_utils::is_truthy(val));
173
+ } else if (opt.handler_str_str) {
174
+ // not supported yet
175
+ throw std::runtime_error(string_format(
176
+ "%s: option with two values is not supported yet",
177
+ __func__
178
+ ));
179
+ } else if (opt.handler_void) {
180
+ opt.handler_void(params);
181
+ } else {
182
+ GGML_ABORT("unknown handler type");
183
+ }
184
+ }
185
+ }
186
+
124
187
  static std::map<std::string, std::map<std::string, std::string>> parse_ini_from_file(const std::string & path) {
125
188
  std::map<std::string, std::map<std::string, std::string>> parsed;
126
189
 
@@ -230,10 +293,16 @@ static std::string parse_bool_arg(const common_arg & arg, const std::string & ke
230
293
  return value;
231
294
  }
232
295
 
233
- common_preset_context::common_preset_context(llama_example ex)
296
+ common_preset_context::common_preset_context(llama_example ex, bool only_remote_allowed)
234
297
  : ctx_params(common_params_parser_init(default_params, ex)) {
235
298
  common_params_add_preset_options(ctx_params.options);
236
299
  key_to_opt = get_map_key_opt(ctx_params);
300
+
301
+ // setup allowed keys if only_remote_allowed is true
302
+ if (only_remote_allowed) {
303
+ filter_allowed_keys = true;
304
+ allowed_keys = get_remote_preset_whitelist(key_to_opt);
305
+ }
237
306
  }
238
307
 
239
308
  common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const {
@@ -250,6 +319,12 @@ common_presets common_preset_context::load_from_ini(const std::string & path, co
250
319
  LOG_DBG("loading preset: %s\n", preset.name.c_str());
251
320
  for (const auto & [key, value] : section.second) {
252
321
  LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str());
322
+ if (filter_allowed_keys && allowed_keys.find(key) == allowed_keys.end()) {
323
+ throw std::runtime_error(string_format(
324
+ "option '%s' is not allowed in remote presets",
325
+ key.c_str()
326
+ ));
327
+ }
253
328
  if (key_to_opt.find(key) != key_to_opt.end()) {
254
329
  const auto & opt = key_to_opt.at(key);
255
330
  if (is_bool_arg(opt)) {
@@ -6,6 +6,7 @@
6
6
  #include <string>
7
7
  #include <vector>
8
8
  #include <map>
9
+ #include <set>
9
10
 
10
11
  //
11
12
  // INI preset parser and writer
@@ -40,6 +41,9 @@ struct common_preset {
40
41
 
41
42
  // merge another preset into this one, overwriting existing options
42
43
  void merge(const common_preset & other);
44
+
45
+ // apply preset options to common_params
46
+ void apply_to_params(common_params & params) const;
43
47
  };
44
48
 
45
49
  // interface for multiple presets in one file
@@ -50,7 +54,12 @@ struct common_preset_context {
50
54
  common_params default_params; // unused for now
51
55
  common_params_context ctx_params;
52
56
  std::map<std::string, common_arg> key_to_opt;
53
- common_preset_context(llama_example ex);
57
+
58
+ bool filter_allowed_keys = false;
59
+ std::set<std::string> allowed_keys;
60
+
61
+ // if only_remote_allowed is true, only accept whitelisted keys
62
+ common_preset_context(llama_example ex, bool only_remote_allowed = false);
54
63
 
55
64
  // load presets from INI file
56
65
  common_presets load_from_ini(const std::string & path, common_preset & global) const;
@@ -234,6 +234,11 @@
234
234
 
235
235
  #if UINTPTR_MAX == 0xFFFFFFFF
236
236
  #define GGML_MEM_ALIGN 4
237
+ #elif defined(__EMSCRIPTEN__)
238
+ // emscripten uses max_align_t == 8, so we need GGML_MEM_ALIGN == 8 for 64-bit wasm.
239
+ // (for 32-bit wasm, the first conditional is true and GGML_MEM_ALIGN stays 4.)
240
+ // ref: https://github.com/ggml-org/llama.cpp/pull/18628
241
+ #define GGML_MEM_ALIGN 8
237
242
  #else
238
243
  #define GGML_MEM_ALIGN 16
239
244
  #endif
@@ -309,6 +309,7 @@ extern "C" {
309
309
  // Keep the booleans together to avoid misalignment during copy-by-value.
310
310
  bool vocab_only; // only load the vocabulary, no weights
311
311
  bool use_mmap; // use mmap if possible
312
+ bool use_direct_io; // use direct io, takes precedence over use_mmap
312
313
  bool use_mlock; // force system to keep model in RAM
313
314
  bool check_tensors; // validate model tensor data
314
315
  bool use_extra_bufts; // use extra buffer types (used for weight repacking)
@@ -494,7 +495,7 @@ extern "C" {
494
495
  struct llama_context_params * cparams,
495
496
  float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements
496
497
  struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements
497
- size_t margin, // margin of memory to leave per device in bytes
498
+ size_t * margins, // margins of memory to leave per device in bytes
498
499
  uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use
499
500
  enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log
500
501
 
@@ -1291,7 +1292,9 @@ extern "C" {
1291
1292
  // available samplers:
1292
1293
 
1293
1294
  LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
1294
- LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
1295
+
1296
+ /// seed == LLAMA_DEFAULT_SEED to use a random seed.
1297
+ LLAMA_API struct llama_sampler * llama_sampler_init_dist(uint32_t seed);
1295
1298
 
1296
1299
  /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
1297
1300
  /// Setting k <= 0 makes this a noop