@fugood/llama.node 0.3.13 → 0.3.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (139) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.ts +1 -1
  18. package/package.json +1 -1
  19. package/src/LlamaContext.cpp +98 -76
  20. package/src/LlamaContext.h +1 -1
  21. package/src/common.hpp +1 -2
  22. package/src/llama.cpp/.github/workflows/build.yml +60 -10
  23. package/src/llama.cpp/.github/workflows/server.yml +2 -0
  24. package/src/llama.cpp/common/CMakeLists.txt +3 -3
  25. package/src/llama.cpp/common/arg.cpp +112 -11
  26. package/src/llama.cpp/common/chat.cpp +960 -266
  27. package/src/llama.cpp/common/chat.h +135 -0
  28. package/src/llama.cpp/common/common.cpp +27 -171
  29. package/src/llama.cpp/common/common.h +27 -67
  30. package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
  31. package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
  32. package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +37 -5
  33. package/src/llama.cpp/common/ngram-cache.cpp +1 -0
  34. package/src/llama.cpp/common/sampling.cpp +45 -7
  35. package/src/llama.cpp/common/speculative.cpp +6 -5
  36. package/src/llama.cpp/common/speculative.h +1 -1
  37. package/src/llama.cpp/docs/build.md +45 -7
  38. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
  39. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
  40. package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
  41. package/src/llama.cpp/examples/imatrix/imatrix.cpp +2 -3
  42. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
  43. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  44. package/src/llama.cpp/examples/llava/clip.cpp +373 -107
  45. package/src/llama.cpp/examples/llava/clip.h +19 -3
  46. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
  47. package/src/llama.cpp/examples/llava/llava.cpp +4 -2
  48. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
  49. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
  50. package/src/llama.cpp/examples/main/main.cpp +73 -28
  51. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
  52. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
  53. package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
  54. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
  55. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
  56. package/src/llama.cpp/examples/run/run.cpp +110 -67
  57. package/src/llama.cpp/examples/server/server.cpp +82 -87
  58. package/src/llama.cpp/examples/server/utils.hpp +94 -107
  59. package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
  60. package/src/llama.cpp/examples/tts/tts.cpp +251 -142
  61. package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
  62. package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
  63. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
  64. package/src/llama.cpp/ggml/include/ggml-cpu.h +3 -0
  65. package/src/llama.cpp/ggml/include/ggml.h +5 -1
  66. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
  67. package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
  68. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
  69. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
  70. package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
  71. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
  72. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
  73. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
  74. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
  75. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
  76. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
  77. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +151 -0
  78. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1396 -386
  79. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1432 -151
  80. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +22 -0
  81. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
  82. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
  83. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
  84. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  85. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +15 -2
  86. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
  87. package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
  88. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
  89. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
  90. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
  91. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
  92. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +220 -116
  93. package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
  94. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
  95. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
  96. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  97. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
  98. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
  99. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
  100. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
  101. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
  102. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
  103. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
  104. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
  105. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
  106. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
  107. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +168 -721
  108. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
  109. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
  110. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  111. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  112. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +146 -42
  113. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +13 -3
  114. package/src/llama.cpp/ggml/src/ggml.c +8 -3
  115. package/src/llama.cpp/include/llama.h +19 -5
  116. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
  117. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
  118. package/src/llama.cpp/requirements/requirements-all.txt +1 -0
  119. package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
  120. package/src/llama.cpp/requirements.txt +1 -0
  121. package/src/llama.cpp/src/llama-arch.cpp +21 -0
  122. package/src/llama.cpp/src/llama-arch.h +1 -0
  123. package/src/llama.cpp/src/llama-chat.cpp +1 -0
  124. package/src/llama.cpp/src/llama-grammar.cpp +182 -182
  125. package/src/llama.cpp/src/llama-grammar.h +12 -3
  126. package/src/llama.cpp/src/llama-kv-cache.h +1 -0
  127. package/src/llama.cpp/src/llama-mmap.cpp +11 -1
  128. package/src/llama.cpp/src/llama-model.cpp +69 -5
  129. package/src/llama.cpp/src/llama-sampling.cpp +43 -10
  130. package/src/llama.cpp/src/llama-vocab.cpp +12 -0
  131. package/src/llama.cpp/src/llama.cpp +147 -0
  132. package/src/llama.cpp/tests/test-backend-ops.cpp +166 -110
  133. package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
  134. package/src/llama.cpp/tests/test-chat.cpp +593 -395
  135. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
  136. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
  137. package/src/llama.cpp/Sources/llama/llama.h +0 -4
  138. package/src/llama.cpp/common/chat.hpp +0 -55
  139. /package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +0 -0
@@ -10,38 +10,12 @@
10
10
  #include <json.hpp>
11
11
  #include <string>
12
12
 
13
- #include "chat-template.hpp"
14
- #include "chat.hpp"
13
+ #include "chat.h"
15
14
  #include "llama-grammar.h"
16
15
  #include "unicode.h"
17
16
 
18
17
  using json = nlohmann::ordered_json;
19
18
 
20
- static common_chat_msg msg_from_json(const json & message) {
21
- common_chat_msg ret;
22
- ret.role = "assistant";
23
- if (message.contains("content") && !message.at("content").is_null()) {
24
- ret.content = message.at("content");
25
- }
26
- if (message.contains("tool_plan")) {
27
- ret.reasoning_content = message.at("tool_plan");
28
- }
29
- if (message.contains("reasoning_content")) {
30
- ret.reasoning_content = message.at("reasoning_content");
31
- }
32
- auto has_tool_calls = message.contains("tool_calls");
33
- if (has_tool_calls) {
34
- for (const auto & tc : message.at("tool_calls")) {
35
- const auto & arguments = tc.at("function").at("arguments");
36
- ret.tool_calls.push_back({
37
- tc.at("function").at("name").get<std::string>(),
38
- arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
39
- tc.contains("id") ? tc.at("id").get<std::string>() : "",
40
- });
41
- }
42
- }
43
- return ret;
44
- }
45
19
 
