@fugood/llama.node 0.3.13 → 0.3.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (139) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.ts +1 -1
  18. package/package.json +1 -1
  19. package/src/LlamaContext.cpp +98 -76
  20. package/src/LlamaContext.h +1 -1
  21. package/src/common.hpp +1 -2
  22. package/src/llama.cpp/.github/workflows/build.yml +60 -10
  23. package/src/llama.cpp/.github/workflows/server.yml +2 -0
  24. package/src/llama.cpp/common/CMakeLists.txt +3 -3
  25. package/src/llama.cpp/common/arg.cpp +112 -11
  26. package/src/llama.cpp/common/chat.cpp +960 -266
  27. package/src/llama.cpp/common/chat.h +135 -0
  28. package/src/llama.cpp/common/common.cpp +27 -171
  29. package/src/llama.cpp/common/common.h +27 -67
  30. package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
  31. package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
  32. package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +37 -5
  33. package/src/llama.cpp/common/ngram-cache.cpp +1 -0
  34. package/src/llama.cpp/common/sampling.cpp +45 -7
  35. package/src/llama.cpp/common/speculative.cpp +6 -5
  36. package/src/llama.cpp/common/speculative.h +1 -1
  37. package/src/llama.cpp/docs/build.md +45 -7
  38. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
  39. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
  40. package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
  41. package/src/llama.cpp/examples/imatrix/imatrix.cpp +2 -3
  42. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
  43. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  44. package/src/llama.cpp/examples/llava/clip.cpp +373 -107
  45. package/src/llama.cpp/examples/llava/clip.h +19 -3
  46. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
  47. package/src/llama.cpp/examples/llava/llava.cpp +4 -2
  48. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
  49. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
  50. package/src/llama.cpp/examples/main/main.cpp +73 -28
  51. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
  52. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
  53. package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
  54. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
  55. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
  56. package/src/llama.cpp/examples/run/run.cpp +110 -67
  57. package/src/llama.cpp/examples/server/server.cpp +82 -87
  58. package/src/llama.cpp/examples/server/utils.hpp +94 -107
  59. package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
  60. package/src/llama.cpp/examples/tts/tts.cpp +251 -142
  61. package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
  62. package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
  63. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
  64. package/src/llama.cpp/ggml/include/ggml-cpu.h +3 -0
  65. package/src/llama.cpp/ggml/include/ggml.h +5 -1
  66. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
  67. package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
  68. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
  69. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
  70. package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
  71. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
  72. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
  73. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
  74. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
  75. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
  76. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
  77. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +151 -0
  78. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1396 -386
  79. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1432 -151
  80. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +22 -0
  81. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
  82. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
  83. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
  84. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  85. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +15 -2
  86. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
  87. package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
  88. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
  89. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
  90. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
  91. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
  92. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +220 -116
  93. package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
  94. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
  95. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
  96. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  97. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
  98. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
  99. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
  100. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
  101. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
  102. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
  103. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
  104. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
  105. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
  106. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
  107. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +168 -721
  108. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
  109. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
  110. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  111. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  112. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +146 -42
  113. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +13 -3
  114. package/src/llama.cpp/ggml/src/ggml.c +8 -3
  115. package/src/llama.cpp/include/llama.h +19 -5
  116. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
  117. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
  118. package/src/llama.cpp/requirements/requirements-all.txt +1 -0
  119. package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
  120. package/src/llama.cpp/requirements.txt +1 -0
  121. package/src/llama.cpp/src/llama-arch.cpp +21 -0
  122. package/src/llama.cpp/src/llama-arch.h +1 -0
  123. package/src/llama.cpp/src/llama-chat.cpp +1 -0
  124. package/src/llama.cpp/src/llama-grammar.cpp +182 -182
  125. package/src/llama.cpp/src/llama-grammar.h +12 -3
  126. package/src/llama.cpp/src/llama-kv-cache.h +1 -0
  127. package/src/llama.cpp/src/llama-mmap.cpp +11 -1
  128. package/src/llama.cpp/src/llama-model.cpp +69 -5
  129. package/src/llama.cpp/src/llama-sampling.cpp +43 -10
  130. package/src/llama.cpp/src/llama-vocab.cpp +12 -0
  131. package/src/llama.cpp/src/llama.cpp +147 -0
  132. package/src/llama.cpp/tests/test-backend-ops.cpp +166 -110
  133. package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
  134. package/src/llama.cpp/tests/test-chat.cpp +593 -395
  135. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
  136. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
  137. package/src/llama.cpp/Sources/llama/llama.h +0 -4
  138. package/src/llama.cpp/common/chat.hpp +0 -55
  139. /package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +0 -0
