@fugood/llama.node 0.3.9 → 0.3.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.js +2 -2
  18. package/lib/binding.ts +47 -8
  19. package/lib/index.js +21 -1
  20. package/lib/index.ts +31 -1
  21. package/package.json +12 -3
  22. package/src/LlamaCompletionWorker.cpp +33 -6
  23. package/src/LlamaCompletionWorker.h +3 -1
  24. package/src/LlamaContext.cpp +336 -28
  25. package/src/LlamaContext.h +2 -0
  26. package/src/common.hpp +19 -2
  27. package/src/llama.cpp/.github/workflows/build.yml +289 -107
  28. package/src/llama.cpp/.github/workflows/close-issue.yml +1 -1
  29. package/src/llama.cpp/.github/workflows/docker.yml +2 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +25 -2
  31. package/src/llama.cpp/CMakeLists.txt +10 -19
  32. package/src/llama.cpp/cmake/build-info.cmake +1 -1
  33. package/src/llama.cpp/common/CMakeLists.txt +32 -0
  34. package/src/llama.cpp/common/arg.cpp +66 -16
  35. package/src/llama.cpp/common/chat-template.hpp +515 -0
  36. package/src/llama.cpp/common/chat.cpp +966 -0
  37. package/src/llama.cpp/common/chat.hpp +52 -0
  38. package/src/llama.cpp/common/common.cpp +159 -36
  39. package/src/llama.cpp/common/common.h +56 -14
  40. package/src/llama.cpp/common/json-schema-to-grammar.cpp +46 -66
  41. package/src/llama.cpp/common/json-schema-to-grammar.h +15 -1
  42. package/src/llama.cpp/common/llguidance.cpp +270 -0
  43. package/src/llama.cpp/common/log.cpp +1 -10
  44. package/src/llama.cpp/common/log.h +10 -0
  45. package/src/llama.cpp/common/minja.hpp +2868 -0
  46. package/src/llama.cpp/common/sampling.cpp +22 -1
  47. package/src/llama.cpp/common/sampling.h +3 -0
  48. package/src/llama.cpp/docs/build.md +54 -9
  49. package/src/llama.cpp/examples/export-lora/export-lora.cpp +12 -2
  50. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +1 -1
  51. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  52. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +59 -0
  53. package/src/llama.cpp/examples/llava/clip.cpp +133 -14
  54. package/src/llama.cpp/examples/llava/clip.h +2 -0
  55. package/src/llama.cpp/examples/llava/llava.cpp +22 -8
  56. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +9 -1
  57. package/src/llama.cpp/examples/main/main.cpp +26 -25
  58. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +136 -137
  59. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +18 -4
  60. package/src/llama.cpp/examples/run/run.cpp +224 -69
  61. package/src/llama.cpp/examples/server/server.cpp +252 -81
  62. package/src/llama.cpp/examples/server/utils.hpp +73 -21
  63. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +6 -4
  64. package/src/llama.cpp/examples/simple-cmake-pkg/CMakeLists.txt +11 -0
  65. package/src/llama.cpp/ggml/CMakeLists.txt +78 -1
  66. package/src/llama.cpp/ggml/include/ggml.h +1 -1
  67. package/src/llama.cpp/ggml/src/CMakeLists.txt +21 -4
  68. package/src/llama.cpp/ggml/src/ggml-alloc.c +1 -13
  69. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +91 -78
  70. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +7 -7
  71. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -1
  72. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  73. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +46 -0
  74. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +16 -1
  75. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +1 -1
  76. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +28 -8
  77. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +5 -7
  78. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +33 -23
  79. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.hpp +1 -5
  80. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +323 -121
  81. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +13 -3
  82. package/src/llama.cpp/ggml/src/ggml.c +23 -13
  83. package/src/llama.cpp/include/llama.h +14 -1
  84. package/src/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +112 -0
  85. package/src/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +46 -0
  86. package/src/llama.cpp/src/CMakeLists.txt +1 -1
  87. package/src/llama.cpp/src/llama-arch.cpp +7 -2
  88. package/src/llama.cpp/src/llama-arch.h +3 -1
  89. package/src/llama.cpp/src/llama-chat.cpp +11 -2
  90. package/src/llama.cpp/src/llama-chat.h +1 -0
  91. package/src/llama.cpp/src/llama-grammar.cpp +86 -6
  92. package/src/llama.cpp/src/llama-grammar.h +22 -1
  93. package/src/llama.cpp/src/llama-mmap.cpp +1 -0
  94. package/src/llama.cpp/src/llama-model-loader.cpp +1 -1
  95. package/src/llama.cpp/src/llama-model.cpp +76 -6
  96. package/src/llama.cpp/src/llama-sampling.cpp +47 -4
  97. package/src/llama.cpp/src/llama-vocab.cpp +10 -4
  98. package/src/llama.cpp/src/llama.cpp +181 -123
  99. package/src/llama.cpp/tests/CMakeLists.txt +4 -0
  100. package/src/llama.cpp/tests/test-backend-ops.cpp +158 -57
  101. package/src/llama.cpp/tests/test-chat-template.cpp +154 -31
  102. package/src/llama.cpp/tests/test-chat.cpp +607 -0
  103. package/src/llama.cpp/tests/test-grammar-integration.cpp +2 -2
  104. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +1140 -0
  105. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +1 -1
  106. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +0 -32