46
20
  template <class T> static void assert_equals(const T & expected, const T & actual) {
47
21
  if (expected != actual) {
@@ -53,7 +27,7 @@ template <class T> static void assert_equals(const T & expected, const T & actua
53
27
  }
54
28
 
55
29
  static std::string read_file(const std::string & path) {
56
- std::cerr << "# Reading: " << path << std::endl << std::flush;
30
+ std::cerr << "# Reading: " << path << '\n' << std::flush;
57
31
  std::ifstream fs(path, std::ios_base::binary);
58
32
  if (!fs.is_open()) {
59
33
  fs = std::ifstream("../" + path, std::ios_base::binary);
@@ -66,10 +40,14 @@ static std::string read_file(const std::string & path) {
66
40
  fs.seekg(0);
67
41
  std::string out;
68
42
  out.resize(static_cast<size_t>(size));
69
- fs.read(&out[0], static_cast<std::streamsize>(size));
43
+ fs.read(out.data(), static_cast<std::streamsize>(size));
70
44
  return out;
71
45
  }
72
46
 
47
+ static common_chat_templates_ptr read_templates(const std::string & path) {
48
+ return common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, read_file(path)));
49
+ }
50
+
73
51
  static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
74
52
  return std::unique_ptr<llama_grammar>(
75
53
  llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0));
@@ -90,110 +68,102 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
90
68
  }
91
69
  }
92
70
 
93
- for (const auto & stack : stacks_cur) {
94
- if (stack.empty()) {
95
- // An empty stack means that the grammar has been completed
96
- return true;
97
- }
71
+ if (std::any_of(stacks_cur.begin(), stacks_cur.end(), [](const auto & stack) { return stack.empty(); })) {
72
+ // An empty stack means that the grammar has been completed
73
+ return true;
98
74
  }
99
75
 
100
76
  return false;
101
77
  }
102
78
 
103
- // Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`.
104
- static std::string dump(const json & j) {
105
- return minja::Value(j).dump(-1, /* to_json= */ true);
106
- }
107
-
108
79
  static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
109
80
  assert_equals(expected.role, actual.role);
110
81
  assert_equals(expected.content, actual.content);
82
+ assert_equals(expected.content_parts.size(), actual.content_parts.size());
83
+ for (size_t i = 0; i < expected.content_parts.size(); i++) {
84
+ const auto & expected_part = expected.content_parts[i];
85
+ const auto & actual_part = actual.content_parts[i];
86
+ assert_equals(expected_part.type, actual_part.type);
87
+ assert_equals(expected_part.text, actual_part.text);
88
+ }
111
89
  assert_equals(expected.reasoning_content, actual.reasoning_content);
112
90
  assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
113
91
  for (size_t i = 0; i < expected.tool_calls.size(); i++) {
114
92
  const auto & expected_tool_call = expected.tool_calls[i];
115
93
  const auto & actual_tool_call = actual.tool_calls[i];
116
94
  assert_equals(expected_tool_call.name, actual_tool_call.name);
117
- assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments)));
95
+ assert_equals(json::parse(expected_tool_call.arguments).dump(), json::parse(actual_tool_call.arguments).dump());
118
96
  assert_equals(expected_tool_call.id, actual_tool_call.id);
119
97
  }
120
98
  }
121
99
 
122
- const auto special_function_tool = json::parse(R"({
123
- "type": "function",
124
- "function": {
125
- "name": "special_function",
126
- "description": "I'm special",
127
- "parameters": {
128
- "type": "object",
129
- "properties": {
130
- "arg1": {
131
- "type": "integer",
132
- "description": "The arg."
133
- }
134
- },
135
- "required": ["arg1"]
136
- }
137
- }
138
- })");
139
- const auto python_tool = json::parse(R"({
140
- "type": "function",
141
- "function": {
142
- "name": "python",
143
- "description": "an ipython interpreter",
144
- "parameters": {
145
- "type": "object",
146
- "properties": {
147
- "code": {
148
- "type": "string",
149
- "description": "Python code to execute."
150
- }
151
- },
152
- "required": ["code"]
153
- }
154
- }
155
- })");
156
- const auto code_interpreter_tool = json::parse(R"({
157
- "type": "function",
158
- "function": {
159
- "name": "code_interpreter",
160
- "description": "an ipython interpreter",
161
- "parameters": {
162
- "type": "object",
163
- "properties": {
164
- "code": {
165
- "type": "string",
166
- "description": "Python code to execute."
167
- }
168
- },
169
- "required": ["code"]
170
- }
171
- }
172
- })");
173
- const json tools = { special_function_tool, python_tool };
174
- const json llama_3_1_tools = { special_function_tool, code_interpreter_tool };
100
+ common_chat_tool special_function_tool {
101
+ /* .name = */ "special_function",
102
+ /* .description = */ "I'm special",
103
+ /* .parameters = */ R"({
104
+ "type": "object",
105
+ "properties": {
106
+ "arg1": {
107
+ "type": "integer",
108
+ "description": "The arg."
109
+ }
110
+ },
111
+ "required": ["arg1"]
112
+ })",
113
+ };
114
+ common_chat_tool python_tool {
115
+ /* .name = */ "python",
116
+ /* .description = */ "an ipython interpreter",
117
+ /* .parameters = */ R"({
118
+ "type": "object",
119
+ "properties": {
120
+ "code": {
121
+ "type": "string",
122
+ "description": "Python code to execute."
123
+ }
124
+ },
125
+ "required": ["code"]
126
+ })",
127
+ };
128
+ common_chat_tool code_interpreter_tool {
129
+ /* .name = */ "code_interpreter",
130
+ /* .description = */ "an ipython interpreter",
131
+ /* .parameters = */ R"({
132
+ "type": "object",
133
+ "properties": {
134
+ "code": {
135
+ "type": "string",
136
+ "description": "Python code to execute."
137
+ }
138
+ },
139
+ "required": ["code"]
140
+ })",
141
+ };
142
+ std::vector<common_chat_tool> tools { special_function_tool, python_tool };
143
+ std::vector<common_chat_tool> llama_3_1_tools { special_function_tool, code_interpreter_tool };
175
144
 
176
145
  struct delta_data {
177
146
  std::string delta;
178
147
  common_chat_params params;
179
148
  };
180
149
 
