@fugood/llama.node 0.3.12 → 0.3.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.ts +2 -1
  18. package/package.json +1 -1
  19. package/src/LlamaCompletionWorker.cpp +14 -0
  20. package/src/LlamaContext.cpp +110 -79
  21. package/src/LlamaContext.h +1 -1
  22. package/src/common.hpp +1 -2
  23. package/src/llama.cpp/.github/workflows/build.yml +95 -13
  24. package/src/llama.cpp/.github/workflows/docker.yml +2 -0
  25. package/src/llama.cpp/.github/workflows/labeler.yml +1 -1
  26. package/src/llama.cpp/.github/workflows/server.yml +2 -0
  27. package/src/llama.cpp/common/CMakeLists.txt +23 -6
  28. package/src/llama.cpp/common/arg.cpp +292 -14
  29. package/src/llama.cpp/common/chat.cpp +1128 -315
  30. package/src/llama.cpp/common/chat.h +135 -0
  31. package/src/llama.cpp/common/common.cpp +27 -171
  32. package/src/llama.cpp/common/common.h +41 -73
  33. package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
  34. package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
  35. package/src/llama.cpp/common/llguidance.cpp +3 -3
  36. package/src/llama.cpp/common/log.cpp +1 -0
  37. package/src/llama.cpp/common/log.h +2 -1
  38. package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +21 -7
  39. package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +61 -14
  40. package/src/llama.cpp/common/ngram-cache.cpp +1 -0
  41. package/src/llama.cpp/common/sampling.cpp +93 -49
  42. package/src/llama.cpp/common/speculative.cpp +6 -5
  43. package/src/llama.cpp/common/speculative.h +1 -1
  44. package/src/llama.cpp/docs/build.md +47 -9
  45. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
  46. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
  47. package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
  48. package/src/llama.cpp/examples/imatrix/imatrix.cpp +4 -4
  49. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +6 -5
  50. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +1 -1
  51. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
  52. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  53. package/src/llama.cpp/examples/llava/clip.cpp +373 -107
  54. package/src/llama.cpp/examples/llava/clip.h +19 -3
  55. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
  56. package/src/llama.cpp/examples/llava/llava.cpp +4 -2
  57. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
  58. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
  59. package/src/llama.cpp/examples/main/main.cpp +73 -28
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +1 -0
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
  64. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
  65. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
  66. package/src/llama.cpp/examples/run/run.cpp +115 -79
  67. package/src/llama.cpp/examples/server/CMakeLists.txt +1 -1
  68. package/src/llama.cpp/examples/server/httplib.h +381 -292
  69. package/src/llama.cpp/examples/server/server.cpp +134 -128
  70. package/src/llama.cpp/examples/server/utils.hpp +95 -106
  71. package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
  72. package/src/llama.cpp/examples/tts/tts.cpp +251 -142
  73. package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
  74. package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
  75. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
  76. package/src/llama.cpp/ggml/include/ggml-cpu.h +4 -1
  77. package/src/llama.cpp/ggml/include/ggml-metal.h +1 -1
  78. package/src/llama.cpp/ggml/include/ggml-vulkan.h +0 -2
  79. package/src/llama.cpp/ggml/include/ggml.h +6 -2
  80. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
  81. package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
  82. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
  83. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
  84. package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
  85. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
  86. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
  87. package/src/llama.cpp/ggml/src/ggml-common.h +0 -2
  88. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
  89. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
  90. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
  91. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +156 -11
  93. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +2235 -641
  94. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1572 -198
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +24 -5
  96. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
  97. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +9 -8
  101. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +16 -3
  102. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
  103. package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
  104. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
  105. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
  106. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
  107. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
  108. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +246 -120
  109. package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
  110. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
  111. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
  112. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
  114. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
  115. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
  116. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
  117. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
  119. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
  120. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
  121. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +174 -728
  124. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
  125. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
  126. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  127. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  128. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +949 -602
  129. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +37 -3
  130. package/src/llama.cpp/ggml/src/ggml.c +9 -4
  131. package/src/llama.cpp/include/llama.h +32 -14
  132. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
  133. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
  134. package/src/llama.cpp/requirements/requirements-all.txt +1 -0
  135. package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
  136. package/src/llama.cpp/requirements.txt +1 -0
  137. package/src/llama.cpp/src/llama-arch.cpp +21 -0
  138. package/src/llama.cpp/src/llama-arch.h +1 -0
  139. package/src/llama.cpp/src/llama-chat.cpp +1 -0
  140. package/src/llama.cpp/src/llama-grammar.cpp +183 -183
  141. package/src/llama.cpp/src/llama-grammar.h +13 -4
  142. package/src/llama.cpp/src/llama-impl.h +6 -6
  143. package/src/llama.cpp/src/llama-kv-cache.h +2 -1
  144. package/src/llama.cpp/src/llama-mmap.cpp +11 -1
  145. package/src/llama.cpp/src/llama-mmap.h +1 -0
  146. package/src/llama.cpp/src/llama-model.cpp +70 -6
  147. package/src/llama.cpp/src/llama-sampling.cpp +174 -67
  148. package/src/llama.cpp/src/llama-vocab.cpp +12 -0
  149. package/src/llama.cpp/src/llama.cpp +154 -5
  150. package/src/llama.cpp/src/unicode.cpp +9 -2
  151. package/src/llama.cpp/tests/test-backend-ops.cpp +171 -115
  152. package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
  153. package/src/llama.cpp/tests/test-chat.cpp +691 -325
  154. package/src/llama.cpp/tests/test-gguf.cpp +4 -4
  155. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
  156. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
  157. package/src/llama.cpp/tests/test-sampling.cpp +15 -0
  158. package/src/llama.cpp/Sources/llama/llama.h +0 -4
  159. package/src/llama.cpp/common/chat.hpp +0 -52
@@ -1,8 +1,436 @@
1
- #include "chat.hpp"
2
- #include "chat-template.hpp"
1
+ #include "chat.h"
3
2
  #include "json-schema-to-grammar.h"
4
3
  #include "log.h"