@@ -47,27 +47,27 @@ extern "C" {
47
47
  #include <stddef.h> /* For size_t. */
48
48
  #include <stdlib.h>
49
49
 
50
- extern const char *linenoiseEditMore;
50
+ extern const char * linenoiseEditMore;
51
51
 
52
52
  /* The linenoiseState structure represents the state during line editing.
53
53
  * We pass this state to functions implementing specific editing
54
54
  * functionalities. */
55
55
  struct linenoiseState {
56
- int in_completion; /* The user pressed TAB and we are now in completion
56
+ int in_completion; /* The user pressed TAB and we are now in completion
57
57
  * mode, so input is handled by completeLine(). */
58
- size_t completion_idx; /* Index of next completion to propose. */
59
- int ifd; /* Terminal stdin file descriptor. */
60
- int ofd; /* Terminal stdout file descriptor. */
61
- char *buf; /* Edited line buffer. */
62
- size_t buflen; /* Edited line buffer size. */
63
- const char *prompt; /* Prompt to display. */
64
- size_t plen; /* Prompt length. */
65
- size_t pos; /* Current cursor position. */
66
- size_t oldpos; /* Previous refresh cursor position. */
67
- size_t len; /* Current edited line length. */
68
- size_t cols; /* Number of columns in terminal. */
69
- size_t oldrows; /* Rows used by last refrehsed line (multiline mode) */
70
- int history_index; /* The history index we are currently editing. */
58
+ size_t completion_idx; /* Index of next completion to propose. */
59
+ int ifd; /* Terminal stdin file descriptor. */
60
+ int ofd; /* Terminal stdout file descriptor. */
61
+ char * buf; /* Edited line buffer. */
62
+ size_t buflen; /* Edited line buffer size. */
63
+ const char * prompt; /* Prompt to display. */
64
+ size_t plen; /* Prompt length. */
65
+ size_t pos; /* Current cursor position. */
66
+ size_t oldcolpos; /* Previous refresh cursor column position. */
67
+ size_t len; /* Current edited line length. */
68
+ size_t cols; /* Number of columns in terminal. */
69
+ size_t oldrows; /* Rows used by last refreshed line (multiline mode) */
70
+ int history_index; /* The history index we are currently editing. */
71
71
  };
72
72
 
73
73
  struct linenoiseCompletions {
@@ -89,19 +89,20 @@ struct linenoiseCompletions {
89
89
  };
90
90
 
91
91
  /* Non blocking API. */
92
- int linenoiseEditStart(struct linenoiseState *l, int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt);
93
- const char *linenoiseEditFeed(struct linenoiseState *l);
94
- void linenoiseEditStop(struct linenoiseState *l);
95
- void linenoiseHide(struct linenoiseState *l);
96
- void linenoiseShow(struct linenoiseState *l);
92
+ int linenoiseEditStart(struct linenoiseState * l, int stdin_fd, int stdout_fd, char * buf, size_t buflen,
93
+ const char * prompt);
94
+ const char * linenoiseEditFeed(struct linenoiseState * l);
95
+ void linenoiseEditStop(struct linenoiseState * l);
96
+ void linenoiseHide(struct linenoiseState * l);
97
+ void linenoiseShow(struct linenoiseState * l);
97
98
 
98
99
  /* Blocking API. */
99
- const char *linenoise(const char *prompt);
100
- void linenoiseFree(void *ptr);
100
+ const char * linenoise(const char * prompt);
101
+ void linenoiseFree(void * ptr);
101
102
 
102
103
  /* Completion API. */
103
104
  typedef void(linenoiseCompletionCallback)(const char *, linenoiseCompletions *);
104
- typedef const char*(linenoiseHintsCallback)(const char *, int *color, int *bold);
105
+ typedef const char *(linenoiseHintsCallback) (const char *, int * color, int * bold);
105
106
  typedef void(linenoiseFreeHintsCallback)(const char *);
106
107
  void linenoiseSetCompletionCallback(linenoiseCompletionCallback *);
107
108
  void linenoiseSetHintsCallback(linenoiseHintsCallback *);
@@ -109,10 +110,10 @@ void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *);
109
110
  void linenoiseAddCompletion(linenoiseCompletions *, const char *);
110
111
 
111
112
  /* History API. */
112
- int linenoiseHistoryAdd(const char *line);
113
+ int linenoiseHistoryAdd(const char * line);
113
114
  int linenoiseHistorySetMaxLen(int len);
114
- int linenoiseHistorySave(const char *filename);
115
- int linenoiseHistoryLoad(const char *filename);
115
+ int linenoiseHistorySave(const char * filename);
116
+ int linenoiseHistoryLoad(const char * filename);
116
117
 
117
118
  /* Other utilities. */
118
119
  void linenoiseClearScreen(void);
@@ -121,6 +122,14 @@ void linenoisePrintKeyCodes(void);
121
122
  void linenoiseMaskModeEnable(void);
122
123
  void linenoiseMaskModeDisable(void);
123
124
 
125
+ /* Encoding functions. */
126
+ typedef size_t(linenoisePrevCharLen)(const char * buf, size_t buf_len, size_t pos, size_t * col_len);
127
+ typedef size_t(linenoiseNextCharLen)(const char * buf, size_t buf_len, size_t pos, size_t * col_len);
128
+ typedef size_t(linenoiseReadCode)(int fd, char * buf, size_t buf_len, int * c);
129
+
130
+ void linenoiseSetEncodingFunctions(linenoisePrevCharLen * prevCharLenFunc, linenoiseNextCharLen * nextCharLenFunc,
131
+ linenoiseReadCode * readCodeFunc);
132
+
124
133
  #ifdef __cplusplus
125
134
  }