181
- static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
182
- const json & user_message, const json & delta_message, const json & tools,
183
- const json & tool_choice,
150
+ static delta_data init_delta(const struct common_chat_templates * tmpls, const std::vector<std::string> & end_tokens,
151
+ const common_chat_msg & user_message,
152
+ const common_chat_msg & delta_message,
153
+ const std::vector<common_chat_tool> & tools,
154
+ const common_chat_tool_choice & tool_choice,
184
155
  bool think = false) {
185
- common_chat_inputs inputs;
156
+ common_chat_templates_inputs inputs;
186
157
  inputs.parallel_tool_calls = true;
187
- inputs.messages = json::array();
188
158
  inputs.messages.push_back(user_message);
189
159
  inputs.tools = tools;
190
160
  inputs.tool_choice = tool_choice;
191
161
  inputs.extract_reasoning = think;
192
- auto params_prefix = common_chat_params_init(tmpl, inputs);
162
+ auto params_prefix = common_chat_templates_apply(tmpls, inputs);
193
163
 
194
164
  inputs.messages.push_back(delta_message);
195
165
  inputs.add_generation_prompt = false;
196
- auto params_full = common_chat_params_init(tmpl, inputs);
166
+ auto params_full = common_chat_templates_apply(tmpls, inputs);
197
167
 
198
168
  std::string prefix = params_prefix.prompt;
199
169
  std::string full = params_full.prompt;
@@ -234,30 +204,29 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
234
204
  gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
235
205
  the parsed message is the same as the test_message
236
206
  */
237
- static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
238
- const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
207
+ static void test_templates(const struct common_chat_templates * tmpls, const std::vector<std::string> & end_tokens,
208
+ const common_chat_msg & test_message,
209
+ const std::vector<common_chat_tool> & tools = {},
210
+ const std::string & expected_delta = "",
239
211
  bool expect_grammar_triggered = true,
240
212
  bool test_grammar_if_triggered = true,
241
213
  bool think = false) {
242
- common_chat_msg expected_msg = msg_from_json(test_message);
243
-
244
- auto user_message = json{
245
- { "role", "user" },
246
- { "content", "Hello, world!" }
247
- };
214
+ common_chat_msg user_message;
215
+ user_message.role = "user";
216
+ user_message.content = "Hello, world!";
248
217
 
249
- for (const auto & tool_choice : json({ "auto", "required" })) {
250
- auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice, think);
218
+ for (const auto & tool_choice : std::vector<common_chat_tool_choice> {COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED}) {
219
+ auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice, think);
251
220
  if (!expected_delta.empty()) {
252
221
  assert_equals(expected_delta, data.delta);
253
222
  }
254
223
 
255
224
  if (expect_grammar_triggered) {
256
225
  const auto msg = common_chat_parse(data.delta, data.params.format);
257
- assert_msg_equals(expected_msg, msg);
226
+ assert_msg_equals(test_message, msg);
258
227
  }
259
228
 
260
- if (!expected_msg.tool_calls.empty()) {
229
+ if (!test_message.tool_calls.empty()) {
261
230
  GGML_ASSERT(!data.params.grammar.empty());
262
231
  }
263
232
  if (!data.params.grammar.empty()) {
@@ -268,12 +237,35 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
268
237
  auto earliest_trigger_pos = std::string::npos;
269
238
  auto constrained = data.delta;
270
239
  for (const auto & trigger : data.params.grammar_triggers) {
271
- auto pos = constrained.find(trigger.word);
272
- if (pos == std::string::npos) {
273
- continue;
240
+ size_t pos = std::string::npos;
241
+ std::smatch match;
242
+ switch (trigger.type) {
243
+ case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
244
+ {
245
+ const auto & word = trigger.value;
246
+ pos = constrained.find(word);
247
+ break;
248
+ }
249
+ case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
250
+ {
251
+ const auto & pattern = trigger.value;
252
+ if (std::regex_search(constrained, match, std::regex(pattern))) {
253
+ pos = match.position();
254
+ }
255
+ break;
256
+ }
257
+ case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
258
+ {
259
+ const auto & pattern = trigger.value;
260
+ if (std::regex_search(constrained, match, std::regex(pattern)) && match.position() == 0) {
261
+ pos = 0;
262
+ }
263
+ break;
264
+ }
265
+ default:
266
+ throw std::runtime_error("Unknown trigger type");
274
267
  }
275
- if (pos > 0 && trigger.at_start) {
276
- fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str());
268
+ if (pos == std::string::npos) {
277
269
  continue;
278
270
  }
279
271
  if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
@@ -291,252 +283,361 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
291
283
 
292
284
  if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) {
293
285
  throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
294
- "\n\nGrammar: " + data.params.grammar);
286
+ "\n\nConstrained: " + constrained +
287
+ "\n\nGrammar: " + data.params.grammar);
295
288
  }
296
289
  }
297
290
  }
298
291
  }
299
292
 