5
- #include "minja.hpp"
4
+ #include "minja/chat-template.hpp"
5
+ #include "minja/minja.hpp"
6
+
7
+ #include <optional>
8
+
9
+ typedef minja::chat_template common_chat_template;
10
+
11
+ struct common_chat_templates {
12
+ bool has_explicit_template; // Model had builtin template or template overridde was specified.
13
+ std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
14
+ std::unique_ptr<common_chat_template> template_tool_use;
15
+ };
16
+
17
+ struct templates_params {
18
+ json messages;
19
+ json tools;
20
+ common_chat_tool_choice tool_choice;
21
+ json json_schema;
22
+ bool parallel_tool_calls;
23
+ bool stream;
24
+ std::string grammar;
25
+ bool add_generation_prompt = true;
26
+ bool extract_reasoning = true;
27
+ };
28
+
29
+ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
30
+ if (tool_choice == "auto") {
31
+ return COMMON_CHAT_TOOL_CHOICE_AUTO;
32
+ }
33
+ if (tool_choice == "none") {
34
+ return COMMON_CHAT_TOOL_CHOICE_NONE;
35
+ }
36
+ if (tool_choice == "required") {
37
+ return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
38
+ }
39
+ throw std::runtime_error("Invalid tool_choice: " + tool_choice);
40
+ }
41
+
42
+ template <>
43
+ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
44
+ std::vector<common_chat_msg> msgs;
45
+
46
+ try {
47
+
48
+ if (!messages.is_array()) {
49
+ throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
50
+ }
51
+
52
+ for (const auto & message : messages) {
53
+ if (!message.is_object()) {
54
+ throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
55
+ }
56
+
57
+ common_chat_msg msg;
58
+ if (!message.contains("role")) {
59
+ throw std::runtime_error("Missing 'role' in message: " + message.dump());
60
+ }
61
+ msg.role = message.at("role");
62
+
63
+ auto has_content = message.contains("content");
64
+ auto has_tool_calls = message.contains("tool_calls");
65
+ if (has_content) {
66
+ const auto & content = message.at("content");
67
+ if (content.is_string()) {
68
+ msg.content = content;
69
+ } else if (content.is_array()) {
70
+ for (const auto & part : content) {
71
+ if (!part.contains("type")) {
72
+ throw std::runtime_error("Missing content part type: " + part.dump());
73
+ }
74
+ const auto & type = part.at("type");
75
+ if (type != "text") {
76
+ throw std::runtime_error("Unsupported content part type: " + type.dump());
77
+ }
78
+ common_chat_msg_content_part msg_part;
79
+ msg_part.type = type;
80
+ msg_part.text = part.at("text");
81
+ msg.content_parts.push_back(msg_part);
82
+ }
83
+ } else if (!content.is_null()) {
84
+ throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
85
+ }
86
+ }
87
+ if (has_tool_calls) {
88
+ for (const auto & tool_call : message.at("tool_calls")) {
89
+ common_chat_tool_call tc;
90
+ if (!tool_call.contains("type")) {
91
+ throw std::runtime_error("Missing tool call type: " + tool_call.dump());
92
+ }
93
+ const auto & type = tool_call.at("type");
94
+ if (type != "function") {
95
+ throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
96
+ }
97
+ if (!tool_call.contains("function")) {
98
+ throw std::runtime_error("Missing tool call function: " + tool_call.dump());
99
+ }
100
+ const auto & fc = tool_call.at("function");
101
+ if (!fc.contains("name")) {
102
+ throw std::runtime_error("Missing tool call name: " + tool_call.dump());
103
+ }
104
+ tc.name = fc.at("name");
105
+ tc.arguments = fc.at("arguments");
106
+ if (tool_call.contains("id")) {
107
+ tc.id = tool_call.at("id");
108
+ }
109
+ msg.tool_calls.push_back(tc);
110
+ }
111
+ }
112
+ if (!has_content && !has_tool_calls) {
113
+ throw std::runtime_error("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)");
114
+ }
115
+ if (message.contains("reasoning_content")) {
116
+ msg.reasoning_content = message.at("reasoning_content");
117
+ }
118
+ if (message.contains("name")) {
119
+ msg.tool_name = message.at("name");
120
+ }
121
+ if (message.contains("tool_call_id")) {
122
+ msg.tool_call_id = message.at("tool_call_id");
123
+ }
124
+
125
+ msgs.push_back(msg);
126
+ }
127
+ } catch (const std::exception & e) {
128
+ throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2));
129
+ }
130
+
131
+ return msgs;
132
+ }
133
+
134
+ template <>
135
+ json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
136
+ json messages = json::array();
137
+ for (const auto & msg : msgs) {
138
+ if (!msg.content.empty() && !msg.content_parts.empty()) {
139
+ throw std::runtime_error("Cannot specify both content and content_parts");
140
+ }
141
+ json jmsg {
142
+ {"role", msg.role},
143
+ };
144
+ if (!msg.content.empty()) {
145
+ jmsg["content"] = msg.content;
146
+ } else if (!msg.content_parts.empty()) {
147
+ if (concat_typed_text) {
148
+ std::string text;
149
+ for (const auto & part : msg.content_parts) {
150
+ if (part.type != "text") {
151
+ LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
152
+ continue;
153
+ }
154
+ if (!text.empty()) {
155
+ text += '\n';
156
+ }
157
+ text += part.text;
158
+ }
159
+ jmsg["content"] = text;
160
+ } else {
161
+ auto & parts = jmsg["content"] = json::array();
162
+ for (const auto & part : msg.content_parts) {
163
+ parts.push_back({
164
+ {"type", part.type},
165
+ {"text", part.text},
166
+ });
167
+ }
168
+ }
169
+ } else {
170
+ jmsg["content"] = json(); // null
171
+ }
172
+ if (!msg.reasoning_content.empty()) {
173
+ jmsg["reasoning_content"] = msg.reasoning_content;
174
+ }
175
+ if (!msg.tool_name.empty()) {
176
+ jmsg["name"] = msg.tool_name;
177
+ }
178
+ if (!msg.tool_call_id.empty()) {
179
+ jmsg["tool_call_id"] = msg.tool_call_id;
180
+ }
181
+ if (!msg.tool_calls.empty()) {
182
+ auto & tool_calls = jmsg["tool_calls"] = json::array();
183
+ for (const auto & tool_call : msg.tool_calls) {
184
+ json tc {
185
+ {"type", "function"},
186
+ {"function", {
187
+ {"name", tool_call.name},
188
+ {"arguments", tool_call.arguments},
189
+ }},
190
+ };
191
+ if (!tool_call.id.empty()) {
192
+ tc["id"] = tool_call.id;
193
+ }
194
+ tool_calls.push_back(tc);
195
+ }
196
+ }
197
+ messages.push_back(jmsg);
198
+ }
199
+ return messages;
200
+ }
201
+
202
+ template <>
203
+ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
204
+ return common_chat_msgs_parse_oaicompat(json::parse(messages));
205
+ }
206
+
207
+ template <>
208
+ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
209
+ std::vector<common_chat_tool> result;
210
+
211
+ try {
212
+ if (!tools.is_null()) {
213
+ if (!tools.is_array()) {
214
+ throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
215
+ }
216
+ for (const auto & tool : tools) {
217
+ if (!tool.contains("type")) {
218
+ throw std::runtime_error("Missing tool type: " + tool.dump());
219
+ }
220
+ const auto & type = tool.at("type");
221
+ if (!type.is_string() || type != "function") {
222
+ throw std::runtime_error("Unsupported tool type: " + tool.dump());
223
+ }
224
+ if (!tool.contains("function")) {
225
+ throw std::runtime_error("Missing tool function: " + tool.dump());
226
+ }
227
+
228
+ const auto & function = tool.at("function");
229
+ result.push_back({
230
+ /* .name = */ function.at("name"),
231
+ /* .description = */ function.at("description"),
232
+ /* .parameters = */ function.at("parameters").dump(),
233
+ });
234
+ }
235
+ }
236
+ } catch (const std::exception & e) {
237
+ throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2));
238
+ }
239
+
240
+ return result;
241
+ }
242
+
243
+ template <>
244
+ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
245
+ return common_chat_tools_parse_oaicompat(json::parse(tools));
246
+ }
247
+
248
+ template <>
249
+ json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
250
+ if (tools.empty()) {
251
+ return json();
252
+ }
253
+
254
+ auto result = json::array();
255
+ for (const auto & tool : tools) {
256
+ result.push_back({
257
+ {"type", "function"},
258
+ {"function", {
259
+ {"name", tool.name},
260
+ {"description", tool.description},
261
+ {"parameters", json::parse(tool.parameters)},
262
+ }},
263
+ });
264
+ }
265
+ return result;
266
+ }
267
+
268
+ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
269
+ if (use_jinja) {
270
+ try {
271
+ common_chat_msg msg;
272
+ msg.role = "user";
273
+ msg.content = "test";
274
+
275
+ auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);
276
+
277
+ common_chat_templates_inputs inputs;
278
+ inputs.messages = {msg};
279
+
280
+ common_chat_templates_apply(tmpls.get(), inputs);
281
+ return true;
282
+ } catch (const std::exception & e) {
283
+ LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
284
+ return false;
285
+ }
286
+ }
287
+ llama_chat_message chat[] = {{"user", "test"}};
288
+ const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
289
+ return res >= 0;
290
+ }
291
+
292
+ std::string common_chat_format_single(
293
+ const struct common_chat_templates * tmpls,
294
+ const std::vector<common_chat_msg> & past_msg,
295
+ const common_chat_msg & new_msg,
296
+ bool add_ass,
297
+ bool use_jinja) {
298
+
299
+ common_chat_templates_inputs inputs;
300
+ inputs.use_jinja = use_jinja;
301
+
302
+ std::string fmt_past_msg;
303
+ if (!past_msg.empty()) {
304
+ inputs.messages = past_msg;
305
+ inputs.add_generation_prompt = false;
306
+ fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt;
307
+ }
308
+ std::ostringstream ss;
309
+ // if the past_msg ends with a newline, we must preserve it in the formatted version
310
+ if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
311
+ ss << "\n";
312
+ };
313
+ // format chat with new_msg
314
+ inputs.messages.push_back(new_msg);
315
+ inputs.add_generation_prompt = add_ass;
316
+ auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt;
317
+ // get the diff part
318
+ ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
319
+ return ss.str();
320
+ }
321
+
322
+ std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
323
+ common_chat_templates_inputs inputs;
324
+ inputs.use_jinja = use_jinja;
325
+ auto add_simple_msg = [&](auto role, auto content) {
326
+ common_chat_msg msg;
327
+ msg.role = role;
328
+ msg.content = content;
329
+ inputs.messages.push_back(msg);
330
+ };
331
+ add_simple_msg("system", "You are a helpful assistant");
332
+ add_simple_msg("user", "Hello");
333
+ add_simple_msg("assistant", "Hi there");
334
+ add_simple_msg("user", "How are you?");
335
+ return common_chat_templates_apply(tmpls, inputs).prompt;
336
+ }
337
+
338
+ #define CHATML_TEMPLATE_SRC \
339
+ "{%- for message in messages -%}\n" \
340
+ " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
341
+ "{%- endfor -%}\n" \
342
+ "{%- if add_generation_prompt -%}\n" \
343
+ " {{- '<|im_start|>assistant\n' -}}\n" \
344
+ "{%- endif -%}"
345
+
346
+ void common_chat_templates_free(struct common_chat_templates * tmpls) {
347
+ delete tmpls;
348
+ }
349
+
350
+ bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) {
351
+ return tmpls->has_explicit_template;
352
+ }
353
+
354
+ const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
355
+ if (variant != nullptr) {
356
+ if (strcmp(variant, "tool_use") == 0) {
357
+ if (tmpls->template_tool_use) {
358
+ return tmpls->template_tool_use->source().c_str();
359
+ }
360
+ return nullptr;
361
+ } else {
362
+ LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
363
+ }
364
+ }
365
+ return tmpls->template_default->source().c_str();
366
+ }
367
+
368
+ common_chat_templates_ptr common_chat_templates_init(
369
+ const struct llama_model * model,
370
+ const std::string & chat_template_override,
371
+ const std::string & bos_token_override,
372
+ const std::string & eos_token_override)
373
+ {
374
+ std::string default_template_src;
375
+ std::string template_tool_use_src;
376
+
377
+ bool has_explicit_template = !chat_template_override.empty();
378
+ if (chat_template_override.empty()) {
379
+ GGML_ASSERT(model != nullptr);
380
+ const auto * str = llama_model_chat_template(model, /* name */ nullptr);
381
+ if (str) {
382
+ default_template_src = str;
383
+ has_explicit_template = true;
384
+ }
385
+ str = llama_model_chat_template(model, /* name */ "tool_use");
386
+ if (str) {
387
+ template_tool_use_src = str;
388
+ has_explicit_template = true;
389
+ }
390
+ } else {
391
+ default_template_src = chat_template_override;
392
+ }
393
+ if (default_template_src.empty() || default_template_src == "chatml") {
394
+ if (!template_tool_use_src.empty()) {
395
+ default_template_src = template_tool_use_src;
396
+ } else {
397
+ default_template_src = CHATML_TEMPLATE_SRC;
398
+ }
399
+ }
400
+ std::string token_bos = bos_token_override;
401
+ std::string token_eos = eos_token_override;
402
+ if (model) {
403
+ const auto * vocab = llama_model_get_vocab(model);
404
+ const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
405
+ if (token == LLAMA_TOKEN_NULL) {
406
+ if (default_template_src.find(jinja_variable_name) != std::string::npos
407
+ || template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
408
+ LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name);
409
+ }
410
+ return std::string();
411
+ }
412
+ return common_token_to_piece(vocab, token, true);
413
+ };
414
+ token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
415
+ token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
416
+ }
417
+ common_chat_templates_ptr tmpls(new common_chat_templates());
418
+ tmpls->has_explicit_template = has_explicit_template;
419
+ try {
420
+ tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
421
+ } catch (const std::exception & e) {
422
+ LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
423
+ tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
424
+ }
425
+ if (!template_tool_use_src.empty()) {
426
+ try {
427
+ tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
428
+ } catch (const std::exception & e) {
429
+ LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
430
+ }
431
+ }
432
+ return tmpls;
433
+ }
6
434
 