126
135
  #endif
@@ -24,7 +24,7 @@
24
24
  #include <string>
25
25
  #include <vector>
26
26
 
27
- #include "chat-template.hpp"
27
+ #include "chat.h"
28
28
  #include "common.h"
29
29
  #include "json.hpp"
30
30
  #include "linenoise.cpp/linenoise.h"
@@ -113,6 +113,7 @@ class Opt {
113
113
  llama_context_params ctx_params;
114
114
  llama_model_params model_params;
115
115
  std::string model_;
116
+ std::string chat_template_file;
116
117
  std::string user;
117
118
  bool use_jinja = false;
118
119
  int context_size = -1, ngl = -1;
@@ -148,6 +149,16 @@ class Opt {
148
149
  return 0;
149
150
  }
150
151
 
152
+ int handle_option_with_value(int argc, const char ** argv, int & i, std::string & option_value) {
153
+ if (i + 1 >= argc) {
154
+ return 1;
155
+ }
156
+
157
+ option_value = argv[++i];
158
+
159
+ return 0;
160
+ }
161
+
151
162
  int parse(int argc, const char ** argv) {
152
163
  bool options_parsing = true;
153
164
  for (int i = 1, positional_args_i = 0; i < argc; ++i) {
@@ -169,6 +180,11 @@ class Opt {
169
180
  verbose = true;
170
181
  } else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
171
182
  use_jinja = true;
183
+ } else if (options_parsing && strcmp(argv[i], "--chat-template-file") == 0){
184
+ if (handle_option_with_value(argc, argv, i, chat_template_file) == 1) {
185
+ return 1;
186
+ }
187
+ use_jinja = true;
172
188
  } else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
173
189
  help = true;
174
190
  return 0;
@@ -207,6 +223,11 @@ class Opt {
207
223
  "Options:\n"
208
224
  " -c, --context-size <value>\n"
209
225
  " Context size (default: %d)\n"
226
+ " --chat-template-file <path>\n"
227
+ " Path to the file containing the chat template to use with the model.\n"
228
+ " Only supports jinja templates and implicitly sets the --jinja flag.\n"
229
+ " --jinja\n"
230
+ " Use jinja templating for the chat template of the model\n"
210
231
  " -n, -ngl, --ngl <value>\n"
211
232
  " Number of GPU layers (default: %d)\n"
212
233
  " --temp <value>\n"
@@ -261,13 +282,12 @@ static int get_terminal_width() {
261
282
  #endif
262
283
  }
263
284
 
264
- #ifdef LLAMA_USE_CURL
265
285
  class File {
266
286
  public:
267
287
  FILE * file = nullptr;
268
288
 
269
289
  FILE * open(const std::string & filename, const char * mode) {
270
- file = fopen(filename.c_str(), mode);
290
+ file = ggml_fopen(filename.c_str(), mode);
271
291
 
272
292
  return file;
273
293
  }
@@ -303,6 +323,20 @@ class File {
303
323
  return 0;
304
324
  }
305
325
 
326
+ std::string to_string() {
327
+ fseek(file, 0, SEEK_END);
328
+ const size_t size = ftell(file);
329
+ fseek(file, 0, SEEK_SET);
330
+ std::string out;
331
+ out.resize(size);
332
+ const size_t read_size = fread(&out[0], 1, size, file);
333
+ if (read_size != size) {
334
+ printe("Error reading file: %s", strerror(errno));
335
+ }
336
+
337
+ return out;
338
+ }
339
+
306
340
  ~File() {
307
341
  if (fd >= 0) {
308
342
  # ifdef _WIN32
@@ -327,6 +361,7 @@ class File {
327
361
  # endif
328
362
  };
329
363
 
364
+ #ifdef LLAMA_USE_CURL
330
365
  class HttpClient {
331
366
  public:
332
367
  int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
@@ -557,7 +592,7 @@ class LlamaData {
557
592
  llama_model_ptr model;
558
593
  llama_sampler_ptr sampler;
559
594
  llama_context_ptr context;
560
- std::vector<llama_chat_message> messages;
595
+ std::vector<llama_chat_message> messages; // TODO: switch to common_chat_msg
561
596
  std::list<std::string> msg_strs;
562
597
  std::vector<char> fmtted;
563
598
 
@@ -834,44 +869,23 @@ static void add_message(const char * role, const std::string & text, LlamaData &
834
869
  }
835
870
 
836
871
  // Function to apply the chat template and resize `formatted` if needed
837
- static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
838
- if (use_jinja) {
839
- json messages = json::array();
840
- for (const auto & msg : llama_data.messages) {
841
- messages.push_back({
842
- {"role", msg.role},
843
- {"content", msg.content},
844
- });
845
- }
846
- try {
847
- minja::chat_template_inputs tmpl_inputs;
848
- tmpl_inputs.messages = messages;
849
- tmpl_inputs.add_generation_prompt = append;
850
-
851
- minja::chat_template_options tmpl_opts;
852
- tmpl_opts.use_bos_token = false;
853
- tmpl_opts.use_eos_token = false;
854
-
855
- auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
856
- llama_data.fmtted.resize(result.size() + 1);
857
- memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
858
- return result.size();
859
- } catch (const std::exception & e) {
860
- printe("failed to render the chat template: %s\n", e.what());
861
- return -1;
862
- }
863
- }
864
- int result = llama_chat_apply_template(
865
- tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
866
- append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
867
- if (append && result > static_cast<int>(llama_data.fmtted.size())) {
868
- llama_data.fmtted.resize(result);
869
- result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
870
- llama_data.messages.size(), append, llama_data.fmtted.data(),
871
- llama_data.fmtted.size());
872
- }
873
-
874
- return result;
872
+ static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) {
873
+ common_chat_templates_inputs inputs;
874
+ for (const auto & msg : llama_data.messages) {
875
+ common_chat_msg cmsg;
876
+ cmsg.role = msg.role;
877
+ cmsg.content = msg.content;
878
+ inputs.messages.push_back(cmsg);
879
+ }
880
+ inputs.add_generation_prompt = append;
881
+ inputs.use_jinja = use_jinja;
882
+
883
+ auto chat_params = common_chat_templates_apply(tmpls, inputs);
884
+ // TODO: use other params for tool calls.
885
+ auto result = chat_params.prompt;
886
+ llama_data.fmtted.resize(result.size() + 1);
887
+ memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
888
+ return result.size();
875
889
  }
876
890
 
877
891
  // Function to tokenize the prompt
@@ -963,7 +977,8 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
963
977
  }
964
978
 
965
979
  static int read_user_input(std::string & user_input) {
966
- static const char * prompt_prefix = "> ";
980
+ static const char * prompt_prefix_env = std::getenv("LLAMA_PROMPT_PREFIX");
981
+ static const char * prompt_prefix = prompt_prefix_env ? prompt_prefix_env : "> ";
967
982
  #ifdef WIN32
968
983
  printf("\r" LOG_CLR_TO_EOL LOG_COL_DEFAULT "%s", prompt_prefix);
969
984
 
@@ -1015,8 +1030,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
1015
1030
  }
1016
1031
 
1017
1032
  // Helper function to apply the chat template and handle errors
1018
- static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
1019
- const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
1033
+ static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
1034
+ const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja);
1020
1035
  if (new_len < 0) {
1021
1036
  printe("failed to apply the chat template\n");
1022
1037
  return -1;
@@ -1074,40 +1089,68 @@ static int get_user_input(std::string & user_input, const std::string & user) {
1074
1089
  return 0;
1075
1090
  }
1076
1091
 
1092
+ // Reads a chat template file to be used
1093
+ static std::string read_chat_template_file(const std::string & chat_template_file) {
1094
+ File file;
1095
+ if (!file.open(chat_template_file, "r")) {
1096
+ printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
1097
+ return "";
1098
+ }
1099
+
1100
+ return file.to_string();
1101
+ }
1102
+
1103
+ static int process_user_message(const Opt & opt, const std::string & user_input, LlamaData & llama_data,
1104
+ const common_chat_templates_ptr & chat_templates, int & prev_len,
1105
+ const bool stdout_a_terminal) {
1106
+ add_message("user", opt.user.empty() ? user_input : opt.user, llama_data);
1107
+ int new_len;
1108
+ if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, opt.use_jinja) < 0) {
1109
+ return 1;
1110
+ }
1111
+
1112
+ std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
1113
+ std::string response;
1114
+ if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
1115
+ return 1;
1116
+ }
1117
+
1118
+ if (!opt.user.empty()) {
1119
+ return 2;
1120
+ }
1121
+
1122
+ add_message("assistant", response, llama_data);
1123
+ if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, opt.use_jinja) < 0) {
1124
+ return 1;
1125
+ }
1126
+
1127
+ return 0;
1128
+ }
1129
+
1077
1130
  // Main chat loop function
1078
- static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
1131
+ static int chat_loop(LlamaData & llama_data, const Opt & opt) {
1079
1132
  int prev_len = 0;
1080
1133
  llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
1081
- auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), "");
1082
- GGML_ASSERT(chat_templates.template_default);
1134
+ std::string chat_template;
1135
+ if (!opt.chat_template_file.empty()) {
1136
+ chat_template = read_chat_template_file(opt.chat_template_file);
1137
+ }
1138
+
1139
+ common_chat_templates_ptr chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template);
1083
1140
  static const bool stdout_a_terminal = is_stdout_a_terminal();
