@fugood/llama.node 0.3.12 → 0.3.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.ts +2 -1
  18. package/package.json +1 -1
  19. package/src/LlamaCompletionWorker.cpp +14 -0
  20. package/src/LlamaContext.cpp +110 -79
  21. package/src/LlamaContext.h +1 -1
  22. package/src/common.hpp +1 -2
  23. package/src/llama.cpp/.github/workflows/build.yml +95 -13
  24. package/src/llama.cpp/.github/workflows/docker.yml +2 -0
  25. package/src/llama.cpp/.github/workflows/labeler.yml +1 -1
  26. package/src/llama.cpp/.github/workflows/server.yml +2 -0
  27. package/src/llama.cpp/common/CMakeLists.txt +23 -6
  28. package/src/llama.cpp/common/arg.cpp +292 -14
  29. package/src/llama.cpp/common/chat.cpp +1128 -315
  30. package/src/llama.cpp/common/chat.h +135 -0
  31. package/src/llama.cpp/common/common.cpp +27 -171
  32. package/src/llama.cpp/common/common.h +41 -73
  33. package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
  34. package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
  35. package/src/llama.cpp/common/llguidance.cpp +3 -3
  36. package/src/llama.cpp/common/log.cpp +1 -0
  37. package/src/llama.cpp/common/log.h +2 -1
  38. package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +21 -7
  39. package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +61 -14
  40. package/src/llama.cpp/common/ngram-cache.cpp +1 -0
  41. package/src/llama.cpp/common/sampling.cpp +93 -49
  42. package/src/llama.cpp/common/speculative.cpp +6 -5
  43. package/src/llama.cpp/common/speculative.h +1 -1
  44. package/src/llama.cpp/docs/build.md +47 -9
  45. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
  46. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
  47. package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
  48. package/src/llama.cpp/examples/imatrix/imatrix.cpp +4 -4
  49. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +6 -5
  50. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +1 -1
  51. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
  52. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  53. package/src/llama.cpp/examples/llava/clip.cpp +373 -107
  54. package/src/llama.cpp/examples/llava/clip.h +19 -3
  55. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
  56. package/src/llama.cpp/examples/llava/llava.cpp +4 -2
  57. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
  58. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
  59. package/src/llama.cpp/examples/main/main.cpp +73 -28
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +1 -0
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
  64. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
  65. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
  66. package/src/llama.cpp/examples/run/run.cpp +115 -79
  67. package/src/llama.cpp/examples/server/CMakeLists.txt +1 -1
  68. package/src/llama.cpp/examples/server/httplib.h +381 -292
  69. package/src/llama.cpp/examples/server/server.cpp +134 -128
  70. package/src/llama.cpp/examples/server/utils.hpp +95 -106
  71. package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
  72. package/src/llama.cpp/examples/tts/tts.cpp +251 -142
  73. package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
  74. package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
  75. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
  76. package/src/llama.cpp/ggml/include/ggml-cpu.h +4 -1
  77. package/src/llama.cpp/ggml/include/ggml-metal.h +1 -1
  78. package/src/llama.cpp/ggml/include/ggml-vulkan.h +0 -2
  79. package/src/llama.cpp/ggml/include/ggml.h +6 -2
  80. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
  81. package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
  82. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
  83. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
  84. package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
  85. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
  86. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
  87. package/src/llama.cpp/ggml/src/ggml-common.h +0 -2
  88. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
  89. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
  90. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
  91. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +156 -11
  93. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +2235 -641
  94. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1572 -198
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +24 -5
  96. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
  97. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +9 -8
  101. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +16 -3
  102. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
  103. package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
  104. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
  105. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
  106. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
  107. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
  108. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +246 -120
  109. package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
  110. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
  111. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
  112. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
  114. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
  115. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
  116. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
  117. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
  119. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
  120. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
  121. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +174 -728
  124. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
  125. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
  126. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  127. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  128. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +949 -602
  129. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +37 -3
  130. package/src/llama.cpp/ggml/src/ggml.c +9 -4
  131. package/src/llama.cpp/include/llama.h +32 -14
  132. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
  133. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
  134. package/src/llama.cpp/requirements/requirements-all.txt +1 -0
  135. package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
  136. package/src/llama.cpp/requirements.txt +1 -0
  137. package/src/llama.cpp/src/llama-arch.cpp +21 -0
  138. package/src/llama.cpp/src/llama-arch.h +1 -0
  139. package/src/llama.cpp/src/llama-chat.cpp +1 -0
  140. package/src/llama.cpp/src/llama-grammar.cpp +183 -183
  141. package/src/llama.cpp/src/llama-grammar.h +13 -4
  142. package/src/llama.cpp/src/llama-impl.h +6 -6
  143. package/src/llama.cpp/src/llama-kv-cache.h +2 -1
  144. package/src/llama.cpp/src/llama-mmap.cpp +11 -1
  145. package/src/llama.cpp/src/llama-mmap.h +1 -0
  146. package/src/llama.cpp/src/llama-model.cpp +70 -6
  147. package/src/llama.cpp/src/llama-sampling.cpp +174 -67
  148. package/src/llama.cpp/src/llama-vocab.cpp +12 -0
  149. package/src/llama.cpp/src/llama.cpp +154 -5
  150. package/src/llama.cpp/src/unicode.cpp +9 -2
  151. package/src/llama.cpp/tests/test-backend-ops.cpp +171 -115
  152. package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
  153. package/src/llama.cpp/tests/test-chat.cpp +691 -325
  154. package/src/llama.cpp/tests/test-gguf.cpp +4 -4
  155. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
  156. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
  157. package/src/llama.cpp/tests/test-sampling.cpp +15 -0
  158. package/src/llama.cpp/Sources/llama/llama.h +0 -4
  159. package/src/llama.cpp/common/chat.hpp +0 -52
@@ -10,35 +10,12 @@
10
10
  #include <json.hpp>
11
11
  #include <string>
12
12
 
13
- #include "chat-template.hpp"
14
- #include "chat.hpp"
13
+ #include "chat.h"
15
14
  #include "llama-grammar.h"
16
15
  #include "unicode.h"
17
16
 
18
17
  using json = nlohmann::ordered_json;
19
18
 