7
435
  std::string common_chat_format_name(common_chat_format format) {
8
436
  switch (format) {
@@ -12,22 +440,19 @@ std::string common_chat_format_name(common_chat_format format) {
12
440
  case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
13
441
  case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
14
442
  case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
443
+ case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING: return "DeepSeek R1 (extract reasoning)";
15
444
  case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
16
445
  case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
17
446
  case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
18
447
  case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
448
+ case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING: return "Hermes 2 Pro (extract reasoning)";
19
449
  case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
450
+ case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)";
20
451
  default:
21
452
  throw std::runtime_error("Unknown chat format");
22
453
  }
23
454
  }
24
455
 
25
- const common_grammar_options grammar_options {
26
- /* .dotall = */ false,
27
- /* .compact_spaces = */ false,
28
- // /* .compact_spaces = */ true,
29
- };
30
-
31
456
  static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
32
457
  // // https://json.nlohmann.me/features/parsing/sax_interface/
33
458
  struct json_error_locator : public nlohmann::json_sax<json> {
@@ -36,22 +461,22 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
36
461
 
37
462
  json_error_locator() : position(0), found_error(false) {}
38
463
 
39
- bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
464
+ bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT
40
465
  this->position = position - 1;
41
466
  this->found_error = true;
42
467
  return false;
43
468
  }
44
- bool null() override { return true; }
45
- bool boolean(bool) override { return true; }
46
- bool number_integer(number_integer_t) override { return true; }
47
- bool number_unsigned(number_unsigned_t) override { return true; }
48
- bool number_float(number_float_t, const string_t &) override { return true; }
49
- bool string(string_t &) override { return true; }
50
- bool binary(binary_t &) override { return true; }
51
- bool start_object(std::size_t) override { return true; }
52
- bool key(string_t &) override { return true; }
469
+ bool null() override { return true; } // NOLINT
470
+ bool boolean(bool) override { return true; } // NOLINT
471
+ bool number_integer(number_integer_t) override { return true; } // NOLINT
472
+ bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT
473
+ bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT
474
+ bool string(string_t &) override { return true; } // NOLINT
475
+ bool binary(binary_t &) override { return true; } // NOLINT
476
+ bool start_object(std::size_t) override { return true; } // NOLINT
477
+ bool key(string_t &) override { return true; } // NOLINT
53
478
  bool end_object() override { return true; }
54
- bool start_array(std::size_t) override { return true; }
479
+ bool start_array(std::size_t) override { return true; } // NOLINT
55
480
  bool end_array() override { return true; }
56
481
  };
57
482
  json_error_locator err_loc;
@@ -73,6 +498,34 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
73
498
  }
74
499
  }
75
500
 
501
+ static bool parse_literal(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
502
+ auto expected_it = expected.begin();
503
+ auto tmp_it = it;
504
+ while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
505
+ ++tmp_it;
506
+ ++expected_it;
507
+ }
508
+ if (expected_it == expected.end()) {
509
+ it = tmp_it;
510
+ return true;
511
+ }
512
+ return false;
513
+ }
514
+
515
+ static std::optional<std::smatch> parse_pattern(std::string::const_iterator & it, const std::string::const_iterator & end, const std::regex & expected) {
516
+ std::smatch match;
517
+ if (std::regex_match(it, end, match, expected)) {
518
+ it = match.suffix().first;
519
+ return match;
520
+ }
521
+ return std::nullopt;
522
+ }
523
+
524
+ static void consume_spaces(std::string::const_iterator & it, const std::string::const_iterator & end) {
525
+ while (it != end && std::isspace(*it)) {
526
+ ++it;
527
+ }
528
+ }
76
529
 
77
530
  /**
78
531
  * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
@@ -82,7 +535,8 @@ static common_chat_msg parse_json_tool_calls(
82
535
  const std::string& input,
83
536
  const std::optional<std::regex> & trigger_opt,
84
537
  const std::regex & function_regex,
85
- const std::regex & close_regex) {
538
+ const std::regex & close_regex,
539
+ bool allow_raw_python = false) {
86
540
  std::smatch match;
87
541
 
88
542
  common_chat_msg result;
@@ -105,7 +559,6 @@ static common_chat_msg parse_json_tool_calls(
105
559
  std::sregex_iterator rend;
106
560
  std::sregex_iterator rit(it, end, function_regex);
107
561
  if (rit == rend) {
108
- fprintf(stderr, "No more tool calls found\n");
109
562
  result.content += std::string(it, end);
110
563
  break;
111
564
  }
@@ -114,48 +567,60 @@ static common_chat_msg parse_json_tool_calls(
114
567
  it = rit->suffix().first;
115
568
 
116
569
  json arguments;
117
- if (!parse_json(it, end, arguments)) {
118
- throw std::runtime_error("Failed to parse json tool call arguments");
570
+ if (parse_json(it, end, arguments)) {
571
+ if (!std::regex_search(it, end, match, close_regex)) {
572
+ throw std::runtime_error("Malformed input, missing closing pattern: " + input);
573
+ }
574
+ it = match.suffix().first;
575
+ result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
576
+ } else {
577
+ if (allow_raw_python && name == "python") {
578
+ result.tool_calls.push_back({name, json({{"code", std::string(it, end)}}).dump(), /* id= */ ""});
579
+ break;
580
+ }
581
+ throw std::runtime_error("Failed to parse json tool call arguments: " + input);
119
582
  }
120
- if (!std::regex_search(it, end, match, close_regex)) {
121
- throw std::runtime_error("Malformed input, missing closing pattern");
583
+ }
584
+
585
+ if (!result.tool_calls.empty()) {
586
+ if (!string_strip(result.content).empty()) {
587
+ LOG_WRN("Content found with tool calls: %s\n", result.content.c_str());
122
588
  }
123
- it = match.suffix().first;
124
- result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
589
+ result.content = "";
125
590
  }
126
591
  return result;
127
592
  }
128
593
 
594
+ static common_chat_tool_call process_tool_call(const json & tool_call) {
595
+ const auto & arguments = tool_call.at("arguments");
596
+ return {
597
+ /* .name = */ tool_call.at("name"),
598
+ /* .arguments = */ arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
599
+ /* .id = */ tool_call.contains("id") ? tool_call.at("id") : "",
600
+ };
601
+ }
129
602
  static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
130
603
  auto content_end = input.find(prefix);
131
604
  size_t tc_start = std::string::npos;
132
605
 
133
606
  common_chat_msg result;
134
607
  result.role = "assistant";
135
- const auto process_tool_calls = [&](const json & tool_calls) {
136
- for (const auto & tool_call : tool_calls) {
137
- const auto & arguments = tool_call["arguments"];
138
- result.tool_calls.push_back({
139
- tool_call["name"],
140
- arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
141
- tool_call.contains("id") ? tool_call["id"] : "",
142
- });
143
- }
144
- };
145
608
  if (content_end == std::string::npos) {
146
609
  result.content = input;
147
610
  } else {
148
611
  tc_start = content_end + prefix.size() - rstrip_prefix;
149
612
  result.content = input.substr(0, content_end);
150
613
  auto tool_calls = json::parse(input.substr(tc_start));
151
- process_tool_calls(tool_calls);
614
+ for (const auto & tool_call : tool_calls) {
615
+ result.tool_calls.emplace_back(process_tool_call(tool_call));
616
+ }
152
617
  }
153
618
  return result;
154
619
  }
155
620
 
156
621
  static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
157
622
  for (const auto & tool : tools) {
158
- if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
623
+ if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
159
624
  LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());
160
625
  continue;
161
626
  }
@@ -179,38 +644,45 @@ static std::string apply(
179
644
  // tmpl_inputs.now = std::chrono::system_clock::now();
180
645
 
181
646
  minja::chat_template_options tmpl_opts;
182
- tmpl_opts.use_bos_token = false;
183
- tmpl_opts.use_eos_token = false;
184
-
185
- return tmpl.apply(tmpl_inputs, tmpl_opts);
647
+ // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
648
+ // instead of using `chat_template_options.use_bos_token = false`, since these tokens
649
+ // may be needed inside the template / between messages too.
650
+ auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
651
+ if (string_starts_with(result, tmpl.bos_token())) {
652
+ result = result.substr(tmpl.bos_token().size());
653
+ }
654
+ if (string_ends_with(result, tmpl.eos_token())) {
655
+ result = result.substr(0, result.size() - tmpl.eos_token().size());
656
+ }
657
+ return result;
186
658
  }
187
659
 
188
- static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
660
+ static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
189
661
  common_chat_params data;
190
662
 
191
663
  auto tool_call_schemas = json::array();
