@fugood/llama.node 1.4.7 → 1.4.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. package/lib/binding.ts +8 -0
  2. package/package.json +15 -15
  3. package/scripts/llama.cpp.patch +23 -24
  4. package/src/LlamaContext.cpp +4 -2
  5. package/src/llama.cpp/common/CMakeLists.txt +2 -0
  6. package/src/llama.cpp/common/arg.cpp +470 -223
  7. package/src/llama.cpp/common/arg.h +43 -2
  8. package/src/llama.cpp/common/chat-peg-parser.cpp +16 -2
  9. package/src/llama.cpp/common/chat.cpp +140 -0
  10. package/src/llama.cpp/common/common.cpp +130 -67
  11. package/src/llama.cpp/common/common.h +44 -17
  12. package/src/llama.cpp/common/console.cpp +98 -18
  13. package/src/llama.cpp/common/console.h +30 -8
  14. package/src/llama.cpp/common/download.cpp +69 -25
  15. package/src/llama.cpp/common/json-schema-to-grammar.cpp +132 -3
  16. package/src/llama.cpp/common/json-schema-to-grammar.h +20 -0
  17. package/src/llama.cpp/common/log.cpp +5 -0
  18. package/src/llama.cpp/common/log.h +1 -0
  19. package/src/llama.cpp/common/peg-parser.cpp +1 -1
  20. package/src/llama.cpp/common/preset.cpp +206 -0
  21. package/src/llama.cpp/common/preset.h +32 -0
  22. package/src/llama.cpp/common/sampling.cpp +67 -54
  23. package/src/llama.cpp/common/sampling.h +8 -0
  24. package/src/llama.cpp/ggml/CMakeLists.txt +4 -0
  25. package/src/llama.cpp/ggml/include/ggml-alloc.h +9 -0
  26. package/src/llama.cpp/ggml/include/ggml-backend.h +1 -0
  27. package/src/llama.cpp/ggml/include/ggml-cpu.h +1 -0
  28. package/src/llama.cpp/ggml/include/ggml.h +7 -8
  29. package/src/llama.cpp/ggml/src/CMakeLists.txt +3 -0
  30. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +4 -0
  31. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +285 -0
  32. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +28 -0
  33. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -45
  34. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  35. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +288 -1
  36. package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +8 -0
  37. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +41 -1
  38. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +125 -22
  39. package/src/llama.cpp/include/llama.h +18 -1
  40. package/src/llama.cpp/src/llama-arch.cpp +1890 -2248
  41. package/src/llama.cpp/src/llama-arch.h +9 -2
  42. package/src/llama.cpp/src/llama-batch.cpp +12 -2
  43. package/src/llama.cpp/src/llama-batch.h +4 -2
  44. package/src/llama.cpp/src/llama-context.cpp +93 -23
  45. package/src/llama.cpp/src/llama-context.h +8 -2
  46. package/src/llama.cpp/src/llama-graph.cpp +84 -16
  47. package/src/llama.cpp/src/llama-graph.h +17 -4
  48. package/src/llama.cpp/src/llama-hparams.cpp +6 -0
  49. package/src/llama.cpp/src/llama-hparams.h +5 -1
  50. package/src/llama.cpp/src/llama-impl.cpp +4 -0
  51. package/src/llama.cpp/src/llama-kv-cache.cpp +90 -42
  52. package/src/llama.cpp/src/llama-kv-cache.h +19 -2
  53. package/src/llama.cpp/src/llama-memory-hybrid.cpp +1 -1
  54. package/src/llama.cpp/src/llama-mmap.cpp +123 -28
  55. package/src/llama.cpp/src/llama-mmap.h +5 -1
  56. package/src/llama.cpp/src/llama-model-loader.cpp +58 -13
  57. package/src/llama.cpp/src/llama-model-loader.h +2 -0
  58. package/src/llama.cpp/src/llama-model.cpp +110 -49
  59. package/src/llama.cpp/src/llama-model.h +1 -0
  60. package/src/llama.cpp/src/llama-quant.cpp +1 -1
  61. package/src/llama.cpp/src/llama-sampling.cpp +16 -0
  62. package/src/llama.cpp/src/llama-vocab.cpp +2 -1
  63. package/src/llama.cpp/src/llama.cpp +665 -1
  64. package/src/llama.cpp/src/models/deepseek2.cpp +9 -5
  65. package/src/llama.cpp/src/models/glm4-moe.cpp +28 -11
  66. package/src/llama.cpp/src/models/glm4.cpp +27 -4
  67. package/src/llama.cpp/src/models/models.h +5 -5
  68. package/src/llama.cpp/src/models/nemotron-h.cpp +35 -6
  69. package/src/llama.cpp/src/models/qwen2.cpp +12 -3
  70. package/src/llama.cpp/src/models/qwen3next.cpp +81 -266