300
- static void test_template_output_parsers() {
301
- json message_user {
302
- { "role", "user" },
303
- { "content", "Hey there!" },
304
- };
305
- json message_assist {
306
- { "role", "assistant" },
307
- { "content", "Hello, world!\nWhat's up?" },
308
- };
309
- json message_assist_thoughts_unparsed_think {
310
- { "role", "assistant" },
311
- { "content", "<think>I'm thinking</think>Hello, world!\nWhat's up?" },
312
- };
313
- json message_assist_thoughts_unparsed_r7b {
314
- { "role", "assistant" },
315
- { "content", "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?" },
316
- };
317
- json message_assist_thoughts {
318
- { "role", "assistant" },
319
- { "content", "Hello, world!\nWhat's up?" },
320
- { "reasoning_content", "I'm thinking" },
321
- };
322
- json tool_calls = json::array({{
323
- { "type", "function" },
324
- { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
325
- }});
326
-
327
- json message_assist_call {
328
- { "role", "assistant"},
329
- { "content", {}},
330
- { "tool_calls", {
331
- {
332
- { "type", "function" },
333
- { "function", {
334
- { "name", "special_function" },
335
- { "arguments", "{\"arg1\": 1}" },
336
- }},
337
- },
338
- }},
339
- };
340
- json message_assist_call_thoughts = {
341
- { "role", "assistant" },
342
- { "content", nullptr },
343
- { "reasoning_content", "I'm\nthinking" },
344
- { "tool_calls", {
345
- {
346
- { "type", "function" },
347
- { "function", {
348
- { "name", "special_function" },
349
- { "arguments", "{\"arg1\": 1}" },
350
- }},
351
- },
352
- }},
353
- };
354
- json message_assist_call_thoughts_unparsed = {
355
- { "role", "assistant" },
356
- { "content", "<think>I'm\nthinking</think>" },
357
- { "tool_calls", {
358
- {
359
- { "type", "function" },
360
- { "function", {
361
- { "name", "special_function" },
362
- { "arguments", "{\"arg1\": 1}" },
363
- }},
364
- },
365
- }},
366
- };
367
- json message_assist_call_id {
368
- { "role", "assistant"},
369
- { "content", {}},
370
- { "tool_calls", {
371
- {
372
- { "type", "function" },
373
- { "function", {
374
- { "name", "special_function" },
375
- { "arguments", "{\"arg1\": 1}" },
376
- }},
377
- {"id", "123456789"},
378
- },
379
- }},
380
- { "role", "assistant" },
381
- { "content", {} },
382
- { "tool_calls", tool_calls }
383
- };
384
- json message_assist_call_idx {
385
- { "role", "assistant"},
386
- { "content", {}},
387
- { "tool_calls", {
388
- {
389
- { "type", "function" },
390
- { "function", {
391
- { "name", "special_function" },
392
- { "arguments", "{\"arg1\": 1}" },
393
- }},
394
- // Index of the tool call in the tool_calls array
395
- {"id", "0"},
396
- },
397
- }},
398
- { "role", "assistant" },
399
- { "content", {} },
400
- { "tool_calls", tool_calls }
401
- };
402
- json message_assist_call_tool_plan_idx = message_assist_call_idx;
403
- message_assist_call_tool_plan_idx["tool_plan"] = "I'm thinking";
404
-
405
- auto python_message_assist_call = json{
406
- { "role", "assistant" },
407
- { "content", {} },
408
- { "tool_calls", json{ {
409
- { "type", "function" },
410
- { "function",
411
- {
412
- { "name", "python" },
413
- { "arguments",
414
- {
415
- { "code", "print('hey')" },
416
- } },
417
- } },
418
- } } }
293
+ const common_chat_msg message_user {
294
+ "user",
295
+ "Hey there!",
296
+ /* .content_parts = */ {},
297
+ /* .tool_calls = */ {},
298
+ /* .reasoning_content = */ "",
299
+ /* .tool_name = */ "",
300
+ /* .tool_call_id = */ "",
301
+ };
302
+
303
+ const common_chat_msg message_user_parts {
304
+ "user",
305
+ /* .content = */ "",
306
+ /* .content_parts = */ {
307
+ { "text", "Hey" },
308
+ { "text", "there" },
309
+ },
310
+ /* .tool_calls = */ {},
311
+ /* .reasoning_content = */ "",
312
+ /* .tool_name = */ "",
313
+ /* .tool_call_id = */ "",
314
+ };
315
+ const common_chat_msg message_assist {
316
+ "assistant",
317
+ "Hello, world!\nWhat's up?",
318
+ /* .content_parts = */ {},
319
+ /* .tool_calls = */ {},
320
+ /* .reasoning_content = */ "",
321
+ /* .tool_name = */ "",
322
+ /* .tool_call_id = */ "",
323
+ };
324
+ const common_chat_msg message_assist_thoughts_unparsed_think {
325
+ "assistant",
326
+ "<think>I'm thinking</think>Hello, world!\nWhat's up?",
327
+ /* .content_parts = */ {},
328
+ /* .tool_calls = */ {},
329
+ /* .reasoning_content = */ "",
330
+ /* .tool_name = */ "",
331
+ /* .tool_call_id = */ "",
332
+ };
333
+ const common_chat_msg message_assist_thoughts_unparsed_r7b {
334
+ "assistant",
335
+ "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?",
336
+ /* .content_parts = */ {},
337
+ /* .tool_calls = */ {},
338
+ /* .reasoning_content = */ "",
339
+ /* .tool_name = */ "",
340
+ /* .tool_call_id = */ "",
341
+ };
342
+ const common_chat_msg message_assist_thoughts {
343
+ "assistant",
344
+ "Hello, world!\nWhat's up?",
345
+ /* .content_parts = */ {},
346
+ /* .tool_calls = */ {},
347
+ /* .reasoning_content = */ "I'm thinking",
348
+ /* .tool_name = */ "",
349
+ /* .tool_call_id = */ "",
350
+ };
351
+ const std::vector<common_chat_tool_call> tool_calls {
352
+ { "special_function", "{\"arg1\": 1}", /* .id = */ "" },
353
+ };
354
+ const std::vector<common_chat_tool_call> tool_calls_idx {
355
+ { "special_function", "{\"arg1\": 1}", /* .id = */ "0" },
356
+ };
357
+ const std::vector<common_chat_tool_call> tool_calls_id {
358
+ { "special_function", "{\"arg1\": 1}", /* .id = */ "123456789" },
359
+ };
360
+
361
+ const common_chat_msg message_assist_call {
362
+ "assistant",
363
+ "",
364
+ /* .content_parts = */ {},
365
+ tool_calls,
366
+ /* .reasoning_content = */ "",
367
+ /* .tool_name = */ "",
368
+ /* .tool_call_id = */ "",
369
+ };
370
+ const common_chat_msg message_assist_call_thoughts = {
371
+ "assistant",
372
+ /* .content = */ "",
373
+ /* .content_parts = */ {},
374
+ tool_calls,
375
+ /* .reasoning_content = */ "I'm\nthinking",
376
+ /* .tool_name = */ "",
377
+ /* .tool_call_id = */ "",
378
+ };
379
+ const common_chat_msg message_assist_call_thoughts_unparsed = {
380
+ "assistant",
381
+ /* .content = */ "<think>I'm\nthinking</think>",
382
+ /* .content_parts = */ {},
383
+ tool_calls,
384
+ /* .reasoning_content = */ "",
385
+ /* .tool_name = */ "",
386
+ /* .tool_call_id = */ "",
387
+ };
388
+ const common_chat_msg message_assist_call_id {
389
+ "assistant",
390
+ "",
391
+ /* .content_parts = */ {},
392
+ tool_calls_id,
393
+ /* .reasoning_content = */ "",
394
+ /* .tool_name = */ "",
395
+ /* .tool_call_id = */ "",
396
+ };
397
+ const common_chat_msg message_assist_call_idx {
398
+ "assistant",
399
+ "",
400
+ /* .content_parts = */ {},
401
+ tool_calls_idx,
402
+ /* .reasoning_content = */ "",
403
+ /* .tool_name = */ "",
404
+ /* .tool_call_id = */ "",
405
+ };
406
+ const common_chat_msg message_assist_call_python {
407
+ "assistant",
408
+ "",
409
+ /* .content_parts = */ {},
410
+ { { "python", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
411
+ /* .reasoning_content = */ "",
412
+ /* .tool_name = */ "",
413
+ /* .tool_call_id = */ "",
414
+ };
415
+ const common_chat_msg message_assist_call_code_interpreter {
416
+ "assistant",
417
+ "",
418
+ /* .content_parts = */ {},
419
+ { { "code_interpreter", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
420
+ /* .reasoning_content = */ "",
421
+ /* .tool_name = */ "",
422
+ /* .tool_call_id = */ "",
423
+ };
424
+
425
+ static void test_msgs_oaicompat_json_conversion() {
426
+ std::vector<common_chat_msg> msgs{
427
+ message_user,
428
+ message_user_parts,
429
+ message_assist_call,
430
+ message_assist_call_thoughts,
431
+ message_assist_call_thoughts_unparsed,
432
+ message_assist_call_id,
433
+ message_assist_call_idx,
434
+ message_assist_call_python,
435
+ message_assist_call_code_interpreter,
419
436
  };
420
- auto code_interpreter_message_assist_call = json{
421
- { "role", "assistant" },
422
- { "content", {} },
423
- { "tool_calls", json{ {
424
- { "type", "function" },
425
- { "function",
426
- {
427
- { "name", "code_interpreter" },
428
- { "arguments",
429
- {
430
- { "code", "print('hey')" },
431
- } },
432
- } },
433
- } } }
437
+ for (const auto & msg : msgs) {
438
+ auto oai_json = common_chat_msgs_to_json_oaicompat<json>({msg});
439
+ auto msgs2 = common_chat_msgs_parse_oaicompat(oai_json);
440
+ assert_equals((size_t) 1, msgs2.size());
441
+ auto msg2 = msgs2[0];
442
+ assert_msg_equals(msg, msg2);
443
+ }
444
+ assert_equals(
445
+ std::string(
446
+ "[\n"
447
+ " {\n"
448
+ " \"role\": \"user\",\n"
449
+ " \"content\": [\n"
450
+ " {\n"
451
+ " \"type\": \"text\",\n"
452
+ " \"text\": \"Hey\"\n"
453
+ " },\n"
454
+ " {\n"
455
+ " \"type\": \"text\",\n"
456
+ " \"text\": \"there\"\n"
457
+ " }\n"
458
+ " ]\n"
459
+ " }\n"
460
+ "]"
461
+ ),
462
+ common_chat_msgs_to_json_oaicompat<json>({message_user_parts}).dump(2));
463
+
464
+ assert_equals(
465
+ std::string(
466
+ "[\n"
467
+ " {\n"
468
+ " \"role\": \"assistant\",\n"
469
+ " \"content\": null,\n"
470
+ " \"tool_calls\": [\n"
471
+ " {\n"
472
+ " \"type\": \"function\",\n"
473
+ " \"function\": {\n"
474
+ " \"name\": \"python\",\n"
475
+ " \"arguments\": \"{\\\"code\\\": \\\"print('hey')\\\"}\"\n"
476
+ " }\n"
477
+ " }\n"
478
+ " ]\n"
479
+ " }\n"
480
+ "]"
481
+ ),
482
+ common_chat_msgs_to_json_oaicompat<json>({message_assist_call_python}).dump(2));
483
+
484
+ auto res = common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\", \"tool_calls\": []}]"));
485
+ assert_equals<size_t>(1, res.size());
486
+ assert_equals<std::string>(res[0].role, "assistant");
487
+ assert_equals(true, res[0].content.empty());
488
+ assert_equals(true, res[0].tool_calls.empty());
489
+
490
+ try {
491
+ common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\"}]"));
492
+ throw std::runtime_error("Expected exception");
493
+ } catch (const std::exception & e) {
494
+ if (std::string(e.what()).find("'content'") == std::string::npos) {
495
+ throw std::runtime_error("Expected exception about missing 'content'");
496
+ }
497
+ }
498
+ }
499
+
500
+ static void test_tools_oaicompat_json_conversion() {
501
+ std::vector<common_chat_tool> tools{
502
+ special_function_tool,
503
+ python_tool,
504
+ code_interpreter_tool,
434
505
  };
435
506
 
436
- common_chat_inputs inputs_no_tools;
437
- inputs_no_tools.messages = json::array({message_user});
507
+ for (const auto & tool : tools) {
508
+ auto oai_json = common_chat_tools_to_json_oaicompat<json>({tool});
509
+ auto tools2 = common_chat_tools_parse_oaicompat(oai_json);
510
+ assert_equals((size_t) 1, tools2.size());
511
+ auto tool2 = tools2[0];
512
+ assert_equals(tool.name, tool2.name);
513
+ assert_equals(tool.description, tool2.description);
514
+ assert_equals(json::parse(tool.parameters).dump(2), json::parse(tool2.parameters).dump(2));
515
+ }
516
+
517
+ assert_equals(
518
+ std::string(
519
+ "[\n"
520
+ " {\n"
521
+ " \"type\": \"function\",\n"
522
+ " \"function\": {\n"
523
+ " \"name\": \"special_function\",\n"
524
+ " \"description\": \"I'm special\",\n"
525
+ " \"parameters\": {\n"
526
+ " \"type\": \"object\",\n"
527
+ " \"properties\": {\n"
528
+ " \"arg1\": {\n"
529
+ " \"type\": \"integer\",\n"
530
+ " \"description\": \"The arg.\"\n"
531
+ " }\n"
532
+ " },\n"
533
+ " \"required\": [\n"
534
+ " \"arg1\"\n"
535
+ " ]\n"
536
+ " }\n"
537
+ " }\n"
538
+ " }\n"
539
+ "]"
540
+ ),
541
+ common_chat_tools_to_json_oaicompat<json>({special_function_tool}).dump(2));
542
+ }
543
+
544
+ static void test_template_output_parsers() {
545
+
546
+ common_chat_templates_inputs inputs_no_tools;
547
+ inputs_no_tools.messages = {message_user};
438
548
  inputs_no_tools.extract_reasoning = false;
439
549
 
440
- common_chat_inputs inputs_no_tools_think;
441
- inputs_no_tools_think.messages = json::array({message_user});
550
+ common_chat_templates_inputs inputs_no_tools_think;
551
+ inputs_no_tools_think.messages = {message_user};
442
552
  inputs_no_tools_think.extract_reasoning = true;
443
553
 
444
- common_chat_inputs inputs_tools;
445
- inputs_tools.messages = json::array({message_user});
446
- inputs_tools.tools = json::array({special_function_tool});
554
+ common_chat_templates_inputs inputs_tools;
555
+ inputs_tools.messages = {message_user};
556
+ inputs_tools.tools = {special_function_tool};
447
557
  inputs_tools.extract_reasoning = false;
448
558
 
449
- common_chat_inputs inputs_tools_think;
450
- inputs_tools_think.messages = json::array({message_user});
451
- inputs_tools_think.tools = json::array({special_function_tool});
559
+ common_chat_templates_inputs inputs_tools_think;
560
+ inputs_tools_think.messages = {message_user};
561
+ inputs_tools_think.tools = {special_function_tool};
452
562
  inputs_tools_think.extract_reasoning = true;
453
563
 
454
- common_chat_inputs inputs_tools_builtin;
455
- inputs_tools_builtin.messages = json::array({message_user});
456
- inputs_tools_builtin.tools = json::array({python_tool});
564
+ common_chat_templates_inputs inputs_tools_builtin;
565
+ inputs_tools_builtin.messages = {message_user};
566
+ inputs_tools_builtin.tools = {python_tool};
457
567
  inputs_tools_builtin.extract_reasoning = false;
458
568
 
459
569
  {
460
570
  // Not supported yet
461
- const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "<s>", "</s>");
462
- assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
571
+ auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja");
572
+ assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
463
573
  }
464
574
  {
465
- const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>");
575
+ auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja");
466
576
  std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
467
577
 
468
- assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_no_tools).format);
469
- assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
470
- assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format);
578
+ assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
579
+ assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
580
+ assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
471
581
 
472
- assert_msg_equals(msg_from_json(message_assist),
582
+ assert_msg_equals(message_assist,
473
583
  common_chat_parse(
474
584
  "Hello, world!\nWhat's up?",
475
585
  COMMON_CHAT_FORMAT_COMMAND_R7B));
476
- assert_msg_equals(msg_from_json(message_assist),
586
+ assert_msg_equals(message_assist,
477
587
  common_chat_parse(
478
588
  "Hello, world!\nWhat's up?<|END_RESPONSE|>",
479
589
  COMMON_CHAT_FORMAT_COMMAND_R7B));
480
- assert_msg_equals(msg_from_json(message_assist),
590
+ assert_msg_equals(message_assist,
481
591
  common_chat_parse(
482
592
  "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
483
593
  COMMON_CHAT_FORMAT_COMMAND_R7B));
484
- assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b),
594
+ assert_msg_equals(message_assist_thoughts_unparsed_r7b,
485
595
  common_chat_parse(
486
596
  "<|START_THINKING|>I'm thinking<|END_THINKING|>"
487
597
  "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
488
598
  COMMON_CHAT_FORMAT_COMMAND_R7B));
489
- assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b),
599
+ assert_msg_equals(message_assist_thoughts_unparsed_r7b,
490
600
  common_chat_parse(
491
601
  "<|START_THINKING|>I'm thinking<|END_THINKING|>"
492
602
  "Hello, world!\nWhat's up?<|END_RESPONSE|>",
493
603
  COMMON_CHAT_FORMAT_COMMAND_R7B));
494
604
 
495
- assert_msg_equals(msg_from_json(message_assist_thoughts),
605
+ assert_msg_equals(message_assist_thoughts,
496
606
  common_chat_parse(
497
607
  "<|START_THINKING|>I'm thinking<|END_THINKING|>"
498
608
  "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
499
609
  COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING));
500
610
 
501
- test_template(tmpl, end_tokens, message_assist_call_idx, tools,
611
+ test_templates(tmpls.get(), end_tokens, message_assist_call_idx, tools,
502
612
  "<|START_THINKING|><|END_THINKING|>"
503
613
  "<|START_ACTION|>[\n"
504
614
  " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
505
615
  "]<|END_ACTION|>");
506
- test_template(tmpl, end_tokens, message_assist_call_tool_plan_idx, tools,
507
- "<|START_THINKING|>I'm thinking<|END_THINKING|>"
508
- "<|START_ACTION|>[\n"
509
- " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
510
- "]<|END_ACTION|>",
511
- /* expect_grammar_triggered= */ true,
512
- /* test_grammar_if_triggered= */ true,
513
- /* think= */ true);
514
- test_template(tmpl, end_tokens, message_assist, tools,
616
+ test_templates(tmpls.get(), end_tokens, message_assist, tools,
515
617
  "<|START_RESPONSE|>Hello, world!\n"
516
618
  "What's up?<|END_RESPONSE|>",
517
619
  /* expect_grammar_triggered= */ false);
518
620
  }
519
621
  {
520
- const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
622
+ auto tmpls = read_templates("models/templates/google-gemma-2-2b-it.jinja");
521
623
  std::vector<std::string> end_tokens{ "<end_of_turn>" };
522
624
 
523
- assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
524
- assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
625
+ assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
626
+ assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
525
627
  assert_equals(COMMON_CHAT_FORMAT_GENERIC,
526
- common_chat_params_init(
527
- common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"),
528
- "<s>", "</s>"),
628
+ common_chat_templates_apply(
629
+ read_templates("models/templates/microsoft-Phi-3.5-mini-instruct.jinja").get(),
529
630
  inputs_tools)
530
631
  .format);
531
632
 
532
633
  // Generic tool calls doesn't generate / parse content-only messages symmetrically.
533
634
 
534
- assert_msg_equals(msg_from_json(message_assist),
635
+ assert_msg_equals(message_assist,
535
636
  common_chat_parse("{\n"
536
637
  " \"response\": \"Hello, world!\\nWhat's up?\"\n"
537
638
  "}",
538
- common_chat_params_init(tmpl, inputs_tools).format));
539
- test_template(tmpl, end_tokens, message_assist_call_id, tools,
639
+ common_chat_templates_apply(tmpls.get(), inputs_tools).format));
640
+ test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools,
540
641
  "{\n"
541
642
  " \"tool_calls\": [\n"
542
643
  " {\n"
@@ -550,143 +651,233 @@ static void test_template_output_parsers() {
550
651
  "}");
551
652
  }
552
653
  {
553
- const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>",
554
- "</s>");
654
+ auto tmpls = read_templates("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja");
555
655
  std::vector<std::string> end_tokens{ "</s>" };
556
656
 
557
- assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
657
+ assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
558
658
 
559
- test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
560
- test_template(
561
- tmpl, end_tokens, message_assist_call_id, tools,
659
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
660
+ test_templates(
661
+ tmpls.get(), end_tokens, message_assist_call_id, tools,
562
662
  "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
563
663
  }
564
664
  {
565
- const common_chat_template tmpl(
566
- read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
665
+ auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
567
666
  std::vector<std::string> end_tokens{ "<|im_end|>" };
568
667
 
569
- assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(tmpl, inputs_tools).format);
668
+ assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
570
669
  assert_equals(
571
670
  COMMON_CHAT_FORMAT_HERMES_2_PRO,
572
- common_chat_params_init(
573
- common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"),
574
- "<s>", "</s>"),
671
+ common_chat_templates_apply(
672
+ read_templates("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja").get(),
575
673
  inputs_tools)
576
674
  .format);
577
675
  assert_equals(
578
676
  COMMON_CHAT_FORMAT_HERMES_2_PRO,
579
- common_chat_params_init(
580
- common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"),
677
+ common_chat_templates_apply(
678
+ read_templates("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja").get(),
581
679
  inputs_tools)
582
680
  .format);
583
681
 
584
- test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
585
- test_template(tmpl, end_tokens, message_assist_call, tools,
682
+ // Test parsing
683
+ assert_msg_equals(message_assist_call, common_chat_parse(
684
+ "<tool_call>\n"
685
+ "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
686
+ "</tool_call>",
687
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
688
+ assert_msg_equals(message_assist_call, common_chat_parse(
689
+ "<function=special_function>{\"arg1\": 1}</function>",
690
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
691
+ assert_msg_equals(message_assist_call, common_chat_parse(
692
+ "<function name=\"special_function\">\n"
693
+ "{\"arg1\": 1}\n"
694
+ "</function>",
695
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
696
+ assert_msg_equals(message_assist_call, common_chat_parse(
697
+ "<tool>\n"
698
+ " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
699
+ "</tool>",
700
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
701
+ assert_msg_equals(message_assist_call, common_chat_parse(
702
+ "<tools>\n"
703
+ " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
704
+ "</tools>",
705
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
706
+ assert_msg_equals(message_assist_call, common_chat_parse(
707
+ "<response>\n"
708
+ " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
709
+ "</response>",
710
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
711
+ assert_msg_equals(message_assist_call, common_chat_parse(
712
+ "```xml\n"
713
+ "<response>\n"
714
+ " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
715
+ "</response>\n"
716
+ "```",
717
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
718
+ assert_msg_equals(message_assist_call, common_chat_parse(
719
+ "```xml\n"
720
+ " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
721
+ "```",
722
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
723
+ assert_msg_equals(message_assist_call, common_chat_parse(
724
+ "```\n"
725
+ " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
726
+ "```",
727
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
728
+ assert_msg_equals(message_assist_call, common_chat_parse(
729
+ "```\n"
730
+ "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
731
+ "```",
732
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
733
+ assert_msg_equals(message_assist_call, common_chat_parse(
734
+ "```json\n"
735
+ " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
736
+ "```",
737
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
738
+ assert_msg_equals(message_assist_call, common_chat_parse(
739
+ "```json\n"
740
+ "\n"
741
+ " <function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}} \n"
742
+ " </function_call> \n"
743
+ "``` ",
744
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
745
+ assert_msg_equals(message_assist_call, common_chat_parse(
746
+ "<json>\n"
747
+ " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
748
+ "</json>",
749
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
750
+ assert_msg_equals(message_assist_call, common_chat_parse(
751
+ "<xml>\n"
752
+ " {\n"
753
+ " \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}\n"
754
+ " }\n"
755
+ "</xml>",
756
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
757
+ assert_msg_equals(message_assist_call, common_chat_parse(
758
+ "<JSON>\n"
759
+ " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
760
+ "</JSON>",
761
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
762
+ assert_msg_equals(message_assist_call, common_chat_parse(
763
+ "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
764
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
765
+ assert_msg_equals(message_assist_call, common_chat_parse(
766
+ "{\n \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
767
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
768
+
769
+ assert_msg_equals(message_assist_thoughts_unparsed_think,
770
+ common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
771
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
772
+ assert_msg_equals(message_assist_thoughts_unparsed_think,
773
+ common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
774
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
775
+ assert_msg_equals(message_assist_thoughts,
776
+ common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
777
+ COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING));
778
+ assert_msg_equals(message_assist_thoughts,
779
+ common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
780
+ COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING));
781
+
782
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
783
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
586
784
  "<tool_call>\n"
587
785
  "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
588
786
  "</tool_call>");
589
- test_template(tmpl, end_tokens, python_message_assist_call, tools,
787
+ test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools,
590
788
  "<tool_call>\n"
591
789
  "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
592
790
  "</tool_call>");
593
791
  }
594
792
  {
595
- const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>",
596
- "</s>");
793
+ auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja");
597
794
  std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
598
795
 
599
- assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
796
+ assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
600
797
  assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
601
- common_chat_params_init(tmpl, inputs_tools_builtin).format);
798
+ common_chat_templates_apply(tmpls.get(), inputs_tools_builtin).format);
602
799
  assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
603
- common_chat_params_init(
604
- common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"),
605
- "<s>", "</s>"),
800
+ common_chat_templates_apply(
801
+ read_templates("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja").get(),
606
802
  inputs_tools_builtin)
607
803
  .format);
608
804
 
609
- // test_template(tmpl, end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
610
- test_template(tmpl, end_tokens, code_interpreter_message_assist_call, llama_3_1_tools,
805
+ // test_templates(tmpls.get(), end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
806
+ test_templates(tmpls.get(), end_tokens, message_assist_call_code_interpreter, llama_3_1_tools,
611
807
  "<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
612
- test_template(tmpl, end_tokens, python_message_assist_call, tools,
808
+ test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools,
613
809
  "<|python_tag|>python.call(code=\"print('hey')\")");
614
- test_template(tmpl, end_tokens, message_assist_call, tools,
810
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
615
811
  "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
616
812
  }
617
813
  {
618
- const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>",
619
- "</s>");
814
+ auto tmpls = read_templates("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja");
620
815
  std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
621
816
 
622
- assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
817
+ assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
623
818
 
624
- test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
625
- test_template(tmpl, end_tokens, message_assist_call, tools,
819
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
820
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
626
821
  "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
627
822
  }
628
823
  {
629
- const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>",
630
- "</s>");
824
+ auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja");
631
825
  std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
632
826
 
633
827
  assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
634
- common_chat_params_init(tmpl, inputs_tools).format);
828
+ common_chat_templates_apply(tmpls.get(), inputs_tools).format);
635
829
 
636
- test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
637
- test_template(tmpl, end_tokens, message_assist_call, tools,
830
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
831
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
638
832
  "<function=special_function>{\"arg1\": 1}</function>");
639
833
  }
640
834
  {
641
- const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>",
642
- "</s>");
835
+ auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.2.jinja");
643
836
  std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
644
837
 
645
- assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format);
646
- assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format);
838
+ assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
839
+ assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
647
840
 
648
- test_template(tmpl, end_tokens, message_assist, {},
841
+ test_templates(tmpls.get(), end_tokens, message_assist, {},
649
842
  "all\n"
650
843
  "Hello, world!\n"
651
844
  "What's up?",
652
845
  /* expect_grammar_triggered= */ false);
653
- test_template(tmpl, end_tokens, message_assist_call, tools,
846
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
654
847
  "special_function\n"
655
848
  "{\"arg1\": 1}");
656
849
  }
657
850
  {
658
- const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>",
659
- "</s>");
851
+ auto tmpls = read_templates("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja");
660
852
  std::vector<std::string> end_tokens{ "<|eot_id|>" };
661
853
 
662
- assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
854
+ assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
663
855
 
664
- test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
665
- test_template(tmpl, end_tokens, message_assist_call, tools,
856
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
857
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
666
858
  " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
667
859
  }
668
860
  {
669
861
  // Original DeepSeek R1 template. Leaves <|tool▁calls▁begin|> and others unclosed. Our logic fixes the prompt.
670
- const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"),
671
- "<s>", "</s>");
862
+ auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja");
672
863
  std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
673
864
 
674
- assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
675
- assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format);
865
+ assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
866
+ assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
676
867
 
677
- test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
678
- test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
679
- assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think),
868
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
869
+ test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
870
+ assert_msg_equals(message_assist_thoughts_unparsed_think,
680
871
  common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
681
872
  COMMON_CHAT_FORMAT_DEEPSEEK_R1));
682
- assert_msg_equals(msg_from_json(message_assist_thoughts),
873
+ assert_msg_equals(message_assist_thoughts,
683
874
  common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
684
875
  COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
685
- assert_msg_equals(msg_from_json(message_assist_thoughts),
876
+ assert_msg_equals(message_assist_thoughts,
686
877
  // Latest template update (ast of 20250209) adds a trailing <think>\n if add_generation_prompt is true.
687
878
  common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
688
879
  COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
689
- // test_template(tmpl, end_tokens, message_assist_call, tools,
880
+ // test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
690
881
  // "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
691
882
  // "```json\n"
692
883
  // "{\"arg1\": 1}\n"
@@ -697,23 +888,22 @@ static void test_template_output_parsers() {
697
888
  }
698
889
  {
699
890
  // Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all.
700
- const common_chat_template tmpl(read_file("models/templates/llama-cpp-deepseek-r1.jinja"),
701
- "<s>", "</s>");
891
+ auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja");
702
892
  std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
703
893
 
704
- assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
705
- assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format);
894
+ assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
895
+ assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
706
896
 
707
- test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
708
- test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
709
- assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think),
897
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
898
+ test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
899
+ assert_msg_equals(message_assist_thoughts_unparsed_think,
710
900
  common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
711
901
  COMMON_CHAT_FORMAT_DEEPSEEK_R1));
712
- assert_msg_equals(msg_from_json(message_assist_thoughts),
902
+ assert_msg_equals(message_assist_thoughts,
713
903
  common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
714
904
  COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
715
905
 
716
- assert_msg_equals(msg_from_json(message_assist_call_thoughts_unparsed),
906
+ assert_msg_equals(message_assist_call_thoughts_unparsed,
717
907
  common_chat_parse(
718
908
  "<think>I'm\nthinking</think>\n\n"
719
909
  "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
@@ -721,7 +911,7 @@ static void test_template_output_parsers() {
721
911
  "{\"arg1\": 1}\n"
722
912
  "```<|tool▁call▁end|><|tool▁calls▁end|>",
723
913
  COMMON_CHAT_FORMAT_DEEPSEEK_R1));
724
- assert_msg_equals(msg_from_json(message_assist_call_thoughts),
914
+ assert_msg_equals(message_assist_call_thoughts,
725
915
  common_chat_parse(
726
916
  "<think>I'm\nthinking</think>\n\n"
727
917
  "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
@@ -729,7 +919,7 @@ static void test_template_output_parsers() {
729
919
  "{\"arg1\": 1}\n"
730
920
  "```<|tool▁call▁end|><|tool▁calls▁end|>",
731
921
  COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
732
- test_template(tmpl, end_tokens, message_assist_call, tools,
922
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
733
923
  "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
734
924
  "```json\n"
735
925
  "{\"arg1\": 1}\n"
@@ -738,38 +928,46 @@ static void test_template_output_parsers() {
738
928
  }
739
929
 
740
930
  int main(int argc, char ** argv) {
931
+ // try {
741
932
  #ifndef _WIN32
742
- if (argc > 1) {
743
- common_chat_inputs inputs;
744
- inputs.messages = {
745
- { { "role", "user" }, { "content", "Hey" } }
746
- };
747
- inputs.tools = json::array({ special_function_tool });
748
-
749
- std::cout << "| Template | Format |\n";
750
- std::cout << "|----------|--------|\n";
751
-
752
- for (int i = 1; i < argc; i++) {
753
- try {
754
- std::string path = argv[i];
755
- if (path.rfind(".jinja") != path.size() - 6) {
756
- std::cerr << "Skipping non-jinja file: " << path << std::endl;
757
- continue;
933
+ if (argc > 1) {
934
+ common_chat_templates_inputs inputs;
935
+ common_chat_msg msg;
936
+ msg.role = "user";
937
+ msg.content = "Hey";
938
+ inputs.messages = {msg};
939
+ inputs.tools = { special_function_tool };
940
+
941
+ std::cout << "| Template | Format |\n";
942
+ std::cout << "|----------|--------|\n";
943
+
944
+ for (int i = 1; i < argc; i++) {
945
+ try {
946
+ std::string path = argv[i];
947
+ if (path.rfind(".jinja") != path.size() - 6) {
948
+ std::cerr << "Skipping non-jinja file: " << path << '\n';
949
+ continue;
950
+ }
951
+ auto tmpls = read_templates(path);
952
+ auto parts = string_split(path, "/");
953
+ auto name = parts[parts.size() - 1];
954
+ auto format = common_chat_format_name(common_chat_templates_apply(tmpls.get(), inputs).format);
955
+ std::cout << "| " << name << " | " << format << " |\n";
956
+ } catch (const std::exception & e) {
957
+ std::cerr << "Failed to process " << argv[i] << ": " << e.what() << '\n';
758
958
  }
759
- common_chat_template tmpl(read_file(path), "", "");
760
- auto parts = string_split(path, "/");
761
- auto name = parts[parts.size() - 1];
762
- auto format = common_chat_format_name(common_chat_params_init(tmpl, inputs).format);
763
- std::cout << "| " << name << " | " << format << " |\n";
764
- } catch (const std::exception & e) {
765
- std::cerr << "Failed to process " << argv[i] << ": " << e.what() << std::endl;
766
959
  }
767
- }
768
- } else
960
+ } else
769
961
  #endif
770
- {
771
- test_template_output_parsers();
772
- std::cout << "\n[chat] All tests passed!" << std::endl;
773
- }
774
- return 0;
962
+ {
963
+ test_msgs_oaicompat_json_conversion();
964
+ test_tools_oaicompat_json_conversion();
965
+ test_template_output_parsers();
966
+ std::cout << "\n[chat] All tests passed!" << '\n';
967
+ }
968
+ return 0;
969
+ // } catch (const std::exception & e) {
970
+ // std::cerr << "Error: " << e.what() << '\n';
971
+ // return 1;
972
+ // }
775
973
  }