192
664
  foreach_function(inputs.tools, [&](const json & tool) {
193
- const auto & function = tool["function"];
665
+ const auto & function = tool.at("function");
194
666
  auto tool_schema = json {
195
667
  {"type", "object"},
196
668
  {"properties", {
197
669
  {"name", {
198
670
  {"type", "string"},
199
- {"const", function["name"]},
671
+ {"const", function.at("name")},
200
672
  }},
201
- {"arguments", function["parameters"]},
673
+ {"arguments", function.at("parameters")},
202
674
  }},
203
675
  {"required", json::array({"name", "arguments"})},
204
676
  };
205
677
  if (function.contains("description")) {
206
- tool_schema["description"] = function["description"];
678
+ tool_schema["description"] = function.at("description");
207
679
  }
208
680
  if (inputs.parallel_tool_calls) {
209
- tool_schema["properties"]["id"] = {
681
+ tool_schema.at("properties")["id"] = {
210
682
  {"type", "string"},
211
683
  {"minLength", 4},
212
684
  };
213
- tool_schema["required"].push_back("id");
685
+ tool_schema.at("required").push_back("id");
214
686
  }
215
687
  tool_call_schemas.emplace_back(tool_schema);
216
688
  });
@@ -239,7 +711,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
239
711
  {"required", json::array({"tool_call"})},
240
712
  };
241
713
  const auto schema =
242
- inputs.tool_choice != "required"
714
+ inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED
243
715
  ? json {
244
716
  {"anyOf", json::array({
245
717
  tool_call,
@@ -260,7 +732,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
260
732
  data.grammar_lazy = false;
261
733
  data.grammar = build_grammar([&](const common_grammar_builder & builder) {
262
734
  builder.add_schema("root", schema);
263
- }, grammar_options);
735
+ });
264
736
 
265
737
  auto tweaked_messages = common_chat_template::add_system(
266
738
  inputs.messages,
@@ -275,33 +747,33 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) {
275
747
  common_chat_msg result;
276
748
  result.role = "assistant";
277
749
  if (data.contains("tool_calls")) {
278
- for (const auto & tool_call : data["tool_calls"]) {
750
+ for (const auto & tool_call : data.at("tool_calls")) {
279
751
  result.tool_calls.push_back({
280
- tool_call["name"],
281
- tool_call["arguments"].dump(),
282
- tool_call.contains("id") ? tool_call["id"] : "",
752
+ tool_call.at("name"),
753
+ tool_call.at("arguments").dump(),
754
+ tool_call.contains("id") ? tool_call.at("id") : "",
283
755
  });
284
756
  }
285
757
  } else if (data.contains("tool_call")) {
286
758
  result.tool_calls.push_back({
287
- data["tool_call"]["name"],
288
- data["tool_call"]["arguments"].dump(),
759
+ data.at("tool_call").at("name"),
760
+ data.at("tool_call").at("arguments").dump(),
289
761
  /* id= */ "",
290
762
  });
291
763
  } else if (data.contains("response")) {
292
- const auto & response = data["response"];
764
+ const auto & response = data.at("response");
293
765
  result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
294
766
  }
295
767
  return result;
296
768
  }
297
769
 
298
- static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
770
+ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
299
771
  common_chat_params data;
300
- data.grammar_lazy = inputs.tool_choice != "required";
772
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
301
773
  data.grammar = build_grammar([&](const common_grammar_builder & builder) {
302
774
  auto schemas = json::array();
303
775
  foreach_function(inputs.tools, [&](const json & tool) {
304
- const auto & function = tool["function"];
776
+ const auto & function = tool.at("function");
305
777
  schemas.push_back({
306
778
  {"type", "object"},
307
779
  {"properties", {
@@ -309,9 +781,9 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
309
781
  // It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
310
782
  {"name", {
311
783
  {"type", "string"},
312
- {"const", function["name"]},
784
+ {"const", function.at("name")},
313
785
  }},
314
- {"arguments", function["parameters"]},
786
+ {"arguments", function.at("parameters")},
315
787
  {"id", {
316
788
  {"type", "string"},
317
789
  // Nemo's template expects a 9-character alphanumeric ID.
@@ -330,8 +802,11 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
330
802
  schema["maxItems"] = 1;
331
803
  }
332
804
  builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
333
- }, grammar_options);
334
- data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
805
+ });
806
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"});
807
+ data.preserved_tokens = {
808
+ "[TOOL_CALLS]",
809
+ };
335
810
  data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
336
811
  data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
337
812
  return data;
@@ -340,13 +815,13 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
340
815
  return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
341
816
  }
342
817
 
343
- static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
818
+ static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
344
819
  common_chat_params data;
345
- data.grammar_lazy = inputs.tool_choice != "required";
820
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
346
821
  data.grammar = build_grammar([&](const common_grammar_builder & builder) {
347
822
  auto schemas = json::array();
348
823
  foreach_function(inputs.tools, [&](const json & tool) {
349
- const auto & function = tool["function"];
824
+ const auto & function = tool.at("function");
350
825
  schemas.push_back({
351
826
  {"type", "object"},
352
827
  {"properties", {
@@ -357,9 +832,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
357
832
  }},
358
833
  {"tool_name", {
359
834
  {"type", "string"},
360
- {"const", function["name"]},
835
+ {"const", function.at("name")},
361
836
  }},
362
- {"parameters", function["parameters"]},
837
+ {"parameters", function.at("parameters")},
363
838
  }},
364
839
  {"required", json::array({"tool_call_id", "tool_name", "parameters"})},
365
840
  });
@@ -373,58 +848,88 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
373
848
  schema["maxItems"] = 1;
374
849
  }
375
850
  builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
376
- }, grammar_options);
377
- data.grammar_triggers.push_back({"<|START_ACTION|>", /* .at_start = */ false});
851
+ });
852
+ data.grammar_triggers.push_back({
853
+ COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
854
+ "<|START_ACTION|>",
855
+ });
378
856
  data.preserved_tokens = {
857
+ "<|START_ACTION|>",
858
+ "<|END_ACTION|>",
379
859
  "<|START_RESPONSE|>",
380
860
  "<|END_RESPONSE|>",
381
861
  "<|START_THINKING|>",
382
862
  "<|END_THINKING|>",
383
- "<|END_ACTION|>",
384
863
  };
385
- data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
386
- data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
864
+ auto adjusted_messages = json::array();
865
+ for (const auto & msg : inputs.messages) {
866
+ auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
867
+ auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
868
+ if (has_reasoning_content && has_tool_calls) {
869
+ auto adjusted_message = msg;
870
+ adjusted_message["tool_plan"] = msg.at("reasoning_content");
871
+ adjusted_message.erase("reasoning_content");
872
+ adjusted_messages.push_back(adjusted_message);
873
+ } else {
874
+ adjusted_messages.push_back(msg);
875
+ }
876
+ }
877
+ data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {});
878
+ data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING : COMMON_CHAT_FORMAT_COMMAND_R7B;
387
879
  return data;
388
880
  }
389
- static common_chat_msg common_chat_parse_command_r7b(const std::string & input) {
390
- static std::regex response_regex("<\\|START_RESPONSE\\|>([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>");
391
- static std::regex thought_action_regex("<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|><\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
881
+ static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) {
882
+ static const std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)");
883
+ static const std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>");
884
+ static const std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>");
885
+
392
886
  std::smatch match;
393
887
 
394
888
  common_chat_msg result;
395
889
  result.role = "assistant";
396
- if (std::regex_match(input, match, response_regex)) {
397
- result.content = match[1].str();
398
- } else if (std::regex_match(input, match, thought_action_regex)) {
399
- result.tool_plan = match[1].str();
400
- auto actions_str = match[2].str();
890
+
891
+ std::string rest = input;
892
+
893
+ if (std::regex_match(rest, match, thought_regex)) {
894
+ if (extract_reasoning) {
895
+ result.reasoning_content = match[2].str();
896
+ } else if (!match[2].str().empty()) {
897
+ // Let the unparsed thinking tags through in content only if their insides aren't empty.
898
+ result.content = match[1].str();
899
+ }
900
+ rest = match[3].str();
901
+ }
902
+ if (std::regex_match(rest, match, action_regex)) {
903
+ auto actions_str = match[1].str();
401
904
  auto actions = json::parse(actions_str);
402
905
  for (const auto & action : actions) {
403
906
  result.tool_calls.push_back({
404
- /* .name = */ action["tool_name"],
405
- /* .arguments = */ action["parameters"].dump(),
406
- /* .id = */ action["tool_call_id"],
907
+ /* .name = */ action.at("tool_name"),
908
+ /* .arguments = */ action.at("parameters").dump(),
909
+ /* .id = */ action.at("tool_call_id"),
407
910
  });
408
911
  }
912
+ } else if (std::regex_match(rest, match, response_regex)) {
913
+ auto response = match[1].str();
914
+ result.content += response;
409
915
  } else {
410
- LOG_ERR("Failed to parse command_r output");
411
- result.content = input;
916
+ result.content += rest;
412
917
  }
413
918
  return result;
414
919
  }
415
920
 
416
921
  static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
417
- if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
922
+ if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
418
923
  throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
419
924
  }
420
925
  const auto & parameters_properties = parameters.at("properties");
421
926
  const auto & parameters_required = parameters.at("required");
422
927
  for (const auto & prop : expected_properties) {
423
928
  if (!parameters_properties.contains(prop)) {
424
- throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
929
+ throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT
425
930
  }
426
931
  if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
427
- throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
932
+ throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT
428
933
  }
429
934
  }
430
935
  if (parameters_properties.size() != expected_properties.size()) {
@@ -432,18 +937,16 @@ static void expect_tool_parameters(const std::string & name, const json & parame
432
937
  }
433
938
  }
434
939
 
435
- static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) {
940
+ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
436
941
  auto builtin_tools = json::array();
437
942
  common_chat_params data;
438
- data.grammar_lazy = inputs.tool_choice != "required";
943
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
439
944
  data.grammar = build_grammar([&](const common_grammar_builder & builder) {
440
945
  std::vector<std::string> tool_rules;
441
946
 
442
947
  auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
443
- if (name == "wolfram_alpha") {
948
+ if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
444
949
  // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
445
- expect_tool_parameters(name, parameters, {"query"});
446
- } else if (name == "web_search" || name == "brave_search") {
447
950
  // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
448
951
  expect_tool_parameters(name, parameters, {"query"});
449
952
  } else if (name == "python" || name == "code_interpreter") {
@@ -455,7 +958,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
455
958
 
456
959
  std::vector<std::string> kvs;
457
960
  for (const auto & [key, value] : parameters.at("properties").items()) {
458
- kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
961
+ kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
459
962
  }
460
963
 
461
964
  tool_rules.push_back(
@@ -468,9 +971,9 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
468
971
  };
469
972
 
470
973
  foreach_function(inputs.tools, [&](const json & tool) {
471
- const auto & function = tool["function"];
472
- std::string name = function["name"];
473
- auto parameters = function["parameters"];
974
+ const auto & function = tool.at("function");
975
+ std::string name = function.at("name");
976
+ auto parameters = function.at("parameters");
474
977
  builder.resolve_refs(parameters);
475
978
 
476
979
  // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
@@ -481,23 +984,23 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
481
984
  builder.add_rule(
482
985
  name + "-call",
483
986
  "\"{\" space "
484
- "( \"\\\"type\\\":\" space \"\\\"function\\\",\" space )? "
485
- "\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
486
- builder.add_schema(name + "-args", parameters) +
487
- " \"}\""));
488
- data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
987
+ "( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
988
+ " \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
989
+ " \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
990
+ "\"}\" space"));
991
+ });
992
+ // Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
993
+ data.grammar_triggers.push_back({
994
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
995
+ "\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
489
996
  });
490
- data.grammar_triggers.push_back({"{\"name\":", /* .at_start = */ true});
491
- data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true});
492
- data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true});
493
- data.grammar_triggers.push_back({"{\"type\": \"function\"", /* .at_start = */ true});
494
- data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true});
495
- data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true});
496
997
  if (!builtin_tools.empty()) {
497
- data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
998
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
999
+ data.preserved_tokens.push_back("<|python_tag|>");
498
1000
  }