@@ -45,6 +45,7 @@ extern "C" {
45
45
  #endif
46
46
 
47
47
  #include <stddef.h> /* For size_t. */
48
+ #include <stdlib.h>
48
49
 
49
50
  extern const char *linenoiseEditMore;
50
51
 
@@ -69,10 +70,23 @@ struct linenoiseState {
69
70
  int history_index; /* The history index we are currently editing. */
70
71
  };
71
72
 
72
- typedef struct linenoiseCompletions {
73
- size_t len;
74
- char **cvec;
75
- } linenoiseCompletions;
73
+ struct linenoiseCompletions {
74
+ size_t len = 0;
75
+ char ** cvec = nullptr;
76
+ bool to_free = true;
77
+
78
+ ~linenoiseCompletions() {
79
+ if (!to_free) {
80
+ return;
81
+ }
82
+
83
+ for (size_t i = 0; i < len; ++i) {
84
+ free(cvec[i]);
85
+ }
86
+
87
+ free(cvec);
88
+ }
89
+ };
76
90
 
77
91
  /* Non blocking API. */
78
92
  int linenoiseEditStart(struct linenoiseState *l, int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt);
@@ -24,14 +24,16 @@
24
24
  #include <string>
25
25
  #include <vector>
26
26
 
27
+ #include "chat-template.hpp"
27
28
  #include "common.h"
28
29
  #include "json.hpp"
29
30
  #include "linenoise.cpp/linenoise.h"
30
31
  #include "llama-cpp.h"
32
+ #include "log.h"
31
33
 
32
34
  #if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32)
33
35
  [[noreturn]] static void sigint_handler(int) {
34
- printf("\n\033[0m");
36
+ printf("\n" LOG_COL_DEFAULT);
35
37
  exit(0); // not ideal, but it's the only way to guarantee exit in all cases
36
38
  }
37
39
  #endif
@@ -64,6 +66,13 @@ static int printe(const char * fmt, ...) {
64
66
  return ret;
65
67
  }
66
68
 
69
+ static std::string strftime_fmt(const char * fmt, const std::tm & tm) {
70
+ std::ostringstream oss;
71
+ oss << std::put_time(&tm, fmt);
72
+
73
+ return oss.str();
74
+ }
75
+
67
76
  class Opt {
68
77
  public:
69
78
  int init(int argc, const char ** argv) {
@@ -105,6 +114,7 @@ class Opt {
105
114
  llama_model_params model_params;
106
115
  std::string model_;
107
116
  std::string user;
117
+ bool use_jinja = false;
108
118
  int context_size = -1, ngl = -1;
109
119
  float temperature = -1;
110
120
  bool verbose = false;
@@ -145,7 +155,8 @@ class Opt {
145
155
  if (handle_option_with_value(argc, argv, i, context_size) == 1) {
146
156
  return 1;
147
157
  }
148
- } else if (options_parsing && (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0)) {
158
+ } else if (options_parsing &&
159
+ (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "-ngl") == 0 || strcmp(argv[i], "--ngl") == 0)) {
149
160
  if (handle_option_with_value(argc, argv, i, ngl) == 1) {
150
161
  return 1;
151
162
  }
@@ -156,6 +167,8 @@ class Opt {
156
167
  } else if (options_parsing &&
157
168
  (parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
158
169
  verbose = true;
170
+ } else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
171
+ use_jinja = true;
159
172
  } else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
160
173
  help = true;
161
174
  return 0;
@@ -176,6 +189,10 @@ class Opt {
176
189
  }
177
190
  }
178
191
 
192
+ if (model_.empty()){
193
+ return 1;
194
+ }
195
+
179
196
  return 0;
180
197
  }
181
198
 
@@ -190,7 +207,7 @@ class Opt {
190
207
  "Options:\n"
191
208
  " -c, --context-size <value>\n"
192
209
  " Context size (default: %d)\n"
193
- " -n, --ngl <value>\n"
210
+ " -n, -ngl, --ngl <value>\n"
194
211
  " Number of GPU layers (default: %d)\n"
195
212
  " --temp <value>\n"
196
213
  " Temperature (default: %.1f)\n"
@@ -314,6 +331,10 @@ class HttpClient {
314
331
  public:
315
332
  int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
316
333
  const bool progress, std::string * response_str = nullptr) {
334
+ if (std::filesystem::exists(output_file)) {
335
+ return 0;
336
+ }
337
+
317
338
  std::string output_file_partial;
318
339
  curl = curl_easy_init();
319
340
  if (!curl) {
@@ -341,7 +362,11 @@ class HttpClient {
341
362
  data.file_size = set_resume_point(output_file_partial);
342
363
  set_progress_options(progress, data);
343
364
  set_headers(headers);
344
- perform(url);
365
+ CURLcode res = perform(url);
366
+ if (res != CURLE_OK){
367
+ printe("Fetching resource '%s' failed: %s\n", url.c_str(), curl_easy_strerror(res));
368
+ return 1;
369
+ }
345
370
  if (!output_file.empty()) {
346
371
  std::filesystem::rename(output_file_partial, output_file);
347
372
  }
@@ -406,16 +431,12 @@ class HttpClient {
406
431
  }
407
432
  }
408
433
 
409
- void perform(const std::string & url) {
410
- CURLcode res;
434
+ CURLcode perform(const std::string & url) {
411
435
  curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
412
436
  curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
413
437
  curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https");
414
438
  curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L);
415
- res = curl_easy_perform(curl);
416
- if (res != CURLE_OK) {
417
- printe("curl_easy_perform() failed: %s\n", curl_easy_strerror(res));
418
- }
439
+ return curl_easy_perform(curl);
419
440
  }
420
441
 
421
442
  static std::string human_readable_time(double seconds) {
@@ -553,13 +574,14 @@ class LlamaData {
553
574
  }
554
575
 
555
576
  sampler = initialize_sampler(opt);
577
+
556
578
  return 0;
557
579
  }
558
580
 
559
581
  private:
560
582
  #ifdef LLAMA_USE_CURL
561
- int download(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
562
- const bool progress, std::string * response_str = nullptr) {
583
+ int download(const std::string & url, const std::string & output_file, const bool progress,
584
+ const std::vector<std::string> & headers = {}, std::string * response_str = nullptr) {
563
585
  HttpClient http;
564
586
  if (http.init(url, headers, output_file, progress, response_str)) {
565
587
  return 1;
@@ -568,48 +590,85 @@ class LlamaData {
568
590
  return 0;
569
591
  }
570
592
  #else
571
- int download(const std::string &, const std::vector<std::string> &, const std::string &, const bool,
593
+ int download(const std::string &, const std::string &, const bool, const std::vector<std::string> & = {},
572
594
  std::string * = nullptr) {
573
595
  printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
596
+
574
597
  return 1;
575
598
  }
576
599
  #endif
577
600
 
578
- int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
601
+ // Helper function to handle model tag extraction and URL construction
602
+ std::pair<std::string, std::string> extract_model_and_tag(std::string & model, const std::string & base_url) {
603
+ std::string model_tag = "latest";
604
+ const size_t colon_pos = model.find(':');
605
+ if (colon_pos != std::string::npos) {
606
+ model_tag = model.substr(colon_pos + 1);
607
+ model = model.substr(0, colon_pos);
608
+ }
609
+
610
+ std::string url = base_url + model + "/manifests/" + model_tag;
611
+
612
+ return { model, url };
613
+ }
614
+
615
+ // Helper function to download and parse the manifest
616
+ int download_and_parse_manifest(const std::string & url, const std::vector<std::string> & headers,
617
+ nlohmann::json & manifest) {
618
+ std::string manifest_str;
619
+ int ret = download(url, "", false, headers, &manifest_str);
620
+ if (ret) {
621
+ return ret;
622
+ }
623
+
624
+ manifest = nlohmann::json::parse(manifest_str);
625
+
626
+ return 0;
627
+ }
628
+
629
+ int huggingface_dl(std::string & model, const std::string & bn) {
579
630
  // Find the second occurrence of '/' after protocol string
580
631
  size_t pos = model.find('/');
581
632
  pos = model.find('/', pos + 1);
633
+ std::string hfr, hff;
634
+ std::vector<std::string> headers = { "User-Agent: llama-cpp", "Accept: application/json" };
635
+ std::string url;
636
+
582
637
  if (pos == std::string::npos) {
583
- return 1;
638
+ auto [model_name, manifest_url] = extract_model_and_tag(model, "https://huggingface.co/v2/");
639
+ hfr = model_name;
640
+
641
+ nlohmann::json manifest;
642
+ int ret = download_and_parse_manifest(manifest_url, headers, manifest);
643
+ if (ret) {
644
+ return ret;
645
+ }
646
+
647
+ hff = manifest["ggufFile"]["rfilename"];
648
+ } else {
649
+ hfr = model.substr(0, pos);
650
+ hff = model.substr(pos + 1);
584
651
  }
585
652
 
586
- const std::string hfr = model.substr(0, pos);
587
- const std::string hff = model.substr(pos + 1);
588
- const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
589
- return download(url, headers, bn, true);
653
+ url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
654
+
655
+ return download(url, bn, true, headers);
590
656
  }
591
657
 
592
- int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn) {
658
+ int ollama_dl(std::string & model, const std::string & bn) {
659
+ const std::vector<std::string> headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" };
593
660
  if (model.find('/') == std::string::npos) {
594
661
  model = "library/" + model;
595
662
  }
596
663
 
597
- std::string model_tag = "latest";
598
- size_t colon_pos = model.find(':');
599
- if (colon_pos != std::string::npos) {
600
- model_tag = model.substr(colon_pos + 1);
601
- model = model.substr(0, colon_pos);
602
- }
603
-
604
- std::string manifest_url = "https://registry.ollama.ai/v2/" + model + "/manifests/" + model_tag;
605
- std::string manifest_str;
606
- const int ret = download(manifest_url, headers, "", false, &manifest_str);
664
+ auto [model_name, manifest_url] = extract_model_and_tag(model, "https://registry.ollama.ai/v2/");
665
+ nlohmann::json manifest;
666
+ int ret = download_and_parse_manifest(manifest_url, {}, manifest);
607
667
  if (ret) {
608
668
  return ret;
609
669
  }
610
670
 
611
- nlohmann::json manifest = nlohmann::json::parse(manifest_str);
612
- std::string layer;
671
+ std::string layer;
613
672
  for (const auto & l : manifest["layers"]) {
614
673
  if (l["mediaType"] == "application/vnd.ollama.image.model") {
615
674
  layer = l["digest"];
@@ -617,8 +676,67 @@ class LlamaData {
617
676
  }
618
677
  }
619
678
 
620
- std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer;
621
- return download(blob_url, headers, bn, true);
679
+ std::string blob_url = "https://registry.ollama.ai/v2/" + model_name + "/blobs/" + layer;
680
+
681
+ return download(blob_url, bn, true, headers);
682
+ }
683
+
684
+ int github_dl(const std::string & model, const std::string & bn) {
685
+ std::string repository = model;
686
+ std::string branch = "main";
687
+ const size_t at_pos = model.find('@');
688
+ if (at_pos != std::string::npos) {
689
+ repository = model.substr(0, at_pos);
690
+ branch = model.substr(at_pos + 1);
691
+ }
692
+
693
+ const std::vector<std::string> repo_parts = string_split(repository, "/");
694
+ if (repo_parts.size() < 3) {
695
+ printe("Invalid GitHub repository format\n");
696
+ return 1;
697
+ }
698
+
699
+ const std::string & org = repo_parts[0];
700
+ const std::string & project = repo_parts[1];
701
+ std::string url = "https://raw.githubusercontent.com/" + org + "/" + project + "/" + branch;
702
+ for (size_t i = 2; i < repo_parts.size(); ++i) {
703
+ url += "/" + repo_parts[i];
704
+ }
705
+
706
+ return download(url, bn, true);
707
+ }
708
+
709
+ int s3_dl(const std::string & model, const std::string & bn) {
710
+ const size_t slash_pos = model.find('/');
711
+ if (slash_pos == std::string::npos) {
712
+ return 1;
713
+ }
714
+
715
+ const std::string bucket = model.substr(0, slash_pos);
716
+ const std::string key = model.substr(slash_pos + 1);
717
+ const char * access_key = std::getenv("AWS_ACCESS_KEY_ID");
718
+ const char * secret_key = std::getenv("AWS_SECRET_ACCESS_KEY");
719
+ if (!access_key || !secret_key) {
720
+ printe("AWS credentials not found in environment\n");
721
+ return 1;
722
+ }
723
+
724
+ // Generate AWS Signature Version 4 headers
725
+ // (Implementation requires HMAC-SHA256 and date handling)
726
+ // Get current timestamp
727
+ const time_t now = time(nullptr);
728
+ const tm tm = *gmtime(&now);
729
+ const std::string date = strftime_fmt("%Y%m%d", tm);
730
+ const std::string datetime = strftime_fmt("%Y%m%dT%H%M%SZ", tm);
731
+ const std::vector<std::string> headers = {
732
+ "Authorization: AWS4-HMAC-SHA256 Credential=" + std::string(access_key) + "/" + date +
733
+ "/us-east-1/s3/aws4_request",
734
+ "x-amz-content-sha256: UNSIGNED-PAYLOAD", "x-amz-date: " + datetime
735
+ };
736
+
737
+ const std::string url = "https://" + bucket + ".s3.amazonaws.com/" + key;
738
+
739
+ return download(url, bn, true, headers);
622
740
  }
623
741
 
624
742
  std::string basename(const std::string & path) {
@@ -630,37 +748,44 @@ class LlamaData {
630
748
  return path.substr(pos + 1);
631
749
  }
632
750
 
633
- int remove_proto(std::string & model_) {
634
- const std::string::size_type pos = model_.find("://");
751
+ int rm_until_substring(std::string & model_, const std::string & substring) {
752
+ const std::string::size_type pos = model_.find(substring);
635
753
  if (pos == std::string::npos) {
636
754
  return 1;
637
755
  }
638
756
 
639
- model_ = model_.substr(pos + 3); // Skip past "://"
757
+ model_ = model_.substr(pos + substring.size()); // Skip past the substring
640
758
  return 0;
641
759
  }
642
760
 
643
761
  int resolve_model(std::string & model_) {
644
762
  int ret = 0;
645
763
  if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) {
646
- remove_proto(model_);
764
+ rm_until_substring(model_, "://");
647
765
 
648
766
  return ret;
649
767
  }
650
768
 
651
- const std::string bn = basename(model_);
652
- const std::vector<std::string> headers = { "--header",
653
- "Accept: application/vnd.docker.distribution.manifest.v2+json" };
654
- if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) {
655
- remove_proto(model_);
656
- ret = huggingface_dl(model_, headers, bn);
657
- } else if (string_starts_with(model_, "ollama://")) {
658
- remove_proto(model_);
659
- ret = ollama_dl(model_, headers, bn);
660
- } else if (string_starts_with(model_, "https://")) {
661
- download(model_, headers, bn, true);
662
- } else {
663
- ret = ollama_dl(model_, headers, bn);
769
+ const std::string bn = basename(model_);
770
+ if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://") ||
771
+ string_starts_with(model_, "hf.co/")) {
772
+ rm_until_substring(model_, "hf.co/");
773
+ rm_until_substring(model_, "://");
774
+ ret = huggingface_dl(model_, bn);
775
+ } else if ((string_starts_with(model_, "https://") || string_starts_with(model_, "http://")) &&
776
+ !string_starts_with(model_, "https://ollama.com/library/")) {
777
+ ret = download(model_, bn, true);
778
+ } else if (string_starts_with(model_, "github:") || string_starts_with(model_, "github://")) {
779
+ rm_until_substring(model_, "github:");
780
+ rm_until_substring(model_, "://");
781
+ ret = github_dl(model_, bn);
782
+ } else if (string_starts_with(model_, "s3://")) {
783
+ rm_until_substring(model_, "://");
784
+ ret = s3_dl(model_, bn);
785
+ } else { // ollama:// or nothing
786
+ rm_until_substring(model_, "ollama.com/library/");
787
+ rm_until_substring(model_, "://");
788
+ ret = ollama_dl(model_, bn);
664
789
  }
665
790
 
666
791
  model_ = bn;
@@ -713,13 +838,39 @@ static void add_message(const char * role, const std::string & text, LlamaData &
713
838
  }
714
839
 
715
840
  // Function to apply the chat template and resize `formatted` if needed
716
- static int apply_chat_template(LlamaData & llama_data, const bool append) {
841
+ static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
842
+ if (use_jinja) {
843
+ json messages = json::array();
844
+ for (const auto & msg : llama_data.messages) {
845
+ messages.push_back({
846
+ {"role", msg.role},
847
+ {"content", msg.content},
848
+ });
849
+ }
850
+ try {
851
+ minja::chat_template_inputs tmpl_inputs;
852
+ tmpl_inputs.messages = messages;
853
+ tmpl_inputs.add_generation_prompt = append;
854
+
855
+ minja::chat_template_options tmpl_opts;
856
+ tmpl_opts.use_bos_token = false;
857
+ tmpl_opts.use_eos_token = false;
858
+
859
+ auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
860
+ llama_data.fmtted.resize(result.size() + 1);
861
+ memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
862
+ return result.size();
863
+ } catch (const std::exception & e) {
864
+ printe("failed to render the chat template: %s\n", e.what());
865
+ return -1;
866
+ }
867
+ }
717
868
  int result = llama_chat_apply_template(
718
- llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append,
869
+ tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
719
870
  append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
720
871
  if (append && result > static_cast<int>(llama_data.fmtted.size())) {
721
872
  llama_data.fmtted.resize(result);
722
- result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(),
873
+ result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
723
874
  llama_data.messages.size(), append, llama_data.fmtted.data(),
724
875
  llama_data.fmtted.size());
725
876
  }
@@ -729,10 +880,12 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) {
729
880
 
730
881
  // Function to tokenize the prompt
731
882
  static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
732
- std::vector<llama_token> & prompt_tokens) {
733
- const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
883
+ std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
884
+ const bool is_first = llama_get_kv_cache_used_cells(llama_data.context.get()) == 0;
885
+
886
+ const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
734
887
  prompt_tokens.resize(n_prompt_tokens);
735
- if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
888
+ if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first,
736
889
  true) < 0) {
737
890
  printe("failed to tokenize the prompt\n");
738
891
  return -1;
@@ -746,7 +899,7 @@ static int check_context_size(const llama_context_ptr & ctx, const llama_batch &
746
899
  const int n_ctx = llama_n_ctx(ctx.get());
747
900
  const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get());
748
901
  if (n_ctx_used + batch.n_tokens > n_ctx) {
749
- printf("\033[0m\n");
902
+ printf(LOG_COL_DEFAULT "\n");
750
903
  printe("context size exceeded\n");
751
904
  return 1;
752
905
  }
@@ -778,7 +931,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
778
931
  const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get());
779
932
 
780
933
  std::vector<llama_token> tokens;
781
- if (tokenize_prompt(vocab, prompt, tokens) < 0) {
934
+ if (tokenize_prompt(vocab, prompt, tokens, llama_data) < 0) {
782
935
  return 1;
783
936
  }
784
937
 
@@ -809,7 +962,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
809
962
  batch = llama_batch_get_one(&new_token_id, 1);
810
963
  }
811
964
 
812
- printf("\033[0m");
965
+ printf(LOG_COL_DEFAULT);
813
966
  return 0;
814
967
  }
815
968
 
@@ -818,7 +971,7 @@ static int read_user_input(std::string & user_input) {
818
971
  #ifdef WIN32
819
972
  printf(
820
973
  "\r%*s"
821
- "\r\033[0m%s",
974
+ "\r" LOG_COL_DEFAULT "%s",
822
975
  get_terminal_width(), " ", prompt_prefix);
823
976
 
824
977
  std::getline(std::cin, user_input);
@@ -855,7 +1008,7 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
855
1008
  const bool stdout_a_terminal) {
856
1009
  // Set response color
857
1010
  if (stdout_a_terminal) {
858
- printf("\033[33m");
1011
+ printf(LOG_COL_YELLOW);
859
1012
  }
860
1013
 
861
1014
  if (generate(llama_data, prompt, response)) {
@@ -864,13 +1017,13 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
864
1017
  }
865
1018
 
866
1019
  // End response with color reset and newline
867
- printf("\n%s", stdout_a_terminal ? "\033[0m" : "");
1020
+ printf("\n%s", stdout_a_terminal ? LOG_COL_DEFAULT : "");
868
1021
  return 0;
869
1022
  }
870
1023
 
871
1024
  // Helper function to apply the chat template and handle errors
872
- static int apply_chat_template_with_error_handling(LlamaData & llama_data, const bool append, int & output_length) {
873
- const int new_len = apply_chat_template(llama_data, append);
1025
+ static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
1026
+ const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
874
1027
  if (new_len < 0) {
875
1028
  printe("failed to apply the chat template\n");
876
1029
  return -1;
@@ -929,9 +1082,11 @@ static int get_user_input(std::string & user_input, const std::string & user) {
929
1082
  }
930
1083
 
931
1084
  // Main chat loop function
932
- static int chat_loop(LlamaData & llama_data, const std::string & user) {
1085
+ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
933
1086
  int prev_len = 0;
934
1087
  llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
1088
+ auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), "");
1089
+ GGML_ASSERT(chat_templates.template_default);
935
1090
  static const bool stdout_a_terminal = is_stdout_a_terminal();
936
1091
  while (true) {
937
1092
  // Get user input
@@ -942,7 +1097,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
942
1097
 
943
1098
  add_message("user", user.empty() ? user_input : user, llama_data);
944
1099
  int new_len;
945
- if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) {
1100
+ if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) {
946
1101
  return 1;
947
1102
  }
948
1103
 
@@ -957,7 +1112,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
957
1112
  }
958
1113
 
959
1114
  add_message("assistant", response, llama_data);
960
- if (apply_chat_template_with_error_handling(llama_data, false, prev_len) < 0) {
1115
+ if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) {
961
1116
  return 1;
962
1117
  }
963
1118
  }
@@ -1017,7 +1172,7 @@ int main(int argc, const char ** argv) {
1017
1172
  return 1;
1018
1173
  }
1019
1174
 
1020
- if (chat_loop(llama_data, opt.user)) {
1175
+ if (chat_loop(llama_data, opt.user, opt.use_jinja)) {
1021
1176
  return 1;
1022
1177
  }
1023
1178