@@ -3,8 +3,10 @@
3
3
  #include "common.h"
4
4
 
5
5
  #include <set>
6
+ #include <map>
6
7
  #include <string>
7
8
  #include <vector>
9
+ #include <cstring>
8
10
 
9
11
  //
10
12
  // CLI argument parsing
@@ -14,6 +16,7 @@ struct common_arg {
14
16
  std::set<enum llama_example> examples = {LLAMA_EXAMPLE_COMMON};
15
17
  std::set<enum llama_example> excludes = {};
16
18
  std::vector<const char *> args;
19
+ std::vector<const char *> args_neg; // for negated args like --no-xxx
17
20
  const char * value_hint = nullptr; // help text or example for arg value
18
21
  const char * value_hint_2 = nullptr; // for second arg value
19
22
  const char * env = nullptr;
@@ -23,6 +26,9 @@ struct common_arg {
23
26
  void (*handler_string) (common_params & params, const std::string &) = nullptr;
24
27
  void (*handler_str_str)(common_params & params, const std::string &, const std::string &) = nullptr;
25
28
  void (*handler_int) (common_params & params, int) = nullptr;
29
+ void (*handler_bool) (common_params & params, bool) = nullptr;
30
+
31
+ common_arg() = default;
26
32
 
27
33
  common_arg(
28
34
  const std::initializer_list<const char *> & args,
@@ -44,6 +50,13 @@ struct common_arg {
44
50
  void (*handler)(common_params & params)
45
51
  ) : args(args), help(help), handler_void(handler) {}
46
52
 
53
+ common_arg(
54
+ const std::initializer_list<const char *> & args,
55
+ const std::initializer_list<const char *> & args_neg,
56
+ const std::string & help,
57
+ void (*handler)(common_params & params, bool)
58
+ ) : args(args), args_neg(args_neg), help(help), handler_bool(handler) {}
59
+
47
60
  // support 2 values for arg
48
61
  common_arg(
49
62
  const std::initializer_list<const char *> & args,
@@ -61,9 +74,33 @@ struct common_arg {
61
74
  bool is_exclude(enum llama_example ex);
62
75
  bool get_value_from_env(std::string & output) const;
63
76
  bool has_value_from_env() const;
64
- std::string to_string();
77
+ std::string to_string() const;
78
+
79
+ // for using as key in std::map
80
+ bool operator<(const common_arg& other) const {
81
+ if (args.empty() || other.args.empty()) {
82
+ return false;
83
+ }
84
+ return strcmp(args[0], other.args[0]) < 0;
85
+ }
86
+ bool operator==(const common_arg& other) const {
87
+ if (args.empty() || other.args.empty()) {
88
+ return false;
89
+ }
90
+ return strcmp(args[0], other.args[0]) == 0;
91
+ }
92
+
93
+ // get all args and env vars (including negated args/env)
94
+ std::vector<std::string> get_args() const;
95
+ std::vector<std::string> get_env() const;
65
96
  };
66
97
 
98
+ namespace common_arg_utils {
99
+ bool is_truthy(const std::string & value);
100
+ bool is_falsey(const std::string & value);
101
+ bool is_autoy(const std::string & value);
102
+ }
103
+
67
104
  struct common_params_context {
68
105
  enum llama_example ex = LLAMA_EXAMPLE_COMMON;
69
106
  common_params & params;
@@ -76,7 +113,11 @@ struct common_params_context {
76
113
  // if one argument has invalid value, it will automatically display usage of the specific argument (and not the full usage message)
77
114
  bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
78
115
 
79
- // function to be used by test-arg-parser
116
+ // parse input arguments from CLI into a map
117
+ // TODO: support repeated args in the future
118
+ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map);
119
+
120
+ // initialize argument parser context - used by test-arg-parser and preset
80
121
  common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
81
122
 
82
123
  struct common_remote_params {
@@ -1,8 +1,17 @@
1
1
  #include "chat-peg-parser.h"
2
2
 
3
- static std::string_view trim_trailing_space(std::string_view sv) {
3
+ #include <nlohmann/json.hpp>
4
+
5
+ using json = nlohmann::ordered_json;
6
+
7
+ static std::string_view trim_trailing_space(std::string_view sv, int max = -1) {
8
+ int count = 0;
4
9
  while (!sv.empty() && std::isspace(static_cast<unsigned char>(sv.back()))) {
10
+ if (max != -1 && count <= max) {
11
+ break;
12
+ }
5
13
  sv.remove_suffix(1);
14
+ count++;
6
15
  }
7
16
  return sv;
8
17
  }
@@ -89,7 +98,7 @@ void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
89
98
 
90
99
  if (is_arg_string && current_tool) {
91
100
  // Serialize to JSON, but exclude the end quote
92
- std::string dumped = json(node.text).dump();
101
+ std::string dumped = json(trim_trailing_space(node.text)).dump();
93
102
  current_tool->arguments += dumped.substr(0, dumped.size() - 1);
94
103
  needs_closing_quote = true;
95
104
  }
@@ -97,6 +106,7 @@ void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
97
106
  if (is_arg_close && current_tool) {
98
107
  if (needs_closing_quote) {
99
108
  current_tool->arguments += "\"";
109
+ needs_closing_quote = false;
100
110
  }
101
111
  }
102
112
 
@@ -105,6 +115,10 @@ void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
105
115
  }
106
116
 
107
117
  if (is_tool_close && current_tool) {
118
+ if (needs_closing_quote) {
119
+ current_tool->arguments += "\"";
120
+ needs_closing_quote = false;
121
+ }
108
122
  current_tool->arguments += "}";
109
123
  }
110
124
  }
@@ -698,6 +698,25 @@ static void foreach_function(const json & tools, const std::function<void(const
698
698
  }
699
699
  }
700
700
 
701
+ static void foreach_parameter(const json & function, const std::function<void(const std::string &, const json &, bool)> & fn) {
702
+ if (!function.contains("parameters") || !function.at("parameters").is_object()) {
703
+ return;
704
+ }
705
+ const auto & params = function.at("parameters");
706
+ if (!params.contains("properties") || !params.at("properties").is_object()) {
707
+ return;
708
+ }
709
+ const auto & props = params.at("properties");
710
+ std::set<std::string> required;
711
+ if (params.contains("required") && params.at("required").is_array()) {
712
+ params.at("required").get_to(required);
713
+ }
714
+ for (const auto & [name, prop] : props.items()) {
715
+ bool is_required = (required.find(name) != required.end());
716
+ fn(name, prop, is_required);
717
+ }
718
+ }
719
+
701
720
  static std::string apply(
702
721
  const common_chat_template & tmpl,
703
722
  const struct templates_params & inputs,
@@ -1396,6 +1415,123 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_
1396
1415
  return data;
1397
1416
  }
1398
1417
 
1418
+ static common_chat_params common_chat_params_init_nemotron_v3(const common_chat_template & tmpl, const struct templates_params & inputs) {
1419
+ common_chat_params data;
1420
+
1421
+ data.prompt = apply(tmpl, inputs);
1422
+ data.format = COMMON_CHAT_FORMAT_PEG_CONSTRUCTED;
1423
+
1424
+ // Handle thinking tags appropriately based on inputs.enable_thinking
1425
+ if (string_ends_with(data.prompt, "<think>\n")) {
1426
+ if (!inputs.enable_thinking) {
1427
+ data.prompt += "</think>";
1428
+ } else {
1429
+ data.thinking_forced_open = true;
1430
+ }
1431
+ }
1432
+
1433
+ data.preserved_tokens = {
1434
+ "<think>",
1435
+ "</think>",
1436
+ "<tool_call>",
1437
+ "</tool_call>",
1438
+ };
1439
+
1440
+ auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
1441
+ auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
1442
+ auto include_grammar = true;
1443
+
1444
+ auto parser = build_chat_peg_constructed_parser([&](auto & p) {
1445
+ auto reasoning = p.eps();
1446
+ if (inputs.enable_thinking && extract_reasoning) {
1447
+ auto reasoning_content = p.reasoning(p.until("</think>")) + ("</think>" | p.end());
1448
+ if (data.thinking_forced_open) {
1449
+ reasoning = reasoning_content;
1450
+ }
1451
+ }
1452
+
1453
+ // Response format parser
1454
+ if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
1455
+ return reasoning << p.content(p.schema(p.json(), "response-format", inputs.json_schema));
1456
+ }
1457
+
1458
+ // Tool call parser
1459
+ if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
1460
+ auto tool_choice = p.choice();
1461
+ foreach_function(inputs.tools, [&](const json & tool) {
1462
+ const auto & function = tool.at("function");
1463
+ std::string name = function.at("name");
1464
+ auto parameters = function.at("parameters");
1465
+
1466
+ auto schema_info = common_schema_info();
1467
+ schema_info.resolve_refs(parameters);
1468
+
1469
+ auto tool_open = "<function=" + p.tool_name(p.literal(name)) + ">\n";
1470
+ auto tool_close = p.literal("</function>\n");
1471
+ auto args = p.sequence();
1472
+ auto arg_string = p.rule("xml-arg-string", p.until_one_of({
1473
+ "\n</parameter>",
1474
+ "\n<parameter=",
1475
+ "\n</function>"
1476
+ }));
1477
+
1478
+ foreach_parameter(function, [&](const auto & param_name, const json & param_schema, bool is_required) {
1479
+ auto rule_name = "tool-" + name + "-arg-" + param_name;
1480
+
1481
+ auto arg_open = "<parameter=" + p.tool_arg_name(p.literal(param_name)) + ">\n";
1482
+ auto arg_close = p.literal("</parameter>\n");
1483
+ auto arg_value = p.eps();
1484
+
1485
+ if (schema_info.resolves_to_string(param_schema)) {
1486
+ arg_value = p.tool_arg_string_value(arg_string) + "\n";
1487
+ } else {
1488
+ arg_value = p.tool_arg_json_value(p.schema(p.json(), rule_name + "-schema", param_schema));
1489
+ }
1490
+
1491
+ // Model may or my not close with </parameter>
1492
+ auto arg_rule = p.rule(rule_name, p.tool_arg_open(arg_open) + arg_value + p.optional(p.tool_arg_close(arg_close)));
1493
+ args += p.repeat(arg_rule, /* min = */ is_required ? 1 : 0, /* max = */ 1);
1494
+ });
1495
+
1496
+ tool_choice |= p.rule("tool-" + name, p.tool_open(tool_open) + args + p.tool_close(tool_close));
1497
+ });
1498
+
1499
+ auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0;
1500
+ auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
1501
+ auto tool_call = p.rule("tool-call", "<tool_call>\n" + tool_choice + "</tool_call>" + p.space());
1502
+ auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls));
1503
+
1504
+ return reasoning << p.content(p.until("<tool_call>")) << tool_calls;
1505
+ }
1506
+
1507
+ // Content only parser
1508
+ include_grammar = false;
1509
+ return reasoning << p.content(p.rest());
1510
+ });
1511
+
1512
+ data.parser = parser.save();
1513
+
1514
+ if (include_grammar) {
1515
+ data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
1516
+
1517
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1518
+ foreach_function(inputs.tools, [&](const json & tool) {
1519
+ const auto & function = tool.at("function");
1520
+ auto schema = function.at("parameters");
1521
+ builder.resolve_refs(schema);
1522
+ });
1523
+ parser.build_grammar(builder, data.grammar_lazy);
1524
+ });
1525
+
1526
+ data.grammar_triggers = {
1527
+ {COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<tool_call>"}
1528
+ };
1529
+ }
1530
+
1531
+ return data;
1532
+ }
1533
+
1534
+
1399
1535
  static common_chat_params common_chat_params_init_apertus(const common_chat_template & tmpl, const struct templates_params & inputs) {
1400
1536
  common_chat_params data;
1401
1537
 
@@ -2521,6 +2657,10 @@ static common_chat_params common_chat_templates_apply_jinja(
2521
2657
  src.find("<function=") != std::string::npos &&
2522
2658
  src.find("<parameters>") != std::string::npos &&
2523
2659
  src.find("<parameter=") != std::string::npos) {
2660
+ // Nemotron 3 Nano 30B A3B
2661
+ if (src.find("<think>") != std::string::npos) {
2662
+ return common_chat_params_init_nemotron_v3(tmpl, params);
2663
+ }
2524
2664
  return common_chat_params_init_qwen3_coder_xml(tmpl, params);
2525
2665
  }
2526
2666
 
@@ -1013,31 +1013,40 @@ bool tty_can_use_colors() {
1013
1013
  // Model utils
1014
1014
  //
1015
1015
 
1016
- static inline void common_init_sampler_from_model(
1016
+ // TODO: move to common/sampling
1017
+ static void common_init_sampler_from_model(
1017
1018
  const llama_model * model,
1018
1019
  common_params_sampling & sparams) {
1019
1020
 
1020
1021
  const uint64_t config = sparams.user_sampling_config;
1021
1022
 
1022
1023
  auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
1023
- if (config & user_config) return;
1024
+ if (config & user_config) {
1025
+ return;
1026
+ }
1024
1027
 
1025
1028
  char buf[64] = {0};
1026
1029
  if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
1027
1030
  char * end = nullptr;
1028
1031
  int32_t v = strtol(buf, &end, 10);
1029
- if (end && end != buf) dst = v;
1032
+ if (end && end != buf) {
1033
+ dst = v;
1034
+ }
1030
1035
  }
1031
1036
  };
1032
1037
 
1033
1038
  auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
1034
- if (config & user_config) return;
1039
+ if (config & user_config) {
1040
+ return;
1041
+ }
1035
1042
 
1036
1043
  char buf[128] = {0};
1037
1044
  if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
1038
1045
  char * end = nullptr;
1039
1046
  float v = strtof(buf, &end);
1040
- if (end && end != buf) dst = v;
1047
+ if (end && end != buf) {
1048
+ dst = v;
1049
+ }
1041
1050
  }
1042
1051
  };