1001
+ // Allow a few empty lines on top of the usual constrained json schema space rule.
499
1002
  builder.add_rule("root", string_join(tool_rules, " | "));
500
- }, grammar_options);
1003
+ });
501
1004
  data.additional_stops.push_back("<|eom_id|>");
502
1005
  data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
503
1006
  {"tools_in_user_message", false},
@@ -510,93 +1013,158 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
510
1013
  }
511
1014
  static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
512
1015
  // TODO: tighten & simplify the parser, don't accept leading text context.
513
- static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
514
- static std::regex close_regex("\\}");
515
- static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
1016
+ static const std::regex function_regex(
1017
+ "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
1018
+ static const std::regex close_regex("\\}\\s*");
1019
+ static const std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)");
516
1020
 
517
1021
  if (with_builtin_tools) {
518
1022
  std::smatch match;
519
1023
  if (std::regex_match(input, match, builtin_call_regex)) {
520
- auto name = match[1].str();
521
- auto raw_args = match[2].str();
522
-
523
- // TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
524
- auto it_eq = raw_args.find('=');
525
- auto arg_name = raw_args.substr(0, it_eq);
526
- auto arg_value_str = raw_args.substr(it_eq + 1);
527
- auto arg_value = json::parse(arg_value_str);
528
-
529
- return {
530
- /* .role = */ "assistant",
531
- /* .content = */ match.prefix().str(),
532
- /* .tool_calls = */ {
533
- {
534
- /* .name = */ match[1],
535
- /* .arguments = */ (json {
536
- {arg_name, arg_value},
537
- }).dump(),
538
- /* .id = */ "",
539
- },
540
- },
541
- };
1024
+ try {
1025
+ auto name = match[1].str();
1026
+ auto arg_name = match[2].str();
1027
+ auto arg_value_str = match[3].str();
1028
+ auto arg_value = json::parse(arg_value_str);
1029
+
1030
+ common_chat_msg msg;
1031
+ msg.role = "assistant";
1032
+ msg.tool_calls.push_back({
1033
+ /* .name = */ name,
1034
+ /* .arguments = */ (json {
1035
+ {arg_name, arg_value},
1036
+ }).dump(),
1037
+ /* .id = */ "",
1038
+ });
1039
+ return msg;
1040
+ } catch (const std::exception & e) {
1041
+ LOG_WRN("Failed to parse builtin tool call arguments (%s): %s", e.what(), input.c_str());
1042
+ }
542
1043
  }
543
1044
  }
544
1045
  return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
545
1046
  }
546
1047
 
547
- static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
1048
+ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
548
1049
  common_chat_params data;
549
- data.grammar_lazy = inputs.tool_choice != "required";
550
- data.grammar = build_grammar([&](const common_grammar_builder & builder) {
551
- std::vector<std::string> tool_rules;
552
- foreach_function(inputs.tools, [&](const json & tool) {
553
- const auto & function = tool["function"];
554
- std::string name = function["name"];
555
- auto parameters = function["parameters"];
556
- auto args_rule = builder.add_schema(name + "-args", parameters);
557
- tool_rules.push_back(builder.add_rule(name + "-call",
558
- "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
1050
+ if (inputs.tools.is_array() && !inputs.tools.empty()) {
1051
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
1052
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1053
+ std::vector<std::string> tool_rules;
1054
+ foreach_function(inputs.tools, [&](const json & tool) {
1055
+ const auto & function = tool.at("function");
1056
+ std::string name = function.at("name");
1057
+ auto parameters = function.at("parameters");
1058
+ builder.resolve_refs(parameters);
1059
+ tool_rules.push_back(builder.add_rule(name + "-call",
1060
+ "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
1061
+ "```json\\n\" " + builder.add_schema(name + "-args", parameters) + " "
1062
+ "\"```<|tool▁call▁end|>\""));
1063
+ });
1064
+ // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
1065
+ // so we accept common variants (then it's all constrained)
1066
+ builder.add_rule("root",
1067
+ "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" ) "
1068
+ "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
1069
+ "\"<|tool▁calls▁end|>\""
1070
+ " space");
1071
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool▁calls▁begin|>"});
1072
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls_begin|>"});
1073
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool calls begin|>"});
1074
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool\\_calls\\_begin|>"});
1075
+ data.preserved_tokens = {
1076
+ "<think>",
1077
+ "</think>",
1078
+ "<|tool▁calls▁begin|>",
1079
+ "<|tool▁call▁begin|>",
1080
+ "<|tool▁sep|>",
1081
+ "<|tool▁call▁end|>",
1082
+ "<|tool▁calls▁end|",
1083
+ };
559
1084
  });
560
- data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
561
- data.preserved_tokens = {
562
- "<|tool▁sep|>",
563
- "<|tool▁call▁end|>",
564
- };
565
- builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
566
- }, grammar_options);
1085
+ }
567
1086
  auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
1087
+
1088
+ // Hacks to fix the official (broken) prompt.
1089
+ // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
1090
+ // until the official template is fixed.
1091
+ if (tmpl.source().find("{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}") != std::string::npos) {
1092
+ // Don't leave the chat dangling after tool results
1093
+ if (string_ends_with(prompt, "<|tool▁outputs▁end|>")) {
1094
+ prompt += "<|end▁of▁sentence|>";
1095
+ if (inputs.add_generation_prompt) {
1096
+ prompt += "<|Assistant|>";
1097
+ }
1098
+ }
1099
+ // Fix up tool call delta example added by Minja
1100
+ prompt = std::regex_replace(
1101
+ prompt,
1102
+ std::regex("(<|tool▁call▁end|>)[\\s\\r\\n]*(<|tool▁outputs▁begin|>|<|User|>)"),
1103
+ "$1<|tool▁calls▁end|><|end▁of▁sentence|>$2");
1104
+ }
568
1105
  data.prompt = prompt;
569
- data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
1106
+ data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1;
570
1107
  return data;
571
1108
  }
572
- static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) {
573
- static std::regex trigger_regex("<|tool▁calls▁begin|>");
574
- static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
575
- static std::regex close_regex("```<|tool▁call▁end|>");
576
- return parse_json_tool_calls(input, trigger_regex, function_regex, close_regex);
1109
+ static common_chat_msg handle_think_tag_prelude(const std::string & input, bool extract_reasoning, const std::function<common_chat_msg(const std::string &)> & rest_parser) {
1110
+ std::smatch match;
1111
+ static const std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
1112
+ if (std::regex_match(input, match, reasoning_content_regex)) {
1113
+ auto rest = match[3].str();
1114
+ auto msg = rest_parser(rest);
1115
+ auto reasoning_content = string_strip(match[2].str());
1116
+ if (extract_reasoning) {
1117
+ msg.reasoning_content = reasoning_content;
1118
+ } else if (!reasoning_content.empty()) {
1119
+ std::ostringstream content;
1120
+ content << "<think>" << reasoning_content << "</think>" << msg.content;
1121
+ msg.content = content.str();
1122
+ }
1123
+ return msg;
1124
+ }
1125
+ return rest_parser(input);
577
1126
  }
1127
+ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) {
1128
+ return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
1129
+ static const std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
1130
+ static const std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
1131
+ static const std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>");
578
1132
 
579
- static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
580
- fprintf(stderr, "%s\n", __func__);
1133
+ common_chat_msg msg;
1134
+ msg.role = "assistant";
1135
+ std::smatch match;
1136
+ if (std::regex_search(input, match, tool_calls_regex)) {
1137
+ auto tool_calls = match[1].str();
1138
+ auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
1139
+ msg.tool_calls = std::move(msg2.tool_calls);
1140
+ } else {
1141
+ msg.content = input;
1142
+ }
1143
+ return msg;
1144
+ });
1145
+ }
1146
+
1147
+ static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
1148
+ LOG_DBG("%s\n", __func__);
581
1149
  common_chat_params data;