20
- static common_chat_msg msg_from_json(const json & message) {
21
- common_chat_msg ret;
22
- ret.role = "assistant";
23
- if (message.contains("content") && !message.at("content").is_null()) {
24
- ret.content = message.at("content");
25
- }
26
- if (message.contains("tool_plan")) {
27
- ret.tool_plan = message.at("tool_plan");
28
- }
29
- auto has_tool_calls = message.contains("tool_calls");
30
- if (has_tool_calls) {
31
- for (const auto & tc : message.at("tool_calls")) {
32
- const auto & arguments = tc.at("function").at("arguments");
33
- ret.tool_calls.push_back({
34
- tc.at("function").at("name").get<std::string>(),
35
- arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
36
- tc.contains("id") ? tc.at("id").get<std::string>() : "",
37
- });
38
- }
39
- }
40
- return ret;
41
- }
42
19
 
43
20
  template <class T> static void assert_equals(const T & expected, const T & actual) {
44
21
  if (expected != actual) {
@@ -50,7 +27,7 @@ template <class T> static void assert_equals(const T & expected, const T & actua
50
27
  }
51
28
 
52
29
  static std::string read_file(const std::string & path) {
53
- std::cerr << "# Reading: " << path << std::endl << std::flush;
30
+ std::cerr << "# Reading: " << path << '\n' << std::flush;
54
31
  std::ifstream fs(path, std::ios_base::binary);
55
32
  if (!fs.is_open()) {
56
33
  fs = std::ifstream("../" + path, std::ios_base::binary);
@@ -63,10 +40,14 @@ static std::string read_file(const std::string & path) {
63
40
  fs.seekg(0);
64
41
  std::string out;
65
42
  out.resize(static_cast<size_t>(size));
66
- fs.read(&out[0], static_cast<std::streamsize>(size));
43
+ fs.read(out.data(), static_cast<std::streamsize>(size));
67
44
  return out;
68
45
  }
69
46
 
47
+ static common_chat_templates_ptr read_templates(const std::string & path) {
48
+ return common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, read_file(path)));
49
+ }
50
+
70
51
  static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
71
52
  return std::unique_ptr<llama_grammar>(
72
53
  llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0));
@@ -87,122 +68,124 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
87
68
  }
88
69
  }
89
70
 
90
- for (const auto & stack : stacks_cur) {
91
- if (stack.empty()) {
92
- // An empty stack means that the grammar has been completed
93
- return true;
94
- }
71
+ if (std::any_of(stacks_cur.begin(), stacks_cur.end(), [](const auto & stack) { return stack.empty(); })) {
72
+ // An empty stack means that the grammar has been completed
73
+ return true;
95
74
  }
96
75
 
97
76
  return false;
98
77
  }
99
78
 
100
- // Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`.
101
- static std::string dump(const json & j) {
102
- return minja::Value(j).dump(-1, /* to_json= */ true);
103
- }
104
-
105
79
  static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
106
80
  assert_equals(expected.role, actual.role);
107
81
  assert_equals(expected.content, actual.content);
82
+ assert_equals(expected.content_parts.size(), actual.content_parts.size());
83
+ for (size_t i = 0; i < expected.content_parts.size(); i++) {
84
+ const auto & expected_part = expected.content_parts[i];
85
+ const auto & actual_part = actual.content_parts[i];
86
+ assert_equals(expected_part.type, actual_part.type);
87
+ assert_equals(expected_part.text, actual_part.text);
88
+ }
89
+ assert_equals(expected.reasoning_content, actual.reasoning_content);
108
90
  assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
109
91
  for (size_t i = 0; i < expected.tool_calls.size(); i++) {
110
92
  const auto & expected_tool_call = expected.tool_calls[i];
111
93
  const auto & actual_tool_call = actual.tool_calls[i];
112
94
  assert_equals(expected_tool_call.name, actual_tool_call.name);
113
- assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments)));
95
+ assert_equals(json::parse(expected_tool_call.arguments).dump(), json::parse(actual_tool_call.arguments).dump());
114
96
  assert_equals(expected_tool_call.id, actual_tool_call.id);
115
97
  }
116
98
  }
117
99
 
118
- const auto special_function_tool = json::parse(R"({
119
- "type": "function",
120
- "function": {
121
- "name": "special_function",
122
- "description": "I'm special",
123
- "parameters": {
124
- "type": "object",
125
- "properties": {
126
- "arg1": {
127
- "type": "integer",
128
- "description": "The arg."
129
- }
130
- },
131
- "required": ["arg1"]
132
- }
133
- }
134
- })");
135
- const auto python_tool = json::parse(R"({
136
- "type": "function",
137
- "function": {
138
- "name": "python",
139
- "description": "an ipython interpreter",
140
- "parameters": {
141
- "type": "object",
142
- "properties": {
143
- "code": {
144
- "type": "string",
145
- "description": "Python code to execute."
146
- }
147
- },
148
- "required": ["code"]
149
- }
150
- }
151
- })");
152
- const auto code_interpreter_tool = json::parse(R"({
153
- "type": "function",
154
- "function": {
155
- "name": "code_interpreter",
156
- "description": "an ipython interpreter",
157
- "parameters": {
158
- "type": "object",
159
- "properties": {
160
- "code": {
161
- "type": "string",
162
- "description": "Python code to execute."
163
- }
164
- },
165
- "required": ["code"]
166
- }
167
- }
168
- })");
169
- const json tools = { special_function_tool, python_tool };
170
- const json llama_3_1_tools = { special_function_tool, code_interpreter_tool };
100
+ common_chat_tool special_function_tool {
101
+ /* .name = */ "special_function",
102
+ /* .description = */ "I'm special",
103
+ /* .parameters = */ R"({
104
+ "type": "object",
105
+ "properties": {
106
+ "arg1": {
107
+ "type": "integer",
108
+ "description": "The arg."
109
+ }
110
+ },
111
+ "required": ["arg1"]
112
+ })",
113
+ };
114
+ common_chat_tool python_tool {
115
+ /* .name = */ "python",
116
+ /* .description = */ "an ipython interpreter",
117
+ /* .parameters = */ R"({
118
+ "type": "object",
119
+ "properties": {
120
+ "code": {
121
+ "type": "string",
122
+ "description": "Python code to execute."
123
+ }
124
+ },
125
+ "required": ["code"]
126
+ })",
127
+ };
128
+ common_chat_tool code_interpreter_tool {
129
+ /* .name = */ "code_interpreter",
130
+ /* .description = */ "an ipython interpreter",
131
+ /* .parameters = */ R"({
132
+ "type": "object",
133
+ "properties": {
134
+ "code": {
135
+ "type": "string",
136
+ "description": "Python code to execute."
137
+ }
138
+ },
139
+ "required": ["code"]
140
+ })",
141
+ };
142
+ std::vector<common_chat_tool> tools { special_function_tool, python_tool };
143
+ std::vector<common_chat_tool> llama_3_1_tools { special_function_tool, code_interpreter_tool };
171
144
 