1043
1052
 
@@ -1065,31 +1074,125 @@ static inline void common_init_sampler_from_model(
1065
1074
  get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
1066
1075
  }
1067
1076
 
1068
- struct common_init_result common_init_from_params(common_params & params) {
1069
- common_init_result iparams;
1077
+ struct common_init_result::impl {
1078
+ impl() = default;
1079
+ ~impl() = default;
1080
+
1081
+ llama_model_ptr model;
1082
+ llama_context_ptr context;
1083
+
1084
+ std::vector<llama_adapter_lora_ptr> lora;
1085
+
1086
+ std::vector<common_sampler_ptr> samplers;
1087
+ };
1088
+
1089
+ common_init_result::common_init_result(common_params & params) :
1090
+ pimpl(new impl{}) {
1070
1091
  auto mparams = common_model_params_to_llama(params);
1092
+ auto cparams = common_context_params_to_llama(params);
1093
+
1094
+ if (params.fit_params) {
1095
+ LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
1096
+ llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
1097
+ params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
1098
+ params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
1099
+ }
1071
1100
 
1072
1101
  llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
1073
1102
  if (model == NULL) {
1074
- LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
1075
- __func__, params.model.path.c_str());
1076
- return iparams;
1103
+ return;
1077
1104
  }
1078
1105
 
1079
- common_init_sampler_from_model(model, params.sampling);
1106
+ pimpl->model.reset(model);
1080
1107
 
1081
1108
  const llama_vocab * vocab = llama_model_get_vocab(model);
1082
1109
 
1083
- auto cparams = common_context_params_to_llama(params);
1110
+ // updates params.sampling
1111
+ // TODO: fix naming
1112
+ common_init_sampler_from_model(model, params.sampling);
1113
+
1114
+ if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
1115
+ LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
1116
+ params.sampling.ignore_eos = false;
1117
+ }
1118
+
1119
+ // initialize once
1120
+ for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
1121
+ if (llama_vocab_is_eog(vocab, i)) {
1122
+ LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
1123
+ params.sampling.logit_bias_eog.push_back({i, -INFINITY});
1124
+ }
1125
+ }
1126
+
1127
+ if (params.sampling.ignore_eos) {
1128
+ // add EOG biases to the active set of logit biases
1129
+ params.sampling.logit_bias.insert(
1130
+ params.sampling.logit_bias.end(),
1131
+ params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
1132
+ }
1133
+
1134
+ //if (params.sampling.penalty_last_n == -1) {
1135
+ // LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1136
+ // params.sampling.penalty_last_n = llama_n_ctx(lctx);
1137
+ //}
1138
+
1139
+ //if (params.sampling.dry_penalty_last_n == -1) {
1140
+ // LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1141
+ // params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
1142
+ //}
1143
+
1144
+ pimpl->samplers.resize(cparams.n_seq_max);
1145
+
1146
+ for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
1147
+ pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
1148
+ }
1084
1149
 