582
1150
  data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
583
1151
  {"datetime", "Jan 29 2025 13:00:00 GMT"},
584
1152
  {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
585
1153
  });
586
- if (!inputs.tools.is_null() && !inputs.tools.empty()) {
587
- data.grammar_lazy = inputs.tool_choice != "required";
1154
+ if (inputs.tools.is_array() && !inputs.tools.empty()) {
1155
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
588
1156
  data.grammar = build_grammar([&](const common_grammar_builder & builder) {
589
1157
  auto schemas = json::array();
590
1158
  foreach_function(inputs.tools, [&](const json & tool) {
591
- const auto & function = tool["function"];
1159
+ const auto & function = tool.at("function");
592
1160
  schemas.push_back({
593
1161
  {"type", "object"},
594
1162
  {"properties", {
595
1163
  {"name", {
596
1164
  {"type", "string"},
597
- {"const", function["name"]},
1165
+ {"const", function.at("name")},
598
1166
  }},
599
- {"arguments", function["parameters"]},
1167
+ {"arguments", function.at("parameters")},
600
1168
  }},
601
1169
  {"required", json::array({"name", "arguments", "id"})},
602
1170
  });
@@ -610,8 +1178,11 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
610
1178
  schema["maxItems"] = 1;
611
1179
  }
612
1180
  builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
613
- }, grammar_options);
614
- data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
1181
+ });
1182
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, " functools["});
1183
+ data.preserved_tokens = {
1184
+ " functools[",
1185
+ };
615
1186
  data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2;
616
1187
  } else {
617
1188
  data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
@@ -622,27 +1193,45 @@ static common_chat_msg common_chat_parse_firefunction_v2(const std::string & inp
622
1193
  return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
623
1194
  }
624
1195
 
625
- static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
1196
+ static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
626
1197
  // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
627
1198
  // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
628
1199
  common_chat_params data;
629
1200
  data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
630
1201
  data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
631
- if (!inputs.tools.is_null() && !inputs.tools.empty()) {
632
- data.grammar_lazy = inputs.tool_choice != "required";
1202
+ if (inputs.tools.is_array() && !inputs.tools.empty()) {
1203
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
633
1204
  data.grammar = build_grammar([&](const common_grammar_builder & builder) {
634
1205
  std::vector<std::string> first_tool_rules;
635
1206
  std::vector<std::string> subsequent_tool_rules;
636
1207
  foreach_function(inputs.tools, [&](const json & tool) {
637
- const auto & function = tool["function"];
638
- std::string name = function["name"];
639
- auto parameters = function["parameters"];
1208
+ const auto & function = tool.at("function");
1209
+ std::string name = function.at("name");
1210
+ auto parameters = function.at("parameters");
1211
+ builder.resolve_refs(parameters);
640
1212
  auto args_rule = builder.add_schema(name + "-args", parameters);
641
- first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
1213
+ first_tool_rules.push_back(builder.add_rule(name + "-call", "( \"assistant<|end_header_id|>\\n\" )? \"" + name + "\\n\" " + args_rule));
642
1214
  subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
643
- data.grammar_triggers.push_back({name, /* .at_start = */ true});
644
- data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
1215
+ data.grammar_triggers.push_back({
1216
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
1217
+ regex_escape(name + "\n"),
1218
+ });
1219
+ data.grammar_triggers.push_back({
1220
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
1221
+ regex_escape("assistant<|end_header_id|>\n" + name + "\n"),
1222
+ });
1223
+ data.grammar_triggers.push_back({
1224
+ COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
1225
+ regex_escape(">>>" + name + "\n"),
1226
+ });
1227
+ data.grammar_triggers.push_back({
1228
+ COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
1229
+ ">>>assistant<|end_header_id|>\n" + name,
1230
+ });
645
1231
  });
1232
+ data.preserved_tokens = {
1233
+ "<|end_header_id|>",
1234
+ };
646
1235
  auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
647
1236
  if (inputs.parallel_tool_calls) {
648
1237
  auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
@@ -651,34 +1240,20 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
651
1240
  builder.add_rule("root", first_rule);
652
1241
  }
653
1242
 
654
- }, grammar_options);
1243
+ });
655
1244
  }
656
1245
  return data;
657
1246
  }
658
1247
 
659
- static bool consume(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
660
- auto expected_it = expected.begin();
661
- auto tmp_it = it;
662
- while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
663
- ++tmp_it;
664
- ++expected_it;
665
- }
666
- if (expected_it == expected.end()) {
667
- it = tmp_it;
668
- return true;
669
- }
670
- return false;
671
- }
672
-
673
1248
  static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) {
674
- static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
675
- static std::regex close_regex(R"($|(?=>>>))");
1249
+ static const std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)");
1250
+ static const std::regex close_regex(R"($|(?=>>>))");
676
1251
 
677
1252
  std::string content;
678
1253
  auto it = input.begin();
679
1254
  const auto end = input.end();
680
1255
 
681
- if (consume(it, end, "all\n")) {
1256
+ if (parse_literal(it, end, "all\n")) {
682
1257
  std::smatch match;
683
1258
  if (std::regex_search(it, end, match, function_regex)) {
684
1259
  auto fun_it = match.prefix().second;
@@ -693,7 +1268,7 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
693
1268
  }
694
1269
  // TODO: tighten & simplify.
695
1270
  try {
696
- auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex);
1271
+ auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex, /* allow_raw_python= */ true);
697
1272
  res.content = content + res.content;
698
1273
  return res;
699
1274
  } catch (const std::exception & e) {
@@ -705,26 +1280,26 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
705
1280
  }
706
1281
  }
707
1282
 
708
- static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
1283
+ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
709
1284
  // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
710
1285
  common_chat_params data;
711
1286
  json tools = inputs.tools.is_null() ? inputs.tools : json::array();
712
1287
  std::string python_code_argument_name;
713
1288
  auto has_raw_python = false;
714
1289
 
715
- data.grammar_lazy = inputs.tool_choice != "required";
1290
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
716
1291
  data.grammar = build_grammar([&](const common_grammar_builder & builder) {
717
1292
  std::vector<std::string> tool_rules;
718
1293
  foreach_function(inputs.tools, [&](const json & tool) {
719
- const auto & function = tool["function"];
720
- const auto & parameters = function["parameters"];
721
- std::string name = function["name"];
1294
+ const auto & function = tool.at("function");
1295
+ const auto & parameters = function.at("parameters");
1296
+ std::string name = function.at("name");
722
1297
  if (name == "python" || name == "ipython") {
723
1298
  if (!parameters.contains("type")) {
724
1299
  throw std::runtime_error("Missing type in python tool");
725
1300
  }
726
1301
  has_raw_python = true;
727
- auto type = parameters.at("type");
1302
+ const auto & type = parameters.at("type");
728
1303
  if (type == "object") {
729
1304
  auto properties = parameters.at("properties");
730
1305
  for (auto it = properties.begin(); it != properties.end(); ++it) {
@@ -746,12 +1321,13 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
746
1321
  });
747
1322
  if (has_raw_python) {
748
1323
  tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
749
- data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
1324
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
1325
+ data.preserved_tokens.push_back("<|python_tag|>");
750
1326
  }
751
1327
  auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
752
1328
  builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
753
- data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
754
- }, grammar_options);
1329
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
1330
+ });
755
1331
 
756
1332
  data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
757
1333
  // TODO: if (has_raw_python)
@@ -760,38 +1336,37 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
760
1336
  }
761
1337
  static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
762
1338
  // This version of Functionary still supports the llama 3.1 tool call format for the python tool.
763
- static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
1339
+ static const std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
764
1340
  std::smatch match;
765
1341
  if (std::regex_search(input, match, python_tag_regex)) {
766
1342
  auto code = match[1].str();
767
- return {
768
- /* .role = */ "assistant",
769
- /* .content = */ match.prefix().str(),
770
- /* .tool_calls = */ {
771
- {
772
- /* .name = */ "python",
773
- /* .arguments = */ (json {{"code", code}}).dump(),
774
- /* .id = */ "",
775
- },
776
- }
777
- };
1343
+ common_chat_msg msg;
1344
+ msg.role = "assistant";
1345
+ msg.content = match.prefix().str();
1346
+ msg.tool_calls.push_back({
1347
+ /* .name = */ "python",
1348
+ /* .arguments = */ (json {{"code", code}}).dump(),
1349
+ /* .id = */ "",
1350
+ });
1351
+ return msg;
778
1352
  }
779
- static std::regex function_regex(R"(<function=(\w+)>)");
780
- static std::regex close_regex(R"(</function>)");
1353
+ static const std::regex function_regex(R"(<function=(\w+)>)");
1354
+ static const std::regex close_regex(R"(</function>)");
781
1355
  // TODO: tighten & simplify.
782
1356
  return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
783
1357
  }
784
1358
 
785
- static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
1359
+ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
786
1360
  common_chat_params data;
787
1361
  // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