172
145
  struct delta_data {
173
146
  std::string delta;
174
147
  common_chat_params params;
175
148
  };
176
149
 
177
- static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
178
- const json & user_message, const json & delta_message, const json & tools,
179
- const json & tool_choice) {
180
- common_chat_inputs inputs;
150
+ static delta_data init_delta(const struct common_chat_templates * tmpls, const std::vector<std::string> & end_tokens,
151
+ const common_chat_msg & user_message,
152
+ const common_chat_msg & delta_message,
153
+ const std::vector<common_chat_tool> & tools,
154
+ const common_chat_tool_choice & tool_choice,
155
+ bool think = false) {
156
+ common_chat_templates_inputs inputs;
181
157
  inputs.parallel_tool_calls = true;
182
- inputs.messages = json::array();
183
158
  inputs.messages.push_back(user_message);
184
159
  inputs.tools = tools;
185
160
  inputs.tool_choice = tool_choice;
186
- auto params_prefix = common_chat_params_init(tmpl, inputs);
161
+ inputs.extract_reasoning = think;
162
+ auto params_prefix = common_chat_templates_apply(tmpls, inputs);
187
163
 
188
164
  inputs.messages.push_back(delta_message);
189
165
  inputs.add_generation_prompt = false;
190
- auto params_full = common_chat_params_init(tmpl, inputs);
166
+ auto params_full = common_chat_templates_apply(tmpls, inputs);
191
167
 
192
168
  std::string prefix = params_prefix.prompt;
193
169
  std::string full = params_full.prompt;
194
170
 
195
- // Check full starts with prefix
196
- if (full.find(prefix) != 0) {
197
- fprintf(stderr, "Full:\n%s\n\nPrefix:\n%s\n\n", full.c_str(), prefix.c_str());
198
- throw std::runtime_error("Full message does not start with prefix");
199
- }
200
-
201
171
  if (full == prefix) {
202
172
  throw std::runtime_error("Full message is the same as the prefix");
203
173
  }
204
174
 
205
- auto delta = full.substr(prefix.size());
175
+ size_t common_prefix_length = 0;
176
+ for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
177
+ if (prefix[i] != full[i]) {
178
+ break;
179
+ }
180
+ if (prefix[i] == '<') {
181
+ // DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
182
+ // but it removes thinking tags for past messages.
183
+ // The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
184
+ continue;
185
+ }
186
+ common_prefix_length = i + 1;
187
+ }
188
+ auto delta = full.substr(common_prefix_length);
206
189
 
207
190
  // Strip end tokens
208
191
  for (const auto & end_token : end_tokens) {
@@ -221,28 +204,29 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
221
204
  gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
222
205
  the parsed message is the same as the test_message
223
206
  */
224
- static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
225
- const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
226
- bool expect_grammar_triggered = true) {
227
- common_chat_msg expected_msg = msg_from_json(test_message);
228
-
229
- auto user_message = json{
230
- { "role", "user" },
231
- { "content", "Hello, world!" }
232
- };
233
-
234
- for (const auto & tool_choice : json({ "auto", "required" })) {
235
- auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice);
207
+ static void test_templates(const struct common_chat_templates * tmpls, const std::vector<std::string> & end_tokens,
208
+ const common_chat_msg & test_message,
209
+ const std::vector<common_chat_tool> & tools = {},
210
+ const std::string & expected_delta = "",
211
+ bool expect_grammar_triggered = true,
212
+ bool test_grammar_if_triggered = true,
213
+ bool think = false) {
214
+ common_chat_msg user_message;
215
+ user_message.role = "user";
216
+ user_message.content = "Hello, world!";
217
+
218
+ for (const auto & tool_choice : std::vector<common_chat_tool_choice> {COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED}) {
219
+ auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice, think);
236
220
  if (!expected_delta.empty()) {
237
221
  assert_equals(expected_delta, data.delta);
238
222
  }
239
223
 
240
224
  if (expect_grammar_triggered) {
241
225
  const auto msg = common_chat_parse(data.delta, data.params.format);
242
- assert_msg_equals(expected_msg, msg);
226
+ assert_msg_equals(test_message, msg);
243
227
  }
244
228
 
245
- if (!expected_msg.tool_calls.empty()) {
229
+ if (!test_message.tool_calls.empty()) {
246
230
  GGML_ASSERT(!data.params.grammar.empty());
247
231
  }
248
232
  if (!data.params.grammar.empty()) {
@@ -253,12 +237,35 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
253
237
  auto earliest_trigger_pos = std::string::npos;
254
238
  auto constrained = data.delta;
255
239
  for (const auto & trigger : data.params.grammar_triggers) {
256
- auto pos = constrained.find(trigger.word);
257
- if (pos == std::string::npos) {
258
- continue;
240
+ size_t pos = std::string::npos;
241
+ std::smatch match;
242
+ switch (trigger.type) {
243
+ case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
244
+ {
245
+ const auto & word = trigger.value;
246
+ pos = constrained.find(word);
247
+ break;
248
+ }
249
+ case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
250
+ {
251
+ const auto & pattern = trigger.value;
252
+ if (std::regex_search(constrained, match, std::regex(pattern))) {
253
+ pos = match.position();
254
+ }
255
+ break;
256
+ }
257
+ case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
258
+ {
259
+ const auto & pattern = trigger.value;
260
+ if (std::regex_search(constrained, match, std::regex(pattern)) && match.position() == 0) {
261
+ pos = 0;
262
+ }
263
+ break;
264
+ }
265
+ default:
266
+ throw std::runtime_error("Unknown trigger type");
259
267
  }
260
- if (pos > 0 && trigger.at_start) {
261
- fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str());
268
+ if (pos == std::string::npos) {
262
269
  continue;
263
270
  }
264
271
  if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
@@ -274,161 +281,363 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
274
281
  assert_equals(expect_grammar_triggered, grammar_triggered);
275
282
  }
276
283
 
277
- if (grammar_triggered && !match_string(constrained, grammar.get())) {
284
+ if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) {
278
285
  throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
279
- "\n\nGrammar: " + data.params.grammar);
286
+ "\n\nConstrained: " + constrained +
287
+ "\n\nGrammar: " + data.params.grammar);
280
288
  }
281
289
  }
282
290
  }
283
291
  }
284
292
 
285
- static void test_template_output_parsers() {
286
- json text_message {
287
- { "role", "assistant" },
288
- { "content", "Hello, world!\nWhat's up?" },
289
- };
290
- json tool_calls = json::array({{
291
- { "type", "function" },
292
- { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
293
- }});
294
-
295
- json tool_call_message {
296
- { "role", "assistant"},
297
- { "content", {}},
298
- { "tool_calls", {
299
- {
300
- { "type", "function" },
301
- { "function", {
302
- { "name", "special_function" },
303
- { "arguments", "{\"arg1\": 1}" },
304
- }},
305
- },
306
- }},
307
- };
308
- json tool_call_message_with_id {
309
- { "role", "assistant"},
310
- { "content", {}},
311
- { "tool_calls", {
312
- {
313
- { "type", "function" },
314
- { "function", {
315
- { "name", "special_function" },
316
- { "arguments", "{\"arg1\": 1}" },
317
- }},
318
- {"id", "123456789"},
319
- },
320
- }},
321
- { "role", "assistant" },
322
- { "content", {} },
323
- { "tool_calls", tool_calls }
324
- };
325
- json tool_call_plan_message_with_idx {
326
- { "role", "assistant"},
327
- { "content", {}},
328
- { "tool_plan", "I'm not so sure"},
329
- { "tool_calls", {
330
- {
331
- { "type", "function" },
332
- { "function", {
333
- { "name", "special_function" },
334
- { "arguments", "{\"arg1\": 1}" },
335
- }},
336
- // Index of the tool call in the tool_calls array
337
- {"id", "0"},
338
- },
339
- }},
340
- { "role", "assistant" },
341
- { "content", {} },
342
- { "tool_calls", tool_calls }
343
- };
293
+ const common_chat_msg message_user {
294
+ "user",
295
+ "Hey there!",
296
+ /* .content_parts = */ {},
297
+ /* .tool_calls = */ {},
298
+ /* .reasoning_content = */ "",
299
+ /* .tool_name = */ "",
300
+ /* .tool_call_id = */ "",
301
+ };
344
302
 
345
- auto python_tool_call_message = json{
346
- { "role", "assistant" },
347
- { "content", {} },
348
- { "tool_calls", json{ {
349
- { "type", "function" },
350
- { "function",
351
- {
352
- { "name", "python" },
353
- { "arguments",
354
- {
355
- { "code", "print('hey')" },
356
- } },
357
- } },
358
- } } }
359
- };
360
- auto code_interpreter_tool_call_message = json{
361
- { "role", "assistant" },
362
- { "content", {} },
363
- { "tool_calls", json{ {
364
- { "type", "function" },
365
- { "function",
366
- {
367
- { "name", "code_interpreter" },
368
- { "arguments",
369
- {
370
- { "code", "print('hey')" },
371
- } },
372
- } },
373
- } } }
303
+ const common_chat_msg message_user_parts {
304
+ "user",
305
+ /* .content = */ "",
306
+ /* .content_parts = */ {
307
+ { "text", "Hey" },
308
+ { "text", "there" },
309
+ },
310
+ /* .tool_calls = */ {},
311
+ /* .reasoning_content = */ "",
312
+ /* .tool_name = */ "",
313
+ /* .tool_call_id = */ "",
314
+ };
315
+ const common_chat_msg message_assist {
316
+ "assistant",
317
+ "Hello, world!\nWhat's up?",
318
+ /* .content_parts = */ {},
319
+ /* .tool_calls = */ {},
320
+ /* .reasoning_content = */ "",
321
+ /* .tool_name = */ "",
322
+ /* .tool_call_id = */ "",
323
+ };
324
+ const common_chat_msg message_assist_thoughts_unparsed_think {
325
+ "assistant",
326
+ "<think>I'm thinking</think>Hello, world!\nWhat's up?",
327
+ /* .content_parts = */ {},
328
+ /* .tool_calls = */ {},
329
+ /* .reasoning_content = */ "",
330
+ /* .tool_name = */ "",
331
+ /* .tool_call_id = */ "",
332
+ };
333
+ const common_chat_msg message_assist_thoughts_unparsed_r7b {
334
+ "assistant",
335
+ "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?",
336
+ /* .content_parts = */ {},
337
+ /* .tool_calls = */ {},
338
+ /* .reasoning_content = */ "",
339
+ /* .tool_name = */ "",
340
+ /* .tool_call_id = */ "",
341
+ };
342
+ const common_chat_msg message_assist_thoughts {
343
+ "assistant",
344
+ "Hello, world!\nWhat's up?",
345
+ /* .content_parts = */ {},
346
+ /* .tool_calls = */ {},
347
+ /* .reasoning_content = */ "I'm thinking",
348
+ /* .tool_name = */ "",
349
+ /* .tool_call_id = */ "",
350
+ };
351
+ const std::vector<common_chat_tool_call> tool_calls {
352
+ { "special_function", "{\"arg1\": 1}", /* .id = */ "" },
353
+ };
354
+ const std::vector<common_chat_tool_call> tool_calls_idx {
355
+ { "special_function", "{\"arg1\": 1}", /* .id = */ "0" },
356
+ };
357
+ const std::vector<common_chat_tool_call> tool_calls_id {
358
+ { "special_function", "{\"arg1\": 1}", /* .id = */ "123456789" },
359
+ };
360
+
361
+ const common_chat_msg message_assist_call {
362
+ "assistant",
363
+ "",
364
+ /* .content_parts = */ {},
365
+ tool_calls,
366
+ /* .reasoning_content = */ "",
367
+ /* .tool_name = */ "",
368
+ /* .tool_call_id = */ "",
369
+ };
370
+ const common_chat_msg message_assist_call_thoughts = {
371
+ "assistant",
372
+ /* .content = */ "",
373
+ /* .content_parts = */ {},
374
+ tool_calls,
375
+ /* .reasoning_content = */ "I'm\nthinking",
376
+ /* .tool_name = */ "",
377
+ /* .tool_call_id = */ "",
378
+ };
379
+ const common_chat_msg message_assist_call_thoughts_unparsed = {
380
+ "assistant",
381
+ /* .content = */ "<think>I'm\nthinking</think>",
382
+ /* .content_parts = */ {},
383
+ tool_calls,
384
+ /* .reasoning_content = */ "",
385
+ /* .tool_name = */ "",
386
+ /* .tool_call_id = */ "",
387
+ };
388
+ const common_chat_msg message_assist_call_id {
389
+ "assistant",
390
+ "",
391
+ /* .content_parts = */ {},
392
+ tool_calls_id,
393
+ /* .reasoning_content = */ "",
394
+ /* .tool_name = */ "",
395
+ /* .tool_call_id = */ "",
396
+ };
397
+ const common_chat_msg message_assist_call_idx {
398
+ "assistant",
399
+ "",
400
+ /* .content_parts = */ {},
401
+ tool_calls_idx,
402
+ /* .reasoning_content = */ "",
403
+ /* .tool_name = */ "",
404
+ /* .tool_call_id = */ "",
405
+ };
406
+ const common_chat_msg message_assist_call_python {
407
+ "assistant",
408
+ "",
409
+ /* .content_parts = */ {},
410
+ { { "python", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
411
+ /* .reasoning_content = */ "",
412
+ /* .tool_name = */ "",
413
+ /* .tool_call_id = */ "",
414
+ };
415
+ const common_chat_msg message_assist_call_code_interpreter {
416
+ "assistant",
417
+ "",
418
+ /* .content_parts = */ {},
419
+ { { "code_interpreter", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
420
+ /* .reasoning_content = */ "",
421
+ /* .tool_name = */ "",
422
+ /* .tool_call_id = */ "",
423
+ };
424
+
425
+ static void test_msgs_oaicompat_json_conversion() {
426
+ std::vector<common_chat_msg> msgs{
427
+ message_user,
428
+ message_user_parts,
429
+ message_assist_call,
430
+ message_assist_call_thoughts,
431
+ message_assist_call_thoughts_unparsed,
432
+ message_assist_call_id,
433
+ message_assist_call_idx,
434
+ message_assist_call_python,
435
+ message_assist_call_code_interpreter,
374
436
  };
437
+ for (const auto & msg : msgs) {
438
+ auto oai_json = common_chat_msgs_to_json_oaicompat<json>({msg});
439
+ auto msgs2 = common_chat_msgs_parse_oaicompat(oai_json);
440
+ assert_equals((size_t) 1, msgs2.size());
441
+ auto msg2 = msgs2[0];
442
+ assert_msg_equals(msg, msg2);
443
+ }
444
+ assert_equals(
445
+ std::string(
446
+ "[\n"
447
+ " {\n"
448
+ " \"role\": \"user\",\n"
449
+ " \"content\": [\n"
450
+ " {\n"
451
+ " \"type\": \"text\",\n"
452
+ " \"text\": \"Hey\"\n"
453
+ " },\n"
454
+ " {\n"
455
+ " \"type\": \"text\",\n"
456
+ " \"text\": \"there\"\n"
457
+ " }\n"
458
+ " ]\n"
459
+ " }\n"
460
+ "]"
461
+ ),
462
+ common_chat_msgs_to_json_oaicompat<json>({message_user_parts}).dump(2));
463
+
464
+ assert_equals(
465
+ std::string(
466
+ "[\n"
467
+ " {\n"
468
+ " \"role\": \"assistant\",\n"
469
+ " \"content\": null,\n"
470
+ " \"tool_calls\": [\n"
471
+ " {\n"
472
+ " \"type\": \"function\",\n"
473
+ " \"function\": {\n"
474
+ " \"name\": \"python\",\n"
475
+ " \"arguments\": \"{\\\"code\\\": \\\"print('hey')\\\"}\"\n"
476
+ " }\n"
477
+ " }\n"
478
+ " ]\n"
479
+ " }\n"
480
+ "]"
481
+ ),
482
+ common_chat_msgs_to_json_oaicompat<json>({message_assist_call_python}).dump(2));
483
+
484
+ auto res = common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\", \"tool_calls\": []}]"));
485
+ assert_equals<size_t>(1, res.size());
486
+ assert_equals<std::string>(res[0].role, "assistant");
487
+ assert_equals(true, res[0].content.empty());
488
+ assert_equals(true, res[0].tool_calls.empty());
489
+
490
+ try {
491
+ common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\"}]"));
492
+ throw std::runtime_error("Expected exception");
493
+ } catch (const std::exception & e) {
494
+ if (std::string(e.what()).find("'content'") == std::string::npos) {
495
+ throw std::runtime_error("Expected exception about missing 'content'");
496
+ }
497
+ }
498
+ }
375
499
 
376
- common_chat_inputs inputs_no_tools;
377
- inputs_no_tools.messages = {
378
- { { "role", "user" }, { "content", "Hey\nThere" } }
500
+ static void test_tools_oaicompat_json_conversion() {
501
+ std::vector<common_chat_tool> tools{
502
+ special_function_tool,
503
+ python_tool,
504
+ code_interpreter_tool,
379
505
  };
380
506
 
381
- common_chat_inputs inputs_tools = inputs_no_tools;
382
- inputs_tools.tools = json::array();
383
- inputs_tools.tools.push_back(special_function_tool);
507
+ for (const auto & tool : tools) {
508
+ auto oai_json = common_chat_tools_to_json_oaicompat<json>({tool});
509
+ auto tools2 = common_chat_tools_parse_oaicompat(oai_json);
510
+ assert_equals((size_t) 1, tools2.size());
511
+ auto tool2 = tools2[0];
512
+ assert_equals(tool.name, tool2.name);
513
+ assert_equals(tool.description, tool2.description);
514
+ assert_equals(json::parse(tool.parameters).dump(2), json::parse(tool2.parameters).dump(2));
515
+ }
516
+
517
+ assert_equals(
518
+ std::string(
519
+ "[\n"
520
+ " {\n"
521
+ " \"type\": \"function\",\n"
522
+ " \"function\": {\n"
523
+ " \"name\": \"special_function\",\n"
524
+ " \"description\": \"I'm special\",\n"
525
+ " \"parameters\": {\n"
526
+ " \"type\": \"object\",\n"
527
+ " \"properties\": {\n"
528
+ " \"arg1\": {\n"
529
+ " \"type\": \"integer\",\n"
530
+ " \"description\": \"The arg.\"\n"
531
+ " }\n"
532
+ " },\n"
533
+ " \"required\": [\n"
534
+ " \"arg1\"\n"
535
+ " ]\n"
536
+ " }\n"
537
+ " }\n"
538
+ " }\n"
539
+ "]"
540
+ ),
541
+ common_chat_tools_to_json_oaicompat<json>({special_function_tool}).dump(2));
542
+ }
543
+
544
+ static void test_template_output_parsers() {
545
+
546
+ common_chat_templates_inputs inputs_no_tools;
547
+ inputs_no_tools.messages = {message_user};
548
+ inputs_no_tools.extract_reasoning = false;
549
+
550
+ common_chat_templates_inputs inputs_no_tools_think;
551
+ inputs_no_tools_think.messages = {message_user};
552
+ inputs_no_tools_think.extract_reasoning = true;
553
+
554
+ common_chat_templates_inputs inputs_tools;
555
+ inputs_tools.messages = {message_user};
556
+ inputs_tools.tools = {special_function_tool};
557
+ inputs_tools.extract_reasoning = false;
384
558
 
385
- common_chat_inputs inputs_tools_builtin = inputs_no_tools;
386
- inputs_tools_builtin.tools = json::array();
387
- inputs_tools_builtin.tools.push_back(python_tool);
559
+ common_chat_templates_inputs inputs_tools_think;
560
+ inputs_tools_think.messages = {message_user};
561
+ inputs_tools_think.tools = {special_function_tool};
562
+ inputs_tools_think.extract_reasoning = true;
563
+
564
+ common_chat_templates_inputs inputs_tools_builtin;
565
+ inputs_tools_builtin.messages = {message_user};
566
+ inputs_tools_builtin.tools = {python_tool};
567
+ inputs_tools_builtin.extract_reasoning = false;
388
568
 
389
569
  {
390
570
  // Not supported yet
391
- const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "<s>", "</s>");
392
- assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
571
+ auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja");
572
+ assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
393
573
  }
394
574
  {
395
- const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>");
575
+ auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja");
396
576
  std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
397
577
 
398
- assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
399
- assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
400
-
401
- test_template(tmpl, end_tokens, tool_call_plan_message_with_idx, tools,
402
- "<|START_THINKING|>I'm not so sure<|END_THINKING|>"
578
+ assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
579
+ assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
580
+ assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
581
+
582
+ assert_msg_equals(message_assist,
583
+ common_chat_parse(
584
+ "Hello, world!\nWhat's up?",
585
+ COMMON_CHAT_FORMAT_COMMAND_R7B));
586
+ assert_msg_equals(message_assist,
587
+ common_chat_parse(
588
+ "Hello, world!\nWhat's up?<|END_RESPONSE|>",
589
+ COMMON_CHAT_FORMAT_COMMAND_R7B));
590
+ assert_msg_equals(message_assist,
591
+ common_chat_parse(
592
+ "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
593
+ COMMON_CHAT_FORMAT_COMMAND_R7B));
594
+ assert_msg_equals(message_assist_thoughts_unparsed_r7b,
595
+ common_chat_parse(
596
+ "<|START_THINKING|>I'm thinking<|END_THINKING|>"
597
+ "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
598
+ COMMON_CHAT_FORMAT_COMMAND_R7B));
599
+ assert_msg_equals(message_assist_thoughts_unparsed_r7b,
600
+ common_chat_parse(
601
+ "<|START_THINKING|>I'm thinking<|END_THINKING|>"
602
+ "Hello, world!\nWhat's up?<|END_RESPONSE|>",
603
+ COMMON_CHAT_FORMAT_COMMAND_R7B));
604
+
605
+ assert_msg_equals(message_assist_thoughts,
606
+ common_chat_parse(
607
+ "<|START_THINKING|>I'm thinking<|END_THINKING|>"
608
+ "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
609
+ COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING));
610
+
611
+ test_templates(tmpls.get(), end_tokens, message_assist_call_idx, tools,
612
+ "<|START_THINKING|><|END_THINKING|>"
403
613
  "<|START_ACTION|>[\n"
404
614
  " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
405
615
  "]<|END_ACTION|>");
406
- test_template(tmpl, end_tokens, text_message, tools,
616
+ test_templates(tmpls.get(), end_tokens, message_assist, tools,
407
617
  "<|START_RESPONSE|>Hello, world!\n"
408
618
  "What's up?<|END_RESPONSE|>",
409
619
  /* expect_grammar_triggered= */ false);
410
620
  }
411
621
  {
412
- const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
622
+ auto tmpls = read_templates("models/templates/google-gemma-2-2b-it.jinja");
413
623
  std::vector<std::string> end_tokens{ "<end_of_turn>" };
414
624
 
415
- assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
416
- assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
625
+ assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
626
+ assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
417
627
  assert_equals(COMMON_CHAT_FORMAT_GENERIC,
418
- common_chat_params_init(
419
- common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"),
420
- "<s>", "</s>"),
628
+ common_chat_templates_apply(
629
+ read_templates("models/templates/microsoft-Phi-3.5-mini-instruct.jinja").get(),
421
630
  inputs_tools)
422
631
  .format);
423
632
 
424
633
  // Generic tool calls doesn't generate / parse content-only messages symmetrically.
425
634
 
426
- assert_msg_equals(msg_from_json(text_message),
635
+ assert_msg_equals(message_assist,
427
636
  common_chat_parse("{\n"
428
637
  " \"response\": \"Hello, world!\\nWhat's up?\"\n"
429
638
  "}",
430
- common_chat_params_init(tmpl, inputs_tools).format));
431
- test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
639
+ common_chat_templates_apply(tmpls.get(), inputs_tools).format));
640
+ test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools,
432
641
  "{\n"
433
642
  " \"tool_calls\": [\n"
434
643
  " {\n"
@@ -442,166 +651,323 @@ static void test_template_output_parsers() {
442
651
  "}");
443
652
  }
444
653
  {
445
- const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>",
446
- "</s>");
654
+ auto tmpls = read_templates("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja");
447
655
  std::vector<std::string> end_tokens{ "</s>" };
448
656
 
449
- assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
657
+ assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
450
658
 
451
- test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
452
- test_template(
453
- tmpl, end_tokens, tool_call_message_with_id, tools,
659
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
660
+ test_templates(
661
+ tmpls.get(), end_tokens, message_assist_call_id, tools,
454
662
  "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
455
663
  }
456
664
  {
457
- const common_chat_template tmpl(
458
- read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
665
+ auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
459
666
  std::vector<std::string> end_tokens{ "<|im_end|>" };
460
667
 
461
- assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(tmpl, inputs_tools).format);
668
+ assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
462
669
  assert_equals(
463
670
  COMMON_CHAT_FORMAT_HERMES_2_PRO,
464
- common_chat_params_init(
465
- common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"),
466
- "<s>", "</s>"),
671
+ common_chat_templates_apply(
672
+ read_templates("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja").get(),
467
673
  inputs_tools)
468
674
  .format);
469
675
  assert_equals(
470
676
  COMMON_CHAT_FORMAT_HERMES_2_PRO,
471
- common_chat_params_init(
472
- common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"),
677
+ common_chat_templates_apply(
678
+ read_templates("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja").get(),
473
679
  inputs_tools)
474
680
  .format);
475
681
 
476
- test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
477
- test_template(tmpl, end_tokens, tool_call_message, tools,
682
+ // Test parsing
683
+ assert_msg_equals(message_assist_call, common_chat_parse(
684
+ "<tool_call>\n"
685
+ "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
686
+ "</tool_call>",
687
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
688
+ assert_msg_equals(message_assist_call, common_chat_parse(
689
+ "<function=special_function>{\"arg1\": 1}</function>",
690
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
691
+ assert_msg_equals(message_assist_call, common_chat_parse(
692
+ "<function name=\"special_function\">\n"
693
+ "{\"arg1\": 1}\n"
694
+ "</function>",
695
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
696
+ assert_msg_equals(message_assist_call, common_chat_parse(
697
+ "<tool>\n"
698
+ " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
699
+ "</tool>",
700
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
701
+ assert_msg_equals(message_assist_call, common_chat_parse(
702
+ "<tools>\n"
703
+ " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
704
+ "</tools>",
705
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
706
+ assert_msg_equals(message_assist_call, common_chat_parse(
707
+ "<response>\n"
708
+ " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
709
+ "</response>",
710
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
711
+ assert_msg_equals(message_assist_call, common_chat_parse(
712
+ "```xml\n"
713
+ "<response>\n"
714
+ " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
715
+ "</response>\n"
716
+ "```",
717
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
718
+ assert_msg_equals(message_assist_call, common_chat_parse(
719
+ "```xml\n"
720
+ " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
721
+ "```",
722
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
723
+ assert_msg_equals(message_assist_call, common_chat_parse(
724
+ "```\n"
725
+ " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
726
+ "```",
727
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
728
+ assert_msg_equals(message_assist_call, common_chat_parse(
729
+ "```\n"
730
+ "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
731
+ "```",
732
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
733
+ assert_msg_equals(message_assist_call, common_chat_parse(
734
+ "```json\n"
735
+ " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
736
+ "```",
737
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
738
+ assert_msg_equals(message_assist_call, common_chat_parse(
739
+ "```json\n"
740
+ "\n"
741
+ " <function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}} \n"
742
+ " </function_call> \n"
743
+ "``` ",
744
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
745
+ assert_msg_equals(message_assist_call, common_chat_parse(
746
+ "<json>\n"
747
+ " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
748
+ "</json>",
749
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
750
+ assert_msg_equals(message_assist_call, common_chat_parse(
751
+ "<xml>\n"
752
+ " {\n"
753
+ " \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}\n"
754
+ " }\n"
755
+ "</xml>",
756
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
757
+ assert_msg_equals(message_assist_call, common_chat_parse(
758
+ "<JSON>\n"
759
+ " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
760
+ "</JSON>",
761
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
762
+ assert_msg_equals(message_assist_call, common_chat_parse(
763
+ "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
764
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
765
+ assert_msg_equals(message_assist_call, common_chat_parse(
766
+ "{\n \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
767
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
768
+
769
+ assert_msg_equals(message_assist_thoughts_unparsed_think,
770
+ common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
771
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
772
+ assert_msg_equals(message_assist_thoughts_unparsed_think,
773
+ common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
774
+ COMMON_CHAT_FORMAT_HERMES_2_PRO));
775
+ assert_msg_equals(message_assist_thoughts,
776
+ common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
777
+ COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING));
778
+ assert_msg_equals(message_assist_thoughts,
779
+ common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
780
+ COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING));
781
+
782
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
783
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
478
784
  "<tool_call>\n"
479
785
  "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
480
786
  "</tool_call>");
481
- test_template(tmpl, end_tokens, python_tool_call_message, tools,
787
+ test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools,
482
788
  "<tool_call>\n"
483
789
  "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
484
790
  "</tool_call>");
485
791
  }
486
792
  {
487
- const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>",
488
- "</s>");
793
+ auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja");
489
794
  std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
490
795
 
491
- assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
796
+ assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
492
797
  assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
493
- common_chat_params_init(tmpl, inputs_tools_builtin).format);
798
+ common_chat_templates_apply(tmpls.get(), inputs_tools_builtin).format);
494
799
  assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
495
- common_chat_params_init(
496
- common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"),
497
- "<s>", "</s>"),
800
+ common_chat_templates_apply(
801
+ read_templates("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja").get(),
498
802
  inputs_tools_builtin)
499
803
  .format);
500
804
 
501
- // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* expect_grammar_triggered= */ false);
502
- test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
805
+ // test_templates(tmpls.get(), end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
806
+ test_templates(tmpls.get(), end_tokens, message_assist_call_code_interpreter, llama_3_1_tools,
503
807
  "<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
504
- test_template(tmpl, end_tokens, python_tool_call_message, tools,
808
+ test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools,
505
809
  "<|python_tag|>python.call(code=\"print('hey')\")");
506
- test_template(tmpl, end_tokens, tool_call_message, tools,
810
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
507
811
  "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
508
812
  }
509
813
  {
510
- const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>",
511
- "</s>");
814
+ auto tmpls = read_templates("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja");
512
815
  std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
513
816
 
514
- assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
817
+ assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
515
818
 
516
- test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
517
- test_template(tmpl, end_tokens, tool_call_message, tools,
819
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
820
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
518
821
  "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
519
822
  }
520
823
  {
521
- const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>",
522
- "</s>");
824
+ auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja");
523
825
  std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
524
826
 
525
827
  assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
526
- common_chat_params_init(tmpl, inputs_tools).format);
828
+ common_chat_templates_apply(tmpls.get(), inputs_tools).format);
527
829
 
528
- test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
529
- test_template(tmpl, end_tokens, tool_call_message, tools,
830
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
831
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
530
832
  "<function=special_function>{\"arg1\": 1}</function>");
531
833
  }
532
834
  {
533
- const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>",
534
- "</s>");
835
+ auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.2.jinja");
535
836
  std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
536
837
 
537
- assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format);
538
- assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format);
838
+ assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
839
+ assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
539
840
 
540
- test_template(tmpl, end_tokens, text_message, {},
841
+ test_templates(tmpls.get(), end_tokens, message_assist, {},
541
842
  "all\n"
542
843
  "Hello, world!\n"
543
844
  "What's up?",
544
845
  /* expect_grammar_triggered= */ false);
545
- test_template(tmpl, end_tokens, tool_call_message, tools,
846
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
546
847
  "special_function\n"
547
848
  "{\"arg1\": 1}");
548
849
  }
549
850
  {
550
- const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>",
551
- "</s>");
851
+ auto tmpls = read_templates("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja");
552
852
  std::vector<std::string> end_tokens{ "<|eot_id|>" };
553
853
 
554
- assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
854
+ assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
555
855
 
556
- test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
557
- test_template(tmpl, end_tokens, tool_call_message, tools,
856
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
857
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
558
858
  " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
559
859
  }
560
860
  {
561
- const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"),
562
- "<s>", "</s>");
861
+ // Original DeepSeek R1 template. Leaves <|tool▁calls▁begin|> and others unclosed. Our logic fixes the prompt.
862
+ auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja");
563
863
  std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
564
864
 
565
- assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
865
+ assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
866
+ assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
867
+
868
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
869
+ test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
870
+ assert_msg_equals(message_assist_thoughts_unparsed_think,
871
+ common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
872
+ COMMON_CHAT_FORMAT_DEEPSEEK_R1));
873
+ assert_msg_equals(message_assist_thoughts,
874
+ common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
875
+ COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
876
+ assert_msg_equals(message_assist_thoughts,
877
+ // Latest template update (ast of 20250209) adds a trailing <think>\n if add_generation_prompt is true.
878
+ common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
879
+ COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
880
+ // test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
881
+ // "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
882
+ // "```json\n"
883
+ // "{\"arg1\": 1}\n"
884
+ // // Look what's not here: <|tool▁calls▁end|> (also missing the <|end▁of▁sentence|>, but that is removed lazily by the test's delta logic)
885
+ // "```<|tool▁call▁end|>",
886
+ // /* expect_grammar_triggered= */ true,
887
+ // /* test_grammar_if_triggered= */ false);
888
+ }
889
+ {
890
+ // Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all.
891
+ auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja");
892
+ std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
566
893
 
567
- test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
568
- test_template(tmpl, end_tokens, tool_call_message, tools,
569
- "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
570
- "```json\n"
571
- "{\"arg1\": 1}\n"
572
- "```<|tool▁call▁end|>");
894
+ assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
895
+ assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
896
+
897
+ test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
898
+ test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
899
+ assert_msg_equals(message_assist_thoughts_unparsed_think,
900
+ common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
901
+ COMMON_CHAT_FORMAT_DEEPSEEK_R1));
902
+ assert_msg_equals(message_assist_thoughts,
903
+ common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
904
+ COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
905
+
906
+ assert_msg_equals(message_assist_call_thoughts_unparsed,
907
+ common_chat_parse(
908
+ "<think>I'm\nthinking</think>\n\n"
909
+ "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
910
+ "```json\n"
911
+ "{\"arg1\": 1}\n"
912
+ "```<|tool▁call▁end|><|tool▁calls▁end|>",
913
+ COMMON_CHAT_FORMAT_DEEPSEEK_R1));
914
+ assert_msg_equals(message_assist_call_thoughts,
915
+ common_chat_parse(
916
+ "<think>I'm\nthinking</think>\n\n"
917
+ "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
918
+ "```json\n"
919
+ "{\"arg1\": 1}\n"
920
+ "```<|tool▁call▁end|><|tool▁calls▁end|>",
921
+ COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
922
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
923
+ "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
924
+ "```json\n"
925
+ "{\"arg1\": 1}\n"
926
+ "```<|tool▁call▁end|><|tool▁calls▁end|>");
573
927
  }
574
928
  }
575
929
 
576
930
  int main(int argc, char ** argv) {
931
+ // try {
577
932
  #ifndef _WIN32
578
- if (argc > 1) {
579
- common_chat_inputs inputs;
580
- inputs.messages = {
581
- { { "role", "user" }, { "content", "Hey" } }
582
- };
583
- inputs.tools = json::array({ special_function_tool });
584
-
585
- std::cout << "| Template | Format |\n";
586
- std::cout << "|----------|--------|\n";
587
-
588
- for (int i = 1; i < argc; i++) {
589
- std::string path = argv[i];
590
- if (path.rfind(".jinja") != path.size() - 6) {
591
- std::cerr << "Skipping non-jinja file: " << path << std::endl;
592
- continue;
933
+ if (argc > 1) {
934
+ common_chat_templates_inputs inputs;
935
+ common_chat_msg msg;
936
+ msg.role = "user";
937
+ msg.content = "Hey";
938
+ inputs.messages = {msg};
939
+ inputs.tools = { special_function_tool };
940
+
941
+ std::cout << "| Template | Format |\n";
942
+ std::cout << "|----------|--------|\n";
943
+
944
+ for (int i = 1; i < argc; i++) {
945
+ try {
946
+ std::string path = argv[i];
947
+ if (path.rfind(".jinja") != path.size() - 6) {
948
+ std::cerr << "Skipping non-jinja file: " << path << '\n';
949
+ continue;
950
+ }
951
+ auto tmpls = read_templates(path);
952
+ auto parts = string_split(path, "/");
953
+ auto name = parts[parts.size() - 1];
954
+ auto format = common_chat_format_name(common_chat_templates_apply(tmpls.get(), inputs).format);
955
+ std::cout << "| " << name << " | " << format << " |\n";
956
+ } catch (const std::exception & e) {
957
+ std::cerr << "Failed to process " << argv[i] << ": " << e.what() << '\n';
958
+ }
593
959
  }
594
- common_chat_template tmpl(read_file(path), "", "");
595
- auto parts = string_split(path, "/");
596
- auto name = parts[parts.size() - 1];
597
- std::cout << "| " << name << " | " << common_chat_format_name(common_chat_params_init(tmpl, inputs).format)
598
- << " |\n";
599
- }
600
- } else
960
+ } else
601
961
  #endif
602
- {
603
- test_template_output_parsers();
604
- std::cout << "\n[chat] All tests passed!" << std::endl;
605
- }
606
- return 0;
962
+ {
963
+ test_msgs_oaicompat_json_conversion();
964
+ test_tools_oaicompat_json_conversion();
965
+ test_template_output_parsers();
966
+ std::cout << "\n[chat] All tests passed!" << '\n';
967
+ }
968
+ return 0;
969
+ // } catch (const std::exception & e) {
970
+ // std::cerr << "Error: " << e.what() << '\n';
971
+ // return 1;
972
+ // }
607
973
  }