1084
1141
  while (true) {
1085
1142
  // Get user input
1086
1143
  std::string user_input;
1087
- if (get_user_input(user_input, user) == 1) {
1144
+ if (get_user_input(user_input, opt.user) == 1) {
1088
1145
  return 0;
1089
1146
  }
1090
1147
 
1091
- add_message("user", user.empty() ? user_input : user, llama_data);
1092
- int new_len;
1093
- if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) {
1094
- return 1;
1095
- }
1096
-
1097
- std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
1098
- std::string response;
1099
- if (generate_response(llama_data, prompt, response, stdout_a_terminal)) {
1148
+ const int ret = process_user_message(opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal);
1149
+ if (ret == 1) {
1100
1150
  return 1;
1101
- }
1102
-
1103
- if (!user.empty()) {
1151
+ } else if (ret == 2) {
1104
1152
  break;
1105
1153
  }
1106
-
1107
- add_message("assistant", response, llama_data);
1108
- if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) {
1109
- return 1;
1110
- }
1111
1154
  }
1112
1155
 
1113
1156
  return 0;
@@ -1165,7 +1208,7 @@ int main(int argc, const char ** argv) {
1165
1208
  return 1;
1166
1209
  }
1167
1210
 
1168
- if (chat_loop(llama_data, opt.user, opt.use_jinja)) {
1211
+ if (chat_loop(llama_data, opt)) {
1169
1212
  return 1;
1170
1213
  }
1171
1214