788
- data.grammar_lazy = inputs.tool_choice != "required";
1362
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
789
1363
  data.grammar = build_grammar([&](const common_grammar_builder & builder) {
790
1364
  std::vector<std::string> tool_rules;
1365
+ std::vector<std::string> tool_call_alts;
791
1366
  foreach_function(inputs.tools, [&](const json & tool) {
792
- const auto & function = tool["function"];
793
- std::string name = function["name"];
794
- auto parameters = function["parameters"];
1367
+ const auto & function = tool.at("function");
1368
+ std::string name = function.at("name");
1369
+ auto parameters = function.at("parameters");
795
1370
  builder.resolve_refs(parameters);
796
1371
  tool_rules.push_back(builder.add_schema(name + "-call", {
797
1372
  {"type", "object"},
@@ -801,73 +1376,190 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
801
1376
  }},
802
1377
  {"required", json::array({"name", "arguments"})},
803
1378
  }));
1379
+ tool_call_alts.push_back(builder.add_rule(
1380
+ name + "-function-tag",
1381
+ "\"<function\" ( \"=" + name + "\" | \" name=\\\"" + name + "\\\"\" ) \">\" space " +
1382
+ builder.add_schema(name + "-args", parameters) + " "
1383
+ "\"</function>\" space"));
1384
+
1385
+ data.grammar_triggers.push_back({
1386
+ COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
1387
+ "<function=" + name + ">",
1388
+ });
1389
+ auto escaped_name = regex_escape(name);
1390
+ data.grammar_triggers.push_back({
1391
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
1392
+ "<function\\s+name\\s*=\\s*\"" + escaped_name + "\"",
1393
+ });
804
1394
  });
805
- auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
1395
+ auto any_tool_call = builder.add_rule("any_tool_call", "( " + string_join(tool_rules, " | ") + " ) space");
1396
+ std::vector<std::string> alt_tags {
1397
+ any_tool_call,
1398
+ "\"<tool_call>\" space " + any_tool_call + " \"</tool_call>\"",
1399
+ // The rest is just to accommodate common "good bad" outputs.
1400
+ "\"<function_call>\" space " + any_tool_call + " \"</function_call>\"",
1401
+ "\"<response>\" space " + any_tool_call + " \"</response>\"",
1402
+ "\"<tools>\" space " + any_tool_call + " \"</tools>\"",
1403
+ "\"<json>\" space " + any_tool_call + " \"</json>\"",
1404
+ "\"<xml>\" space " + any_tool_call + " \"</xml>\"",
1405
+ "\"<JSON>\" space " + any_tool_call + " \"</JSON>\"",
1406
+ };
1407
+ auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space");
1408
+ tool_call_alts.push_back(wrappable_tool_call);
1409
+ tool_call_alts.push_back(
1410
+ "( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space ");
1411
+ auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | "));
806
1412
  builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
807
- data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
808
- data.preserved_tokens = { "</tool_call>" };
809
- }, grammar_options);
1413
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<tool_call>"});
1414
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function"});
1415
+ // Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
1416
+ data.grammar_triggers.push_back({
1417
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
1418
+ "(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?\\s*\\{\\s*\"", //name\"\\s*:\\s*\"" + escaped_name + "\"",
1419
+ });
1420
+ data.preserved_tokens = {
1421
+ "<think>",
1422
+ "</think>",
1423
+ "<tool_call>",
1424
+ "</tool_call>",
1425
+ "<function",
1426
+ "<tools>",
1427
+ "</tools>",
1428
+ "<response>",
1429
+ "</response>",
1430
+ "<function_call>",
1431
+ "</function_call>",
1432
+ "<json>",
1433
+ "</json>",
1434
+ "<JSON>",
1435
+ "</JSON>",
1436
+ "```",
1437
+ "```json",
1438
+ "```xml",
1439
+ };
1440
+ });
810
1441
 
811
1442
  data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
812
- data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
1443
+ data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING : COMMON_CHAT_FORMAT_HERMES_2_PRO;
813
1444
  return data;
814
1445
  }
815
- static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) {
816
- try {
817
- std::regex start_pattern(R"([\n\s]*<tool_call>)");
818
- std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
819
- std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
1446
+ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input, bool extract_reasoning) {
1447
+ return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
1448
+ static const std::regex open_regex(
1449
+ "(?:"
1450
+ "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
1451
+ "(<tool_call>" // match 2 (open_tag)
1452
+ "|<function_call>"
1453
+ "|<tool>"
1454
+ "|<tools>"
1455
+ "|<response>"
1456
+ "|<json>"
1457
+ "|<xml>"
1458
+ "|<JSON>"
1459
+ ")?"
1460
+ "(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)" // match 3 (named tool call + rest)
1461
+ ")"
1462
+ "|"
1463
+ "(?:<function=([^>]+)>" // match 4 (function name)
1464
+ "|<function name=\"([^\"]+)\">)" // match 5 (function name again)
1465
+ "([\\s\\S]*)" // match 6 (function arguments + rest)})"
1466
+ );
820
1467
 
821
- auto end = input.end();
822
- std::sregex_iterator rend;
823
- std::sregex_iterator rit(input.begin(), end, start_pattern);
824
- if (rit == rend) {
825
- return {
826
- /* .role = */ "assistant",
827
- /* .content = */ input,
828
- /* .tool_calls = */ {},
829
- };
830
- }
1468
+ try {
1469
+ common_chat_msg msg;
1470
+ msg.role = "assistant";
831
1471
 
832
- common_chat_msg result;
833
- result.role = "assistant";
834
- result.content = rit->prefix();
1472
+ std::string::const_iterator it = input.begin();
1473
+ const std::string::const_iterator end = input.end();
1474
+ std::smatch match;
835
1475
 
836
- auto it = rit->suffix().first;
837
- while (it != end) {
838
- json call;
839
- if (!parse_json(it, end, call)) {
840
- throw std::runtime_error("Failed to parse json tool call");
841
- }
842
- const auto & arguments = call["arguments"];
843
- result.tool_calls.push_back({
844
- call["name"],
845
- arguments.dump(),
846
- // arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
847
- /* id= */ "",
848
- });
849
- rit = {it, end, middle_pattern};
850
- if (rit != rend) {
851
- it = rit->suffix().first;
852
- } else {
853
- rit = {it, end, end_pattern};
854
- if (rit == rend) {
855
- throw std::runtime_error("Malformed input, missing </tool_call>");
1476
+ while (it != end) {
1477
+ if (std::regex_search(it, end, match, open_regex)) {
1478
+ // Add content before the match
1479
+ msg.content += std::string(it, match[0].first);
1480
+
1481
+ auto block_start = match[1].str();
1482
+ std::string block_end = block_start.empty() ? "" : "```";
1483
+
1484
+ auto open_tag = match[2].str();
1485
+ std::string close_tag;
1486
+
1487
+ if (match[3].matched) {
1488
+ close_tag = open_tag.empty() ? "" : "</" + open_tag.substr(1);
1489
+ auto json_it = match[3].first;
1490
+ json tool_call;
1491
+ if (parse_json(json_it, end, tool_call) && tool_call.contains("name") && tool_call.contains("arguments")) {
1492
+
1493
+ msg.tool_calls.emplace_back(process_tool_call(tool_call));
1494
+ it = json_it; // Move iterator past parsed JSON
1495
+
1496
+ // Handle close tags
1497
+ consume_spaces(it, end);
1498
+ if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
1499
+ throw std::runtime_error("Failed to parse closing tag");
1500
+ }
1501
+ consume_spaces(it, end);
1502
+ if (!block_end.empty() && !parse_literal(it, end, block_end)) {
1503
+ throw std::runtime_error("Failed to parse block end");
1504
+ }
1505
+ consume_spaces(it, end);
1506
+ } else {
1507
+ // Not a valid tool call, treat as content
1508
+ msg.content += std::string(match[0].first, match[0].second);
1509
+ it = match[0].second;
1510
+ }
1511
+ } else {
1512
+ auto function_name = match[4].str();
1513
+ if (function_name.empty()) {
1514
+ function_name = match[5].str();
1515
+ }
1516
+ GGML_ASSERT(!function_name.empty());
1517
+
1518
+ close_tag = "</function>";
1519
+ // Start parsing from after the opening tags
1520
+ auto json_it = match[6].first;
1521
+ json arguments;
1522
+ if (parse_json(json_it, end, arguments)) {
1523
+ msg.tool_calls.emplace_back(process_tool_call({
1524
+ {"name", function_name},
1525
+ {"arguments", arguments},
1526
+ }));
1527
+ it = json_it; // Move iterator past parsed JSON
1528
+
1529
+ // Handle close tags
1530
+ consume_spaces(it, end);
1531
+ if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
1532
+ throw std::runtime_error("Failed to parse closing tag");
1533
+ }
1534
+ consume_spaces(it, end);
1535
+ if (!block_end.empty() && !parse_literal(it, end, block_end)) {
1536
+ throw std::runtime_error("Failed to parse block end");
1537
+ }
1538
+ consume_spaces(it, end);
1539
+ } else {
1540
+ // Not a valid tool call, treat as content
1541
+ msg.content += std::string(match[0].first, match[0].second);
1542
+ it = match[0].second;
1543
+ }
1544
+ }
1545
+ } else {
1546
+ // Add remaining content
1547
+ msg.content += std::string(it, end);
1548
+ break;
856
1549
  }
857
- break;
858
1550
  }
1551
+ return msg;
1552
+ } catch (const std::exception & e) {
1553
+ LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
1554
+ common_chat_msg msg;
1555
+ msg.role = "assistant";
1556
+ msg.content = input;
1557
+ return msg;
859
1558
  }
860
- return result;
861
- } catch (const std::exception & e) {
862
- return {
863
- /* .role = */ "assistant",
864
- /* .content = */ input,
865
- /* .tool_calls = */ {},
866
- };
867
- }
1559
+ });
868
1560
  }
869
1561
 