1085
1150
  llama_context * lctx = llama_init_from_model(model, cparams);
1086
1151
  if (lctx == NULL) {
1087
- LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
1088
- __func__, params.model.path.c_str());
1089
- llama_model_free(model);
1090
- return iparams;
1152
+ LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
1153
+ return;
1154
+ }
1155
+
1156
+ pimpl->context.reset(lctx);
1157
+ }
1158
+
1159
+ llama_model * common_init_result::model() {
1160
+ return pimpl->model.get();
1161
+ }
1162
+
1163
+ llama_context * common_init_result::context() {
1164
+ return pimpl->context.get();
1165
+ }
1166
+
1167
+ common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
1168
+ return pimpl->samplers[seq_id].get();
1169
+ }
1170
+
1171
+ std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
1172
+ return pimpl->lora;
1173
+ }
1174
+
1175
+ void common_init_result::free_context() {
1176
+ pimpl->context.reset();
1177
+ }
1178
+
1179
+ common_init_result_ptr common_init_from_params(common_params & params) {
1180
+ common_init_result_ptr res(new common_init_result(params));
1181
+
1182
+ llama_model * model = res->model();
1183
+ if (model == NULL) {
1184
+ LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
1185
+ return res;
1186
+ }
1187
+
1188
+ llama_context * lctx = res->context();
1189
+ if (lctx == NULL) {
1190
+ LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
1191
+ return res;
1091
1192
  }
1092
1193
 
1194
+ const llama_vocab * vocab = llama_model_get_vocab(model);
1195
+
1093
1196
  if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
1094
1197
  LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
1095
1198
  params.ctx_shift = false;
@@ -1101,10 +1204,7 @@ struct common_init_result common_init_from_params(common_params & params) {
1101
1204
 
1102
1205
  const auto cvec = common_control_vector_load(params.control_vectors);
1103
1206
  if (cvec.n_embd == -1) {
1104
- llama_free(lctx);
1105
- llama_model_free(model);
1106
-
1107
- return iparams;
1207
+ return res;
1108
1208
  }
1109
1209
 
1110
1210
  int err = llama_apply_adapter_cvec(
@@ -1115,10 +1215,7 @@ struct common_init_result common_init_from_params(common_params & params) {
1115
1215
  params.control_vector_layer_start,
1116
1216
  params.control_vector_layer_end);
1117
1217
  if (err) {
1118
- llama_free(lctx);
1119
- llama_model_free(model);
1120
-
1121
- return iparams;
1218
+ return res;
1122
1219
  }
1123
1220
  }
1124
1221
 
@@ -1142,10 +1239,7 @@ struct common_init_result common_init_from_params(common_params & params) {
1142
1239
  }
1143
1240
 
1144
1241
  if (!ok) {
1145
- llama_free(lctx);
1146
- llama_model_free(model);
1147
-
1148
- return iparams;
1242
+ return res;
1149
1243
  }
1150
1244
  }
1151
1245
 
@@ -1155,9 +1249,7 @@ struct common_init_result common_init_from_params(common_params & params) {
1155
1249
  lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
1156
1250
  if (lora == nullptr) {
1157
1251
  LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
1158
- llama_free(lctx);
1159
- llama_model_free(model);
1160
- return iparams;
1252
+ return res;
1161
1253
  }
1162
1254
 
1163
1255
  char buf[1024];
@@ -1166,43 +1258,13 @@ struct common_init_result common_init_from_params(common_params & params) {
1166
1258
  la.task_name = buf;
1167
1259
  llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
1168
1260
  la.prompt_prefix = buf;
1169
- iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
1261
+ res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters
1170
1262
  }
1171
1263
 
1172
1264
  if (!params.lora_init_without_apply) {
1173
1265
  common_set_adapter_lora(lctx, params.lora_adapters);
1174
1266
  }
1175
1267
 
1176
- if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
1177
- LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
1178
- params.sampling.ignore_eos = false;
1179
- }
1180
-
1181
- // initialize once
1182
- for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
1183
- if (llama_vocab_is_eog(vocab, i)) {
1184
- LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
1185
- params.sampling.logit_bias_eog.push_back({i, -INFINITY});
1186
- }
1187
- }
1188
-
1189
- if (params.sampling.ignore_eos) {
1190
- // add EOG biases to the active set of logit biases
1191
- params.sampling.logit_bias.insert(
1192
- params.sampling.logit_bias.end(),
1193
- params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
1194
- }
1195
-
1196
- if (params.sampling.penalty_last_n == -1) {
1197
- LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1198
- params.sampling.penalty_last_n = llama_n_ctx(lctx);
1199
- }
1200
-
1201
- if (params.sampling.dry_penalty_last_n == -1) {
1202
- LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1203
- params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
1204
- }
1205
-
1206
1268
  if (params.warmup) {
1207
1269
  LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
1208
1270
 
@@ -1241,12 +1303,11 @@ struct common_init_result common_init_from_params(common_params & params) {
1241
1303
  llama_set_warmup(lctx, false);
1242
1304
  }
1243
1305
 
1244
- iparams.model.reset(model);
1245
- iparams.context.reset(lctx);
1246
-
1247
- return iparams;
1306
+ return res;
1248
1307
  }
1249
1308
 
1309
+ common_init_result::~common_init_result() = default;
1310
+
1250
1311
  std::string get_model_endpoint() {
1251
1312
  const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
1252
1313
  // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
@@ -1255,7 +1316,9 @@ std::string get_model_endpoint() {
1255
1316
  std::string model_endpoint = "https://huggingface.co/";
1256
1317
  if (endpoint_env) {
1257
1318
  model_endpoint = endpoint_env;
1258
- if (model_endpoint.back() != '/') model_endpoint += '/';
1319
+ if (model_endpoint.back() != '/') {
1320
+ model_endpoint += '/';
1321
+ }
1259
1322
  }
1260
1323
  return model_endpoint;
1261
1324
  }