870
- static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
1562
+ static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
871
1563
  common_chat_params data;
872
1564
  data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
873
1565
  data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
@@ -878,62 +1570,177 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
878
1570
  }
879
1571
  data.grammar = json_schema_to_grammar(inputs.json_schema);
880
1572
  } else {
881
- data.grammar = inputs.grammar.empty();
1573
+ data.grammar = inputs.grammar;
882
1574
  }
883
1575
  return data;
884
1576
  }
885
1577
 
886
- common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
887
- auto has_tools = !inputs.tools.is_null() && inputs.tool_choice != "none";
888
- LOG_DBG("[%s] has_tools=%s\n", __func__, has_tools ? "true" : "false");
1578
+ static common_chat_params common_chat_templates_apply_jinja(
1579
+ const struct common_chat_templates * tmpls,
1580
+ const struct common_chat_templates_inputs & inputs)
1581
+ {
1582
+ templates_params params;
1583
+ params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
1584
+ const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
1585
+ ? *tmpls->template_tool_use
1586
+ : *tmpls->template_default;
1587
+ const auto & src = tmpl.source();
1588
+ const auto & caps = tmpl.original_caps();
1589
+ params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
1590
+ params.add_generation_prompt = inputs.add_generation_prompt;
1591
+ params.extract_reasoning = inputs.extract_reasoning;
1592
+ params.tool_choice = inputs.tool_choice;
1593
+ params.grammar = inputs.grammar;
1594
+ if (!inputs.json_schema.empty()) {
1595
+ params.json_schema = json::parse(inputs.json_schema);
1596
+ }
889
1597
 
890
- if (has_tools && !inputs.grammar.empty()) {
891
- throw std::runtime_error("Cannot specify grammar with tools");
1598
+ if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
1599
+ LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
1600
+ params.parallel_tool_calls = false;
1601
+ } else {
1602
+ params.parallel_tool_calls = inputs.parallel_tool_calls;
892
1603
  }
893
1604
 
894
- const auto & src = tmpl.source();
1605
+ if (params.tools.is_array()) {
1606
+ if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
1607
+ throw std::runtime_error("Cannot specify grammar with tools");
1608
+ }
1609
+ if (caps.supports_tool_calls && !caps.supports_tools) {
1610
+ LOG_WRN("Template supports tool calls but does not natively describe tools. The fallback behaviour used may produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
1611
+ }
1612
+ }
1613
+
1614
+ // DeepSeek R1: use handler in all cases except json schema (thinking / tools).
1615
+ if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) {
1616
+ return common_chat_params_init_deepseek_r1(tmpl, params);
1617
+ }
1618
+
1619
+ // Command R7B: : use handler in all cases except json schema (thinking / tools).
1620
+ if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) {
1621
+ return common_chat_params_init_command_r7b(tmpl, params);
1622
+ }
1623
+
1624
+ // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
1625
+ if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
1626
+ return common_chat_params_init_hermes_2_pro(tmpl, params);
1627
+ }
1628
+
1629
+ // Use generic handler when mixing tools + JSON schema.
1630
+ // TODO: support that mix in handlers below.
1631
+ if ((params.tools.is_array() && params.json_schema.is_object())) {
1632
+ return common_chat_params_init_generic(tmpl, params);
1633
+ }
1634
+
1635
+ // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
895
1636
  if (src.find(">>>all") != std::string::npos) {
896
- // Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when
897
- return common_chat_params_init_functionary_v3_2(tmpl, inputs);
1637
+ return common_chat_params_init_functionary_v3_2(tmpl, params);
898
1638
  }
1639
+
1640
+ // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
899
1641
  if (src.find(" functools[") != std::string::npos) {
900
- // Firefunction v2 requires datetime and functions in the context, even w/o tools.
901
- return common_chat_params_init_firefunction_v2(tmpl, inputs);
1642
+ return common_chat_params_init_firefunction_v2(tmpl, params);
902
1643
  }
903
1644
 
904
- if (!has_tools) {
905
- return common_chat_params_init_without_tools(tmpl, inputs);
1645
+ // Plain handler (no tools)
1646
+ if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
1647
+ return common_chat_params_init_without_tools(tmpl, params);
906
1648
  }
907
1649
 
908
- if (src.find("<tool_call>") != std::string::npos) {
909
- return common_chat_params_init_hermes_2_pro(tmpl, inputs);
910
- }
1650
+ // Functionary v3.1 (w/ tools)
911
1651
  if (src.find("<|start_header_id|>") != std::string::npos
912
1652
  && src.find("<function=") != std::string::npos) {
913
- return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
1653
+ return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
914
1654
  }
1655
+
1656
+ // Llama 3.1, 3.2, 3.3 (w/ tools)
915
1657
  if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
916
1658
  auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
917
- return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
918
- }
919
- if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
920
- return common_chat_params_init_deepseek_r1(tmpl, inputs);
1659
+ return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
921
1660
  }
1661
+
1662
+ // Mistral Nemo (w/ tools)
922
1663
  if (src.find("[TOOL_CALLS]") != std::string::npos) {
923
- return common_chat_params_init_mistral_nemo(tmpl, inputs);
1664
+ return common_chat_params_init_mistral_nemo(tmpl, params);
1665
+ }
1666
+
1667
+ // Generic fallback
1668
+ return common_chat_params_init_generic(tmpl, params);
1669
+ }
1670
+
1671
+ // Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template.
1672
+ static common_chat_params common_chat_templates_apply_legacy(
1673
+ const struct common_chat_templates * tmpls,
1674
+ const struct common_chat_templates_inputs & inputs)
1675
+ {
1676
+ int alloc_size = 0;
1677
+ std::vector<llama_chat_message> chat;
1678
+ std::vector<std::string> contents;
1679
+ for (const auto & msg : inputs.messages) {
1680
+ auto content = msg.content;
1681
+ for (const auto & part : msg.content_parts) {
1682
+ if (part.type != "text") {
1683
+ LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
1684
+ continue;
1685
+ }
1686
+ if (!content.empty()) {
1687
+ content += "\n";;
1688
+ }
1689
+ content += part.text;
1690
+ }
1691
+ contents.emplace_back(std::move(content));
1692
+ }
1693
+ for (size_t i = 0; i < contents.size(); ++i) {
1694
+ const auto & msg = inputs.messages[i];
1695
+ const auto & content = contents[i];
1696
+ chat.push_back({msg.role.c_str(), content.c_str()});
1697
+ alloc_size += (msg.role.size() + content.size()) * 1.25;
924
1698
  }
925
- if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos) {
926
- return common_chat_params_init_command_r7b(tmpl, inputs);
1699
+
1700
+ std::vector<char> buf(alloc_size);
1701
+
1702
+ // run the first time to get the total output length
1703
+ const auto & src = tmpls->template_default->source();
1704
+ int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
1705
+
1706
+ // error: chat template is not supported
1707
+ if (res < 0) {
1708
+ // if the custom "tmpl" is not supported, we throw an error
1709
+ // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
1710
+ throw std::runtime_error("this custom template is not supported");
927
1711
  }
928
- return common_chat_params_init_generic(tmpl, inputs);
1712
+
1713
+ // if it turns out that our buffer is too small, we resize it
1714
+ if ((size_t) res > buf.size()) {
1715
+ buf.resize(res);
1716
+ res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
1717
+ }
1718
+
1719
+ common_chat_params params;
1720
+ params.prompt = std::string(buf.data(), res);
1721
+ if (!inputs.json_schema.empty()) {
1722
+ params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema));
1723
+ } else {
1724
+ params.grammar = inputs.grammar;
1725
+ }
1726
+ return params;
1727
+ }
1728
+
1729
+ common_chat_params common_chat_templates_apply(
1730
+ const struct common_chat_templates * tmpls,
1731
+ const struct common_chat_templates_inputs & inputs)
1732
+ {
1733
+ GGML_ASSERT(tmpls != nullptr);
1734
+ return inputs.use_jinja
1735
+ ? common_chat_templates_apply_jinja(tmpls, inputs)
1736
+ : common_chat_templates_apply_legacy(tmpls, inputs);
929
1737
  }
930
1738
 
931
1739
  static common_chat_msg common_chat_parse_content_only(const std::string & input) {
932
- return {
933
- /* .role = */ "assistant",
934
- /* .content = */ input,
935
- /* .tool_calls = */ {},
936
- };
1740
+ common_chat_msg msg;
1741
+ msg.role = "assistant";
1742
+ msg.content = input;
1743
+ return msg;
937
1744
  }
938
1745
 
939
1746
  common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
@@ -949,17 +1756,23 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
949
1756
  case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
950
1757
  return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true);
951
1758
  case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
952
- return common_chat_parse_deepseek_r1(input);
1759
+ return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ false);
1760
+ case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING:
1761
+ return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ true);
953
1762
  case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
954
1763
  return common_chat_parse_functionary_v3_2(input);
955
1764
  case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
956
1765
  return common_chat_parse_functionary_v3_1_llama_3_1(input);
957
1766
  case COMMON_CHAT_FORMAT_HERMES_2_PRO:
958
- return common_chat_parse_hermes_2_pro(input);
1767
+ return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ false);
1768
+ case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING:
1769
+ return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ true);
959
1770
  case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
960
1771
  return common_chat_parse_firefunction_v2(input);
961
1772
  case COMMON_CHAT_FORMAT_COMMAND_R7B:
962
- return common_chat_parse_command_r7b(input);
1773
+ return common_chat_parse_command_r7b(input, /* extract_reasoning= */ false);
1774
+ case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING:
1775
+ return common_chat_parse_command_r7b(input, /* extract_reasoning= */ true);
963
1776
  default:
964
1777
  throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
965
1778
  }