@fugood/llama.node 0.3.12 → 0.3.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +2 -1
- package/package.json +1 -1
- package/src/LlamaCompletionWorker.cpp +14 -0
- package/src/LlamaContext.cpp +110 -79
- package/src/LlamaContext.h +1 -1
- package/src/common.hpp +1 -2
- package/src/llama.cpp/.github/workflows/build.yml +95 -13
- package/src/llama.cpp/.github/workflows/docker.yml +2 -0
- package/src/llama.cpp/.github/workflows/labeler.yml +1 -1
- package/src/llama.cpp/.github/workflows/server.yml +2 -0
- package/src/llama.cpp/common/CMakeLists.txt +23 -6
- package/src/llama.cpp/common/arg.cpp +292 -14
- package/src/llama.cpp/common/chat.cpp +1128 -315
- package/src/llama.cpp/common/chat.h +135 -0
- package/src/llama.cpp/common/common.cpp +27 -171
- package/src/llama.cpp/common/common.h +41 -73
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
- package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
- package/src/llama.cpp/common/llguidance.cpp +3 -3
- package/src/llama.cpp/common/log.cpp +1 -0
- package/src/llama.cpp/common/log.h +2 -1
- package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +21 -7
- package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +61 -14
- package/src/llama.cpp/common/ngram-cache.cpp +1 -0
- package/src/llama.cpp/common/sampling.cpp +93 -49
- package/src/llama.cpp/common/speculative.cpp +6 -5
- package/src/llama.cpp/common/speculative.h +1 -1
- package/src/llama.cpp/docs/build.md +47 -9
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +4 -4
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +6 -5
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +373 -107
- package/src/llama.cpp/examples/llava/clip.h +19 -3
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
- package/src/llama.cpp/examples/llava/llava.cpp +4 -2
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
- package/src/llama.cpp/examples/main/main.cpp +73 -28
- package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +1 -0
- package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
- package/src/llama.cpp/examples/run/run.cpp +115 -79
- package/src/llama.cpp/examples/server/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/server/httplib.h +381 -292
- package/src/llama.cpp/examples/server/server.cpp +134 -128
- package/src/llama.cpp/examples/server/utils.hpp +95 -106
- package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
- package/src/llama.cpp/examples/tts/tts.cpp +251 -142
- package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
- package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-cpu.h +4 -1
- package/src/llama.cpp/ggml/include/ggml-metal.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +0 -2
- package/src/llama.cpp/ggml/include/ggml.h +6 -2
- package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
- package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
- package/src/llama.cpp/ggml/src/ggml-common.h +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +156 -11
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +2235 -641
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1572 -198
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +24 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +9 -8
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +16 -3
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +246 -120
- package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +174 -728
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +949 -602
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +37 -3
- package/src/llama.cpp/ggml/src/ggml.c +9 -4
- package/src/llama.cpp/include/llama.h +32 -14
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +1 -0
- package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
- package/src/llama.cpp/requirements.txt +1 -0
- package/src/llama.cpp/src/llama-arch.cpp +21 -0
- package/src/llama.cpp/src/llama-arch.h +1 -0
- package/src/llama.cpp/src/llama-chat.cpp +1 -0
- package/src/llama.cpp/src/llama-grammar.cpp +183 -183
- package/src/llama.cpp/src/llama-grammar.h +13 -4
- package/src/llama.cpp/src/llama-impl.h +6 -6
- package/src/llama.cpp/src/llama-kv-cache.h +2 -1
- package/src/llama.cpp/src/llama-mmap.cpp +11 -1
- package/src/llama.cpp/src/llama-mmap.h +1 -0
- package/src/llama.cpp/src/llama-model.cpp +70 -6
- package/src/llama.cpp/src/llama-sampling.cpp +174 -67
- package/src/llama.cpp/src/llama-vocab.cpp +12 -0
- package/src/llama.cpp/src/llama.cpp +154 -5
- package/src/llama.cpp/src/unicode.cpp +9 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +171 -115
- package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
- package/src/llama.cpp/tests/test-chat.cpp +691 -325
- package/src/llama.cpp/tests/test-gguf.cpp +4 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
- package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
- package/src/llama.cpp/tests/test-sampling.cpp +15 -0
- package/src/llama.cpp/Sources/llama/llama.h +0 -4
- package/src/llama.cpp/common/chat.hpp +0 -52
|
@@ -1,8 +1,436 @@
|
|
|
1
|
-
#include "chat.
|
|
2
|
-
#include "chat-template.hpp"
|
|
1
|
+
#include "chat.h"
|
|
3
2
|
#include "json-schema-to-grammar.h"
|
|
4
3
|
#include "log.h"
|
|
5
|
-
#include "minja.hpp"
|
|
4
|
+
#include "minja/chat-template.hpp"
|
|
5
|
+
#include "minja/minja.hpp"
|
|
6
|
+
|
|
7
|
+
#include <optional>
|
|
8
|
+
|
|
9
|
+
typedef minja::chat_template common_chat_template;
|
|
10
|
+
|
|
11
|
+
struct common_chat_templates {
|
|
12
|
+
bool has_explicit_template; // Model had builtin template or template overridde was specified.
|
|
13
|
+
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
|
|
14
|
+
std::unique_ptr<common_chat_template> template_tool_use;
|
|
15
|
+
};
|
|
16
|
+
|
|
17
|
+
struct templates_params {
|
|
18
|
+
json messages;
|
|
19
|
+
json tools;
|
|
20
|
+
common_chat_tool_choice tool_choice;
|
|
21
|
+
json json_schema;
|
|
22
|
+
bool parallel_tool_calls;
|
|
23
|
+
bool stream;
|
|
24
|
+
std::string grammar;
|
|
25
|
+
bool add_generation_prompt = true;
|
|
26
|
+
bool extract_reasoning = true;
|
|
27
|
+
};
|
|
28
|
+
|
|
29
|
+
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
|
|
30
|
+
if (tool_choice == "auto") {
|
|
31
|
+
return COMMON_CHAT_TOOL_CHOICE_AUTO;
|
|
32
|
+
}
|
|
33
|
+
if (tool_choice == "none") {
|
|
34
|
+
return COMMON_CHAT_TOOL_CHOICE_NONE;
|
|
35
|
+
}
|
|
36
|
+
if (tool_choice == "required") {
|
|
37
|
+
return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
38
|
+
}
|
|
39
|
+
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
template <>
|
|
43
|
+
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
|
|
44
|
+
std::vector<common_chat_msg> msgs;
|
|
45
|
+
|
|
46
|
+
try {
|
|
47
|
+
|
|
48
|
+
if (!messages.is_array()) {
|
|
49
|
+
throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
for (const auto & message : messages) {
|
|
53
|
+
if (!message.is_object()) {
|
|
54
|
+
throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
common_chat_msg msg;
|
|
58
|
+
if (!message.contains("role")) {
|
|
59
|
+
throw std::runtime_error("Missing 'role' in message: " + message.dump());
|
|
60
|
+
}
|
|
61
|
+
msg.role = message.at("role");
|
|
62
|
+
|
|
63
|
+
auto has_content = message.contains("content");
|
|
64
|
+
auto has_tool_calls = message.contains("tool_calls");
|
|
65
|
+
if (has_content) {
|
|
66
|
+
const auto & content = message.at("content");
|
|
67
|
+
if (content.is_string()) {
|
|
68
|
+
msg.content = content;
|
|
69
|
+
} else if (content.is_array()) {
|
|
70
|
+
for (const auto & part : content) {
|
|
71
|
+
if (!part.contains("type")) {
|
|
72
|
+
throw std::runtime_error("Missing content part type: " + part.dump());
|
|
73
|
+
}
|
|
74
|
+
const auto & type = part.at("type");
|
|
75
|
+
if (type != "text") {
|
|
76
|
+
throw std::runtime_error("Unsupported content part type: " + type.dump());
|
|
77
|
+
}
|
|
78
|
+
common_chat_msg_content_part msg_part;
|
|
79
|
+
msg_part.type = type;
|
|
80
|
+
msg_part.text = part.at("text");
|
|
81
|
+
msg.content_parts.push_back(msg_part);
|
|
82
|
+
}
|
|
83
|
+
} else if (!content.is_null()) {
|
|
84
|
+
throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
if (has_tool_calls) {
|
|
88
|
+
for (const auto & tool_call : message.at("tool_calls")) {
|
|
89
|
+
common_chat_tool_call tc;
|
|
90
|
+
if (!tool_call.contains("type")) {
|
|
91
|
+
throw std::runtime_error("Missing tool call type: " + tool_call.dump());
|
|
92
|
+
}
|
|
93
|
+
const auto & type = tool_call.at("type");
|
|
94
|
+
if (type != "function") {
|
|
95
|
+
throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
|
|
96
|
+
}
|
|
97
|
+
if (!tool_call.contains("function")) {
|
|
98
|
+
throw std::runtime_error("Missing tool call function: " + tool_call.dump());
|
|
99
|
+
}
|
|
100
|
+
const auto & fc = tool_call.at("function");
|
|
101
|
+
if (!fc.contains("name")) {
|
|
102
|
+
throw std::runtime_error("Missing tool call name: " + tool_call.dump());
|
|
103
|
+
}
|
|
104
|
+
tc.name = fc.at("name");
|
|
105
|
+
tc.arguments = fc.at("arguments");
|
|
106
|
+
if (tool_call.contains("id")) {
|
|
107
|
+
tc.id = tool_call.at("id");
|
|
108
|
+
}
|
|
109
|
+
msg.tool_calls.push_back(tc);
|
|
110
|
+
}
|
|
111
|
+
}
|
|
112
|
+
if (!has_content && !has_tool_calls) {
|
|
113
|
+
throw std::runtime_error("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)");
|
|
114
|
+
}
|
|
115
|
+
if (message.contains("reasoning_content")) {
|
|
116
|
+
msg.reasoning_content = message.at("reasoning_content");
|
|
117
|
+
}
|
|
118
|
+
if (message.contains("name")) {
|
|
119
|
+
msg.tool_name = message.at("name");
|
|
120
|
+
}
|
|
121
|
+
if (message.contains("tool_call_id")) {
|
|
122
|
+
msg.tool_call_id = message.at("tool_call_id");
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
msgs.push_back(msg);
|
|
126
|
+
}
|
|
127
|
+
} catch (const std::exception & e) {
|
|
128
|
+
throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2));
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
return msgs;
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
template <>
|
|
135
|
+
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
|
|
136
|
+
json messages = json::array();
|
|
137
|
+
for (const auto & msg : msgs) {
|
|
138
|
+
if (!msg.content.empty() && !msg.content_parts.empty()) {
|
|
139
|
+
throw std::runtime_error("Cannot specify both content and content_parts");
|
|
140
|
+
}
|
|
141
|
+
json jmsg {
|
|
142
|
+
{"role", msg.role},
|
|
143
|
+
};
|
|
144
|
+
if (!msg.content.empty()) {
|
|
145
|
+
jmsg["content"] = msg.content;
|
|
146
|
+
} else if (!msg.content_parts.empty()) {
|
|
147
|
+
if (concat_typed_text) {
|
|
148
|
+
std::string text;
|
|
149
|
+
for (const auto & part : msg.content_parts) {
|
|
150
|
+
if (part.type != "text") {
|
|
151
|
+
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
|
|
152
|
+
continue;
|
|
153
|
+
}
|
|
154
|
+
if (!text.empty()) {
|
|
155
|
+
text += '\n';
|
|
156
|
+
}
|
|
157
|
+
text += part.text;
|
|
158
|
+
}
|
|
159
|
+
jmsg["content"] = text;
|
|
160
|
+
} else {
|
|
161
|
+
auto & parts = jmsg["content"] = json::array();
|
|
162
|
+
for (const auto & part : msg.content_parts) {
|
|
163
|
+
parts.push_back({
|
|
164
|
+
{"type", part.type},
|
|
165
|
+
{"text", part.text},
|
|
166
|
+
});
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
} else {
|
|
170
|
+
jmsg["content"] = json(); // null
|
|
171
|
+
}
|
|
172
|
+
if (!msg.reasoning_content.empty()) {
|
|
173
|
+
jmsg["reasoning_content"] = msg.reasoning_content;
|
|
174
|
+
}
|
|
175
|
+
if (!msg.tool_name.empty()) {
|
|
176
|
+
jmsg["name"] = msg.tool_name;
|
|
177
|
+
}
|
|
178
|
+
if (!msg.tool_call_id.empty()) {
|
|
179
|
+
jmsg["tool_call_id"] = msg.tool_call_id;
|
|
180
|
+
}
|
|
181
|
+
if (!msg.tool_calls.empty()) {
|
|
182
|
+
auto & tool_calls = jmsg["tool_calls"] = json::array();
|
|
183
|
+
for (const auto & tool_call : msg.tool_calls) {
|
|
184
|
+
json tc {
|
|
185
|
+
{"type", "function"},
|
|
186
|
+
{"function", {
|
|
187
|
+
{"name", tool_call.name},
|
|
188
|
+
{"arguments", tool_call.arguments},
|
|
189
|
+
}},
|
|
190
|
+
};
|
|
191
|
+
if (!tool_call.id.empty()) {
|
|
192
|
+
tc["id"] = tool_call.id;
|
|
193
|
+
}
|
|
194
|
+
tool_calls.push_back(tc);
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
messages.push_back(jmsg);
|
|
198
|
+
}
|
|
199
|
+
return messages;
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
template <>
|
|
203
|
+
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
|
|
204
|
+
return common_chat_msgs_parse_oaicompat(json::parse(messages));
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
template <>
|
|
208
|
+
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
|
|
209
|
+
std::vector<common_chat_tool> result;
|
|
210
|
+
|
|
211
|
+
try {
|
|
212
|
+
if (!tools.is_null()) {
|
|
213
|
+
if (!tools.is_array()) {
|
|
214
|
+
throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
|
|
215
|
+
}
|
|
216
|
+
for (const auto & tool : tools) {
|
|
217
|
+
if (!tool.contains("type")) {
|
|
218
|
+
throw std::runtime_error("Missing tool type: " + tool.dump());
|
|
219
|
+
}
|
|
220
|
+
const auto & type = tool.at("type");
|
|
221
|
+
if (!type.is_string() || type != "function") {
|
|
222
|
+
throw std::runtime_error("Unsupported tool type: " + tool.dump());
|
|
223
|
+
}
|
|
224
|
+
if (!tool.contains("function")) {
|
|
225
|
+
throw std::runtime_error("Missing tool function: " + tool.dump());
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
const auto & function = tool.at("function");
|
|
229
|
+
result.push_back({
|
|
230
|
+
/* .name = */ function.at("name"),
|
|
231
|
+
/* .description = */ function.at("description"),
|
|
232
|
+
/* .parameters = */ function.at("parameters").dump(),
|
|
233
|
+
});
|
|
234
|
+
}
|
|
235
|
+
}
|
|
236
|
+
} catch (const std::exception & e) {
|
|
237
|
+
throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2));
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
return result;
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
template <>
|
|
244
|
+
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
|
|
245
|
+
return common_chat_tools_parse_oaicompat(json::parse(tools));
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
template <>
|
|
249
|
+
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
|
|
250
|
+
if (tools.empty()) {
|
|
251
|
+
return json();
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
auto result = json::array();
|
|
255
|
+
for (const auto & tool : tools) {
|
|
256
|
+
result.push_back({
|
|
257
|
+
{"type", "function"},
|
|
258
|
+
{"function", {
|
|
259
|
+
{"name", tool.name},
|
|
260
|
+
{"description", tool.description},
|
|
261
|
+
{"parameters", json::parse(tool.parameters)},
|
|
262
|
+
}},
|
|
263
|
+
});
|
|
264
|
+
}
|
|
265
|
+
return result;
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
|
269
|
+
if (use_jinja) {
|
|
270
|
+
try {
|
|
271
|
+
common_chat_msg msg;
|
|
272
|
+
msg.role = "user";
|
|
273
|
+
msg.content = "test";
|
|
274
|
+
|
|
275
|
+
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);
|
|
276
|
+
|
|
277
|
+
common_chat_templates_inputs inputs;
|
|
278
|
+
inputs.messages = {msg};
|
|
279
|
+
|
|
280
|
+
common_chat_templates_apply(tmpls.get(), inputs);
|
|
281
|
+
return true;
|
|
282
|
+
} catch (const std::exception & e) {
|
|
283
|
+
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
|
284
|
+
return false;
|
|
285
|
+
}
|
|
286
|
+
}
|
|
287
|
+
llama_chat_message chat[] = {{"user", "test"}};
|
|
288
|
+
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
|
|
289
|
+
return res >= 0;
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
std::string common_chat_format_single(
|
|
293
|
+
const struct common_chat_templates * tmpls,
|
|
294
|
+
const std::vector<common_chat_msg> & past_msg,
|
|
295
|
+
const common_chat_msg & new_msg,
|
|
296
|
+
bool add_ass,
|
|
297
|
+
bool use_jinja) {
|
|
298
|
+
|
|
299
|
+
common_chat_templates_inputs inputs;
|
|
300
|
+
inputs.use_jinja = use_jinja;
|
|
301
|
+
|
|
302
|
+
std::string fmt_past_msg;
|
|
303
|
+
if (!past_msg.empty()) {
|
|
304
|
+
inputs.messages = past_msg;
|
|
305
|
+
inputs.add_generation_prompt = false;
|
|
306
|
+
fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt;
|
|
307
|
+
}
|
|
308
|
+
std::ostringstream ss;
|
|
309
|
+
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
|
310
|
+
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
|
311
|
+
ss << "\n";
|
|
312
|
+
};
|
|
313
|
+
// format chat with new_msg
|
|
314
|
+
inputs.messages.push_back(new_msg);
|
|
315
|
+
inputs.add_generation_prompt = add_ass;
|
|
316
|
+
auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt;
|
|
317
|
+
// get the diff part
|
|
318
|
+
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
|
319
|
+
return ss.str();
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
|
|
323
|
+
common_chat_templates_inputs inputs;
|
|
324
|
+
inputs.use_jinja = use_jinja;
|
|
325
|
+
auto add_simple_msg = [&](auto role, auto content) {
|
|
326
|
+
common_chat_msg msg;
|
|
327
|
+
msg.role = role;
|
|
328
|
+
msg.content = content;
|
|
329
|
+
inputs.messages.push_back(msg);
|
|
330
|
+
};
|
|
331
|
+
add_simple_msg("system", "You are a helpful assistant");
|
|
332
|
+
add_simple_msg("user", "Hello");
|
|
333
|
+
add_simple_msg("assistant", "Hi there");
|
|
334
|
+
add_simple_msg("user", "How are you?");
|
|
335
|
+
return common_chat_templates_apply(tmpls, inputs).prompt;
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
#define CHATML_TEMPLATE_SRC \
|
|
339
|
+
"{%- for message in messages -%}\n" \
|
|
340
|
+
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
|
|
341
|
+
"{%- endfor -%}\n" \
|
|
342
|
+
"{%- if add_generation_prompt -%}\n" \
|
|
343
|
+
" {{- '<|im_start|>assistant\n' -}}\n" \
|
|
344
|
+
"{%- endif -%}"
|
|
345
|
+
|
|
346
|
+
void common_chat_templates_free(struct common_chat_templates * tmpls) {
|
|
347
|
+
delete tmpls;
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) {
|
|
351
|
+
return tmpls->has_explicit_template;
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
|
|
355
|
+
if (variant != nullptr) {
|
|
356
|
+
if (strcmp(variant, "tool_use") == 0) {
|
|
357
|
+
if (tmpls->template_tool_use) {
|
|
358
|
+
return tmpls->template_tool_use->source().c_str();
|
|
359
|
+
}
|
|
360
|
+
return nullptr;
|
|
361
|
+
} else {
|
|
362
|
+
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
|
|
363
|
+
}
|
|
364
|
+
}
|
|
365
|
+
return tmpls->template_default->source().c_str();
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
common_chat_templates_ptr common_chat_templates_init(
|
|
369
|
+
const struct llama_model * model,
|
|
370
|
+
const std::string & chat_template_override,
|
|
371
|
+
const std::string & bos_token_override,
|
|
372
|
+
const std::string & eos_token_override)
|
|
373
|
+
{
|
|
374
|
+
std::string default_template_src;
|
|
375
|
+
std::string template_tool_use_src;
|
|
376
|
+
|
|
377
|
+
bool has_explicit_template = !chat_template_override.empty();
|
|
378
|
+
if (chat_template_override.empty()) {
|
|
379
|
+
GGML_ASSERT(model != nullptr);
|
|
380
|
+
const auto * str = llama_model_chat_template(model, /* name */ nullptr);
|
|
381
|
+
if (str) {
|
|
382
|
+
default_template_src = str;
|
|
383
|
+
has_explicit_template = true;
|
|
384
|
+
}
|
|
385
|
+
str = llama_model_chat_template(model, /* name */ "tool_use");
|
|
386
|
+
if (str) {
|
|
387
|
+
template_tool_use_src = str;
|
|
388
|
+
has_explicit_template = true;
|
|
389
|
+
}
|
|
390
|
+
} else {
|
|
391
|
+
default_template_src = chat_template_override;
|
|
392
|
+
}
|
|
393
|
+
if (default_template_src.empty() || default_template_src == "chatml") {
|
|
394
|
+
if (!template_tool_use_src.empty()) {
|
|
395
|
+
default_template_src = template_tool_use_src;
|
|
396
|
+
} else {
|
|
397
|
+
default_template_src = CHATML_TEMPLATE_SRC;
|
|
398
|
+
}
|
|
399
|
+
}
|
|
400
|
+
std::string token_bos = bos_token_override;
|
|
401
|
+
std::string token_eos = eos_token_override;
|
|
402
|
+
if (model) {
|
|
403
|
+
const auto * vocab = llama_model_get_vocab(model);
|
|
404
|
+
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
|
|
405
|
+
if (token == LLAMA_TOKEN_NULL) {
|
|
406
|
+
if (default_template_src.find(jinja_variable_name) != std::string::npos
|
|
407
|
+
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
|
|
408
|
+
LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name);
|
|
409
|
+
}
|
|
410
|
+
return std::string();
|
|
411
|
+
}
|
|
412
|
+
return common_token_to_piece(vocab, token, true);
|
|
413
|
+
};
|
|
414
|
+
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
|
|
415
|
+
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
|
|
416
|
+
}
|
|
417
|
+
common_chat_templates_ptr tmpls(new common_chat_templates());
|
|
418
|
+
tmpls->has_explicit_template = has_explicit_template;
|
|
419
|
+
try {
|
|
420
|
+
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
|
|
421
|
+
} catch (const std::exception & e) {
|
|
422
|
+
LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
|
|
423
|
+
tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
|
|
424
|
+
}
|
|
425
|
+
if (!template_tool_use_src.empty()) {
|
|
426
|
+
try {
|
|
427
|
+
tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
|
|
428
|
+
} catch (const std::exception & e) {
|
|
429
|
+
LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
|
|
430
|
+
}
|
|
431
|
+
}
|
|
432
|
+
return tmpls;
|
|
433
|
+
}
|
|
6
434
|
|
|
7
435
|
std::string common_chat_format_name(common_chat_format format) {
|
|
8
436
|
switch (format) {
|
|
@@ -12,22 +440,19 @@ std::string common_chat_format_name(common_chat_format format) {
|
|
|
12
440
|
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
|
|
13
441
|
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
|
|
14
442
|
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
|
|
443
|
+
case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING: return "DeepSeek R1 (extract reasoning)";
|
|
15
444
|
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
|
|
16
445
|
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
|
|
17
446
|
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
|
|
18
447
|
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
|
|
448
|
+
case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING: return "Hermes 2 Pro (extract reasoning)";
|
|
19
449
|
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
|
|
450
|
+
case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)";
|
|
20
451
|
default:
|
|
21
452
|
throw std::runtime_error("Unknown chat format");
|
|
22
453
|
}
|
|
23
454
|
}
|
|
24
455
|
|
|
25
|
-
const common_grammar_options grammar_options {
|
|
26
|
-
/* .dotall = */ false,
|
|
27
|
-
/* .compact_spaces = */ false,
|
|
28
|
-
// /* .compact_spaces = */ true,
|
|
29
|
-
};
|
|
30
|
-
|
|
31
456
|
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
|
|
32
457
|
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
|
33
458
|
struct json_error_locator : public nlohmann::json_sax<json> {
|
|
@@ -36,22 +461,22 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
|
|
|
36
461
|
|
|
37
462
|
json_error_locator() : position(0), found_error(false) {}
|
|
38
463
|
|
|
39
|
-
bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
|
|
464
|
+
bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT
|
|
40
465
|
this->position = position - 1;
|
|
41
466
|
this->found_error = true;
|
|
42
467
|
return false;
|
|
43
468
|
}
|
|
44
|
-
bool null() override { return true; }
|
|
45
|
-
bool boolean(bool) override { return true; }
|
|
46
|
-
bool number_integer(number_integer_t) override { return true; }
|
|
47
|
-
bool number_unsigned(number_unsigned_t) override { return true; }
|
|
48
|
-
bool number_float(number_float_t, const string_t &) override { return true; }
|
|
49
|
-
bool string(string_t &) override { return true; }
|
|
50
|
-
bool binary(binary_t &) override { return true; }
|
|
51
|
-
bool start_object(std::size_t) override { return true; }
|
|
52
|
-
bool key(string_t &) override { return true; }
|
|
469
|
+
bool null() override { return true; } // NOLINT
|
|
470
|
+
bool boolean(bool) override { return true; } // NOLINT
|
|
471
|
+
bool number_integer(number_integer_t) override { return true; } // NOLINT
|
|
472
|
+
bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT
|
|
473
|
+
bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT
|
|
474
|
+
bool string(string_t &) override { return true; } // NOLINT
|
|
475
|
+
bool binary(binary_t &) override { return true; } // NOLINT
|
|
476
|
+
bool start_object(std::size_t) override { return true; } // NOLINT
|
|
477
|
+
bool key(string_t &) override { return true; } // NOLINT
|
|
53
478
|
bool end_object() override { return true; }
|
|
54
|
-
bool start_array(std::size_t) override { return true; }
|
|
479
|
+
bool start_array(std::size_t) override { return true; } // NOLINT
|
|
55
480
|
bool end_array() override { return true; }
|
|
56
481
|
};
|
|
57
482
|
json_error_locator err_loc;
|
|
@@ -73,6 +498,34 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
|
|
|
73
498
|
}
|
|
74
499
|
}
|
|
75
500
|
|
|
501
|
+
static bool parse_literal(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
|
|
502
|
+
auto expected_it = expected.begin();
|
|
503
|
+
auto tmp_it = it;
|
|
504
|
+
while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
|
|
505
|
+
++tmp_it;
|
|
506
|
+
++expected_it;
|
|
507
|
+
}
|
|
508
|
+
if (expected_it == expected.end()) {
|
|
509
|
+
it = tmp_it;
|
|
510
|
+
return true;
|
|
511
|
+
}
|
|
512
|
+
return false;
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
static std::optional<std::smatch> parse_pattern(std::string::const_iterator & it, const std::string::const_iterator & end, const std::regex & expected) {
|
|
516
|
+
std::smatch match;
|
|
517
|
+
if (std::regex_match(it, end, match, expected)) {
|
|
518
|
+
it = match.suffix().first;
|
|
519
|
+
return match;
|
|
520
|
+
}
|
|
521
|
+
return std::nullopt;
|
|
522
|
+
}
|
|
523
|
+
|
|
524
|
+
static void consume_spaces(std::string::const_iterator & it, const std::string::const_iterator & end) {
|
|
525
|
+
while (it != end && std::isspace(*it)) {
|
|
526
|
+
++it;
|
|
527
|
+
}
|
|
528
|
+
}
|
|
76
529
|
|
|
77
530
|
/**
|
|
78
531
|
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
|
@@ -82,7 +535,8 @@ static common_chat_msg parse_json_tool_calls(
|
|
|
82
535
|
const std::string& input,
|
|
83
536
|
const std::optional<std::regex> & trigger_opt,
|
|
84
537
|
const std::regex & function_regex,
|
|
85
|
-
const std::regex & close_regex
|
|
538
|
+
const std::regex & close_regex,
|
|
539
|
+
bool allow_raw_python = false) {
|
|
86
540
|
std::smatch match;
|
|
87
541
|
|
|
88
542
|
common_chat_msg result;
|
|
@@ -105,7 +559,6 @@ static common_chat_msg parse_json_tool_calls(
|
|
|
105
559
|
std::sregex_iterator rend;
|
|
106
560
|
std::sregex_iterator rit(it, end, function_regex);
|
|
107
561
|
if (rit == rend) {
|
|
108
|
-
fprintf(stderr, "No more tool calls found\n");
|
|
109
562
|
result.content += std::string(it, end);
|
|
110
563
|
break;
|
|
111
564
|
}
|
|
@@ -114,48 +567,60 @@ static common_chat_msg parse_json_tool_calls(
|
|
|
114
567
|
it = rit->suffix().first;
|
|
115
568
|
|
|
116
569
|
json arguments;
|
|
117
|
-
if (
|
|
118
|
-
|
|
570
|
+
if (parse_json(it, end, arguments)) {
|
|
571
|
+
if (!std::regex_search(it, end, match, close_regex)) {
|
|
572
|
+
throw std::runtime_error("Malformed input, missing closing pattern: " + input);
|
|
573
|
+
}
|
|
574
|
+
it = match.suffix().first;
|
|
575
|
+
result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
|
|
576
|
+
} else {
|
|
577
|
+
if (allow_raw_python && name == "python") {
|
|
578
|
+
result.tool_calls.push_back({name, json({{"code", std::string(it, end)}}).dump(), /* id= */ ""});
|
|
579
|
+
break;
|
|
580
|
+
}
|
|
581
|
+
throw std::runtime_error("Failed to parse json tool call arguments: " + input);
|
|
119
582
|
}
|
|
120
|
-
|
|
121
|
-
|
|
583
|
+
}
|
|
584
|
+
|
|
585
|
+
if (!result.tool_calls.empty()) {
|
|
586
|
+
if (!string_strip(result.content).empty()) {
|
|
587
|
+
LOG_WRN("Content found with tool calls: %s\n", result.content.c_str());
|
|
122
588
|
}
|
|
123
|
-
|
|
124
|
-
result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
|
|
589
|
+
result.content = "";
|
|
125
590
|
}
|
|
126
591
|
return result;
|
|
127
592
|
}
|
|
128
593
|
|
|
594
|
+
static common_chat_tool_call process_tool_call(const json & tool_call) {
|
|
595
|
+
const auto & arguments = tool_call.at("arguments");
|
|
596
|
+
return {
|
|
597
|
+
/* .name = */ tool_call.at("name"),
|
|
598
|
+
/* .arguments = */ arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
|
599
|
+
/* .id = */ tool_call.contains("id") ? tool_call.at("id") : "",
|
|
600
|
+
};
|
|
601
|
+
}
|
|
129
602
|
static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
|
|
130
603
|
auto content_end = input.find(prefix);
|
|
131
604
|
size_t tc_start = std::string::npos;
|
|
132
605
|
|
|
133
606
|
common_chat_msg result;
|
|
134
607
|
result.role = "assistant";
|
|
135
|
-
const auto process_tool_calls = [&](const json & tool_calls) {
|
|
136
|
-
for (const auto & tool_call : tool_calls) {
|
|
137
|
-
const auto & arguments = tool_call["arguments"];
|
|
138
|
-
result.tool_calls.push_back({
|
|
139
|
-
tool_call["name"],
|
|
140
|
-
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
|
141
|
-
tool_call.contains("id") ? tool_call["id"] : "",
|
|
142
|
-
});
|
|
143
|
-
}
|
|
144
|
-
};
|
|
145
608
|
if (content_end == std::string::npos) {
|
|
146
609
|
result.content = input;
|
|
147
610
|
} else {
|
|
148
611
|
tc_start = content_end + prefix.size() - rstrip_prefix;
|
|
149
612
|
result.content = input.substr(0, content_end);
|
|
150
613
|
auto tool_calls = json::parse(input.substr(tc_start));
|
|
151
|
-
|
|
614
|
+
for (const auto & tool_call : tool_calls) {
|
|
615
|
+
result.tool_calls.emplace_back(process_tool_call(tool_call));
|
|
616
|
+
}
|
|
152
617
|
}
|
|
153
618
|
return result;
|
|
154
619
|
}
|
|
155
620
|
|
|
156
621
|
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
|
|
157
622
|
for (const auto & tool : tools) {
|
|
158
|
-
if (!tool.contains("type") || tool
|
|
623
|
+
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
|
|
159
624
|
LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());
|
|
160
625
|
continue;
|
|
161
626
|
}
|
|
@@ -179,38 +644,45 @@ static std::string apply(
|
|
|
179
644
|
// tmpl_inputs.now = std::chrono::system_clock::now();
|
|
180
645
|
|
|
181
646
|
minja::chat_template_options tmpl_opts;
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
647
|
+
// To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
|
|
648
|
+
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
|
|
649
|
+
// may be needed inside the template / between messages too.
|
|
650
|
+
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
|
|
651
|
+
if (string_starts_with(result, tmpl.bos_token())) {
|
|
652
|
+
result = result.substr(tmpl.bos_token().size());
|
|
653
|
+
}
|
|
654
|
+
if (string_ends_with(result, tmpl.eos_token())) {
|
|
655
|
+
result = result.substr(0, result.size() - tmpl.eos_token().size());
|
|
656
|
+
}
|
|
657
|
+
return result;
|
|
186
658
|
}
|
|
187
659
|
|
|
188
|
-
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct
|
|
660
|
+
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
189
661
|
common_chat_params data;
|
|
190
662
|
|
|
191
663
|
auto tool_call_schemas = json::array();
|
|
192
664
|
foreach_function(inputs.tools, [&](const json & tool) {
|
|
193
|
-
const auto & function = tool
|
|
665
|
+
const auto & function = tool.at("function");
|
|
194
666
|
auto tool_schema = json {
|
|
195
667
|
{"type", "object"},
|
|
196
668
|
{"properties", {
|
|
197
669
|
{"name", {
|
|
198
670
|
{"type", "string"},
|
|
199
|
-
{"const", function
|
|
671
|
+
{"const", function.at("name")},
|
|
200
672
|
}},
|
|
201
|
-
{"arguments", function
|
|
673
|
+
{"arguments", function.at("parameters")},
|
|
202
674
|
}},
|
|
203
675
|
{"required", json::array({"name", "arguments"})},
|
|
204
676
|
};
|
|
205
677
|
if (function.contains("description")) {
|
|
206
|
-
tool_schema["description"] = function
|
|
678
|
+
tool_schema["description"] = function.at("description");
|
|
207
679
|
}
|
|
208
680
|
if (inputs.parallel_tool_calls) {
|
|
209
|
-
tool_schema
|
|
681
|
+
tool_schema.at("properties")["id"] = {
|
|
210
682
|
{"type", "string"},
|
|
211
683
|
{"minLength", 4},
|
|
212
684
|
};
|
|
213
|
-
tool_schema
|
|
685
|
+
tool_schema.at("required").push_back("id");
|
|
214
686
|
}
|
|
215
687
|
tool_call_schemas.emplace_back(tool_schema);
|
|
216
688
|
});
|
|
@@ -239,7 +711,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
|
|
|
239
711
|
{"required", json::array({"tool_call"})},
|
|
240
712
|
};
|
|
241
713
|
const auto schema =
|
|
242
|
-
inputs.tool_choice !=
|
|
714
|
+
inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED
|
|
243
715
|
? json {
|
|
244
716
|
{"anyOf", json::array({
|
|
245
717
|
tool_call,
|
|
@@ -260,7 +732,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
|
|
|
260
732
|
data.grammar_lazy = false;
|
|
261
733
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
262
734
|
builder.add_schema("root", schema);
|
|
263
|
-
}
|
|
735
|
+
});
|
|
264
736
|
|
|
265
737
|
auto tweaked_messages = common_chat_template::add_system(
|
|
266
738
|
inputs.messages,
|
|
@@ -275,33 +747,33 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) {
|
|
|
275
747
|
common_chat_msg result;
|
|
276
748
|
result.role = "assistant";
|
|
277
749
|
if (data.contains("tool_calls")) {
|
|
278
|
-
for (const auto & tool_call : data
|
|
750
|
+
for (const auto & tool_call : data.at("tool_calls")) {
|
|
279
751
|
result.tool_calls.push_back({
|
|
280
|
-
tool_call
|
|
281
|
-
tool_call
|
|
282
|
-
tool_call.contains("id") ? tool_call
|
|
752
|
+
tool_call.at("name"),
|
|
753
|
+
tool_call.at("arguments").dump(),
|
|
754
|
+
tool_call.contains("id") ? tool_call.at("id") : "",
|
|
283
755
|
});
|
|
284
756
|
}
|
|
285
757
|
} else if (data.contains("tool_call")) {
|
|
286
758
|
result.tool_calls.push_back({
|
|
287
|
-
data
|
|
288
|
-
data
|
|
759
|
+
data.at("tool_call").at("name"),
|
|
760
|
+
data.at("tool_call").at("arguments").dump(),
|
|
289
761
|
/* id= */ "",
|
|
290
762
|
});
|
|
291
763
|
} else if (data.contains("response")) {
|
|
292
|
-
const auto & response = data
|
|
764
|
+
const auto & response = data.at("response");
|
|
293
765
|
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
|
|
294
766
|
}
|
|
295
767
|
return result;
|
|
296
768
|
}
|
|
297
769
|
|
|
298
|
-
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct
|
|
770
|
+
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
299
771
|
common_chat_params data;
|
|
300
|
-
data.grammar_lazy = inputs.tool_choice !=
|
|
772
|
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
301
773
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
302
774
|
auto schemas = json::array();
|
|
303
775
|
foreach_function(inputs.tools, [&](const json & tool) {
|
|
304
|
-
const auto & function = tool
|
|
776
|
+
const auto & function = tool.at("function");
|
|
305
777
|
schemas.push_back({
|
|
306
778
|
{"type", "object"},
|
|
307
779
|
{"properties", {
|
|
@@ -309,9 +781,9 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
|
|
|
309
781
|
// It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
|
|
310
782
|
{"name", {
|
|
311
783
|
{"type", "string"},
|
|
312
|
-
{"const", function
|
|
784
|
+
{"const", function.at("name")},
|
|
313
785
|
}},
|
|
314
|
-
{"arguments", function
|
|
786
|
+
{"arguments", function.at("parameters")},
|
|
315
787
|
{"id", {
|
|
316
788
|
{"type", "string"},
|
|
317
789
|
// Nemo's template expects a 9-character alphanumeric ID.
|
|
@@ -330,8 +802,11 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
|
|
|
330
802
|
schema["maxItems"] = 1;
|
|
331
803
|
}
|
|
332
804
|
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
|
333
|
-
}
|
|
334
|
-
data.grammar_triggers.push_back({"[TOOL_CALLS]"
|
|
805
|
+
});
|
|
806
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"});
|
|
807
|
+
data.preserved_tokens = {
|
|
808
|
+
"[TOOL_CALLS]",
|
|
809
|
+
};
|
|
335
810
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
|
336
811
|
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
|
|
337
812
|
return data;
|
|
@@ -340,13 +815,13 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
|
|
|
340
815
|
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
|
341
816
|
}
|
|
342
817
|
|
|
343
|
-
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct
|
|
818
|
+
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
344
819
|
common_chat_params data;
|
|
345
|
-
data.grammar_lazy = inputs.tool_choice !=
|
|
820
|
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
346
821
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
347
822
|
auto schemas = json::array();
|
|
348
823
|
foreach_function(inputs.tools, [&](const json & tool) {
|
|
349
|
-
const auto & function = tool
|
|
824
|
+
const auto & function = tool.at("function");
|
|
350
825
|
schemas.push_back({
|
|
351
826
|
{"type", "object"},
|
|
352
827
|
{"properties", {
|
|
@@ -357,9 +832,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
|
|
357
832
|
}},
|
|
358
833
|
{"tool_name", {
|
|
359
834
|
{"type", "string"},
|
|
360
|
-
{"const", function
|
|
835
|
+
{"const", function.at("name")},
|
|
361
836
|
}},
|
|
362
|
-
{"parameters", function
|
|
837
|
+
{"parameters", function.at("parameters")},
|
|
363
838
|
}},
|
|
364
839
|
{"required", json::array({"tool_call_id", "tool_name", "parameters"})},
|
|
365
840
|
});
|
|
@@ -373,58 +848,88 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
|
|
373
848
|
schema["maxItems"] = 1;
|
|
374
849
|
}
|
|
375
850
|
builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
|
|
376
|
-
}
|
|
377
|
-
data.grammar_triggers.push_back({
|
|
851
|
+
});
|
|
852
|
+
data.grammar_triggers.push_back({
|
|
853
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
|
854
|
+
"<|START_ACTION|>",
|
|
855
|
+
});
|
|
378
856
|
data.preserved_tokens = {
|
|
857
|
+
"<|START_ACTION|>",
|
|
858
|
+
"<|END_ACTION|>",
|
|
379
859
|
"<|START_RESPONSE|>",
|
|
380
860
|
"<|END_RESPONSE|>",
|
|
381
861
|
"<|START_THINKING|>",
|
|
382
862
|
"<|END_THINKING|>",
|
|
383
|
-
"<|END_ACTION|>",
|
|
384
863
|
};
|
|
385
|
-
|
|
386
|
-
|
|
864
|
+
auto adjusted_messages = json::array();
|
|
865
|
+
for (const auto & msg : inputs.messages) {
|
|
866
|
+
auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
|
|
867
|
+
auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
|
|
868
|
+
if (has_reasoning_content && has_tool_calls) {
|
|
869
|
+
auto adjusted_message = msg;
|
|
870
|
+
adjusted_message["tool_plan"] = msg.at("reasoning_content");
|
|
871
|
+
adjusted_message.erase("reasoning_content");
|
|
872
|
+
adjusted_messages.push_back(adjusted_message);
|
|
873
|
+
} else {
|
|
874
|
+
adjusted_messages.push_back(msg);
|
|
875
|
+
}
|
|
876
|
+
}
|
|
877
|
+
data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {});
|
|
878
|
+
data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING : COMMON_CHAT_FORMAT_COMMAND_R7B;
|
|
387
879
|
return data;
|
|
388
880
|
}
|
|
389
|
-
static common_chat_msg common_chat_parse_command_r7b(const std::string & input) {
|
|
390
|
-
static std::regex
|
|
391
|
-
static std::regex
|
|
881
|
+
static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) {
|
|
882
|
+
static const std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)");
|
|
883
|
+
static const std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>");
|
|
884
|
+
static const std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>");
|
|
885
|
+
|
|
392
886
|
std::smatch match;
|
|
393
887
|
|
|
394
888
|
common_chat_msg result;
|
|
395
889
|
result.role = "assistant";
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
890
|
+
|
|
891
|
+
std::string rest = input;
|
|
892
|
+
|
|
893
|
+
if (std::regex_match(rest, match, thought_regex)) {
|
|
894
|
+
if (extract_reasoning) {
|
|
895
|
+
result.reasoning_content = match[2].str();
|
|
896
|
+
} else if (!match[2].str().empty()) {
|
|
897
|
+
// Let the unparsed thinking tags through in content only if their insides aren't empty.
|
|
898
|
+
result.content = match[1].str();
|
|
899
|
+
}
|
|
900
|
+
rest = match[3].str();
|
|
901
|
+
}
|
|
902
|
+
if (std::regex_match(rest, match, action_regex)) {
|
|
903
|
+
auto actions_str = match[1].str();
|
|
401
904
|
auto actions = json::parse(actions_str);
|
|
402
905
|
for (const auto & action : actions) {
|
|
403
906
|
result.tool_calls.push_back({
|
|
404
|
-
/* .name = */ action
|
|
405
|
-
/* .arguments = */ action
|
|
406
|
-
/* .id = */ action
|
|
907
|
+
/* .name = */ action.at("tool_name"),
|
|
908
|
+
/* .arguments = */ action.at("parameters").dump(),
|
|
909
|
+
/* .id = */ action.at("tool_call_id"),
|
|
407
910
|
});
|
|
408
911
|
}
|
|
912
|
+
} else if (std::regex_match(rest, match, response_regex)) {
|
|
913
|
+
auto response = match[1].str();
|
|
914
|
+
result.content += response;
|
|
409
915
|
} else {
|
|
410
|
-
|
|
411
|
-
result.content = input;
|
|
916
|
+
result.content += rest;
|
|
412
917
|
}
|
|
413
918
|
return result;
|
|
414
919
|
}
|
|
415
920
|
|
|
416
921
|
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
|
|
417
|
-
if (!parameters.is_object() || !parameters.contains("type") || parameters
|
|
922
|
+
if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
|
|
418
923
|
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
|
|
419
924
|
}
|
|
420
925
|
const auto & parameters_properties = parameters.at("properties");
|
|
421
926
|
const auto & parameters_required = parameters.at("required");
|
|
422
927
|
for (const auto & prop : expected_properties) {
|
|
423
928
|
if (!parameters_properties.contains(prop)) {
|
|
424
|
-
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
|
|
929
|
+
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT
|
|
425
930
|
}
|
|
426
931
|
if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
|
|
427
|
-
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
|
|
932
|
+
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT
|
|
428
933
|
}
|
|
429
934
|
}
|
|
430
935
|
if (parameters_properties.size() != expected_properties.size()) {
|
|
@@ -432,18 +937,16 @@ static void expect_tool_parameters(const std::string & name, const json & parame
|
|
|
432
937
|
}
|
|
433
938
|
}
|
|
434
939
|
|
|
435
|
-
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct
|
|
940
|
+
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
|
|
436
941
|
auto builtin_tools = json::array();
|
|
437
942
|
common_chat_params data;
|
|
438
|
-
data.grammar_lazy = inputs.tool_choice !=
|
|
943
|
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
439
944
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
440
945
|
std::vector<std::string> tool_rules;
|
|
441
946
|
|
|
442
947
|
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
|
443
|
-
if (name == "wolfram_alpha") {
|
|
948
|
+
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
|
|
444
949
|
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
|
445
|
-
expect_tool_parameters(name, parameters, {"query"});
|
|
446
|
-
} else if (name == "web_search" || name == "brave_search") {
|
|
447
950
|
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
|
448
951
|
expect_tool_parameters(name, parameters, {"query"});
|
|
449
952
|
} else if (name == "python" || name == "code_interpreter") {
|
|
@@ -455,7 +958,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
|
|
455
958
|
|
|
456
959
|
std::vector<std::string> kvs;
|
|
457
960
|
for (const auto & [key, value] : parameters.at("properties").items()) {
|
|
458
|
-
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
|
|
961
|
+
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
|
|
459
962
|
}
|
|
460
963
|
|
|
461
964
|
tool_rules.push_back(
|
|
@@ -468,9 +971,9 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
|
|
468
971
|
};
|
|
469
972
|
|
|
470
973
|
foreach_function(inputs.tools, [&](const json & tool) {
|
|
471
|
-
const auto & function = tool
|
|
472
|
-
std::string name = function
|
|
473
|
-
auto parameters = function
|
|
974
|
+
const auto & function = tool.at("function");
|
|
975
|
+
std::string name = function.at("name");
|
|
976
|
+
auto parameters = function.at("parameters");
|
|
474
977
|
builder.resolve_refs(parameters);
|
|
475
978
|
|
|
476
979
|
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
|
|
@@ -481,23 +984,23 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
|
|
481
984
|
builder.add_rule(
|
|
482
985
|
name + "-call",
|
|
483
986
|
"\"{\" space "
|
|
484
|
-
"( \"\\\"type\\\":\" space \"\\\"function\\\",\" space )? "
|
|
485
|
-
"\"\\\"name\\\"
|
|
486
|
-
|
|
487
|
-
"
|
|
488
|
-
|
|
987
|
+
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
|
|
988
|
+
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
|
|
989
|
+
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
|
|
990
|
+
"\"}\" space"));
|
|
991
|
+
});
|
|
992
|
+
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
|
|
993
|
+
data.grammar_triggers.push_back({
|
|
994
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
|
995
|
+
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
|
|
489
996
|
});
|
|
490
|
-
data.grammar_triggers.push_back({"{\"name\":", /* .at_start = */ true});
|
|
491
|
-
data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true});
|
|
492
|
-
data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true});
|
|
493
|
-
data.grammar_triggers.push_back({"{\"type\": \"function\"", /* .at_start = */ true});
|
|
494
|
-
data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true});
|
|
495
|
-
data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true});
|
|
496
997
|
if (!builtin_tools.empty()) {
|
|
497
|
-
data.grammar_triggers.push_back({"<|python_tag|>"
|
|
998
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
|
999
|
+
data.preserved_tokens.push_back("<|python_tag|>");
|
|
498
1000
|
}
|
|
1001
|
+
// Allow a few empty lines on top of the usual constrained json schema space rule.
|
|
499
1002
|
builder.add_rule("root", string_join(tool_rules, " | "));
|
|
500
|
-
}
|
|
1003
|
+
});
|
|
501
1004
|
data.additional_stops.push_back("<|eom_id|>");
|
|
502
1005
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
|
503
1006
|
{"tools_in_user_message", false},
|
|
@@ -510,93 +1013,158 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
|
|
510
1013
|
}
|
|
511
1014
|
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
|
|
512
1015
|
// TODO: tighten & simplify the parser, don't accept leading text context.
|
|
513
|
-
static std::regex function_regex(
|
|
514
|
-
|
|
515
|
-
static std::regex
|
|
1016
|
+
static const std::regex function_regex(
|
|
1017
|
+
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
|
|
1018
|
+
static const std::regex close_regex("\\}\\s*");
|
|
1019
|
+
static const std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)");
|
|
516
1020
|
|
|
517
1021
|
if (with_builtin_tools) {
|
|
518
1022
|
std::smatch match;
|
|
519
1023
|
if (std::regex_match(input, match, builtin_call_regex)) {
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
},
|
|
540
|
-
},
|
|
541
|
-
};
|
|
1024
|
+
try {
|
|
1025
|
+
auto name = match[1].str();
|
|
1026
|
+
auto arg_name = match[2].str();
|
|
1027
|
+
auto arg_value_str = match[3].str();
|
|
1028
|
+
auto arg_value = json::parse(arg_value_str);
|
|
1029
|
+
|
|
1030
|
+
common_chat_msg msg;
|
|
1031
|
+
msg.role = "assistant";
|
|
1032
|
+
msg.tool_calls.push_back({
|
|
1033
|
+
/* .name = */ name,
|
|
1034
|
+
/* .arguments = */ (json {
|
|
1035
|
+
{arg_name, arg_value},
|
|
1036
|
+
}).dump(),
|
|
1037
|
+
/* .id = */ "",
|
|
1038
|
+
});
|
|
1039
|
+
return msg;
|
|
1040
|
+
} catch (const std::exception & e) {
|
|
1041
|
+
LOG_WRN("Failed to parse builtin tool call arguments (%s): %s", e.what(), input.c_str());
|
|
1042
|
+
}
|
|
542
1043
|
}
|
|
543
1044
|
}
|
|
544
1045
|
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
|
545
1046
|
}
|
|
546
1047
|
|
|
547
|
-
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct
|
|
1048
|
+
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
548
1049
|
common_chat_params data;
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
const
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
1050
|
+
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
|
1051
|
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
|
|
1052
|
+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
1053
|
+
std::vector<std::string> tool_rules;
|
|
1054
|
+
foreach_function(inputs.tools, [&](const json & tool) {
|
|
1055
|
+
const auto & function = tool.at("function");
|
|
1056
|
+
std::string name = function.at("name");
|
|
1057
|
+
auto parameters = function.at("parameters");
|
|
1058
|
+
builder.resolve_refs(parameters);
|
|
1059
|
+
tool_rules.push_back(builder.add_rule(name + "-call",
|
|
1060
|
+
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
|
|
1061
|
+
"```json\\n\" " + builder.add_schema(name + "-args", parameters) + " "
|
|
1062
|
+
"\"```<|tool▁call▁end|>\""));
|
|
1063
|
+
});
|
|
1064
|
+
// Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
|
|
1065
|
+
// so we accept common variants (then it's all constrained)
|
|
1066
|
+
builder.add_rule("root",
|
|
1067
|
+
"( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" ) "
|
|
1068
|
+
"(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
|
|
1069
|
+
"\"<|tool▁calls▁end|>\""
|
|
1070
|
+
" space");
|
|
1071
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool▁calls▁begin|>"});
|
|
1072
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls_begin|>"});
|
|
1073
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool calls begin|>"});
|
|
1074
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool\\_calls\\_begin|>"});
|
|
1075
|
+
data.preserved_tokens = {
|
|
1076
|
+
"<think>",
|
|
1077
|
+
"</think>",
|
|
1078
|
+
"<|tool▁calls▁begin|>",
|
|
1079
|
+
"<|tool▁call▁begin|>",
|
|
1080
|
+
"<|tool▁sep|>",
|
|
1081
|
+
"<|tool▁call▁end|>",
|
|
1082
|
+
"<|tool▁calls▁end|",
|
|
1083
|
+
};
|
|
559
1084
|
});
|
|
560
|
-
|
|
561
|
-
data.preserved_tokens = {
|
|
562
|
-
"<|tool▁sep|>",
|
|
563
|
-
"<|tool▁call▁end|>",
|
|
564
|
-
};
|
|
565
|
-
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
|
|
566
|
-
}, grammar_options);
|
|
1085
|
+
}
|
|
567
1086
|
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
|
1087
|
+
|
|
1088
|
+
// Hacks to fix the official (broken) prompt.
|
|
1089
|
+
// It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
|
|
1090
|
+
// until the official template is fixed.
|
|
1091
|
+
if (tmpl.source().find("{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}") != std::string::npos) {
|
|
1092
|
+
// Don't leave the chat dangling after tool results
|
|
1093
|
+
if (string_ends_with(prompt, "<|tool▁outputs▁end|>")) {
|
|
1094
|
+
prompt += "<|end▁of▁sentence|>";
|
|
1095
|
+
if (inputs.add_generation_prompt) {
|
|
1096
|
+
prompt += "<|Assistant|>";
|
|
1097
|
+
}
|
|
1098
|
+
}
|
|
1099
|
+
// Fix up tool call delta example added by Minja
|
|
1100
|
+
prompt = std::regex_replace(
|
|
1101
|
+
prompt,
|
|
1102
|
+
std::regex("(<|tool▁call▁end|>)[\\s\\r\\n]*(<|tool▁outputs▁begin|>|<|User|>)"),
|
|
1103
|
+
"$1<|tool▁calls▁end|><|end▁of▁sentence|>$2");
|
|
1104
|
+
}
|
|
568
1105
|
data.prompt = prompt;
|
|
569
|
-
data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1;
|
|
1106
|
+
data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1;
|
|
570
1107
|
return data;
|
|
571
1108
|
}
|
|
572
|
-
static common_chat_msg
|
|
573
|
-
|
|
574
|
-
static std::regex
|
|
575
|
-
|
|
576
|
-
|
|
1109
|
+
static common_chat_msg handle_think_tag_prelude(const std::string & input, bool extract_reasoning, const std::function<common_chat_msg(const std::string &)> & rest_parser) {
|
|
1110
|
+
std::smatch match;
|
|
1111
|
+
static const std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
|
|
1112
|
+
if (std::regex_match(input, match, reasoning_content_regex)) {
|
|
1113
|
+
auto rest = match[3].str();
|
|
1114
|
+
auto msg = rest_parser(rest);
|
|
1115
|
+
auto reasoning_content = string_strip(match[2].str());
|
|
1116
|
+
if (extract_reasoning) {
|
|
1117
|
+
msg.reasoning_content = reasoning_content;
|
|
1118
|
+
} else if (!reasoning_content.empty()) {
|
|
1119
|
+
std::ostringstream content;
|
|
1120
|
+
content << "<think>" << reasoning_content << "</think>" << msg.content;
|
|
1121
|
+
msg.content = content.str();
|
|
1122
|
+
}
|
|
1123
|
+
return msg;
|
|
1124
|
+
}
|
|
1125
|
+
return rest_parser(input);
|
|
577
1126
|
}
|
|
1127
|
+
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) {
|
|
1128
|
+
return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
|
|
1129
|
+
static const std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
|
1130
|
+
static const std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
|
|
1131
|
+
static const std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>");
|
|
578
1132
|
|
|
579
|
-
|
|
580
|
-
|
|
1133
|
+
common_chat_msg msg;
|
|
1134
|
+
msg.role = "assistant";
|
|
1135
|
+
std::smatch match;
|
|
1136
|
+
if (std::regex_search(input, match, tool_calls_regex)) {
|
|
1137
|
+
auto tool_calls = match[1].str();
|
|
1138
|
+
auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
|
|
1139
|
+
msg.tool_calls = std::move(msg2.tool_calls);
|
|
1140
|
+
} else {
|
|
1141
|
+
msg.content = input;
|
|
1142
|
+
}
|
|
1143
|
+
return msg;
|
|
1144
|
+
});
|
|
1145
|
+
}
|
|
1146
|
+
|
|
1147
|
+
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
1148
|
+
LOG_DBG("%s\n", __func__);
|
|
581
1149
|
common_chat_params data;
|
|
582
1150
|
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
|
583
1151
|
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
|
584
1152
|
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
|
585
1153
|
});
|
|
586
|
-
if (
|
|
587
|
-
data.grammar_lazy = inputs.tool_choice !=
|
|
1154
|
+
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
|
1155
|
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
588
1156
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
589
1157
|
auto schemas = json::array();
|
|
590
1158
|
foreach_function(inputs.tools, [&](const json & tool) {
|
|
591
|
-
const auto & function = tool
|
|
1159
|
+
const auto & function = tool.at("function");
|
|
592
1160
|
schemas.push_back({
|
|
593
1161
|
{"type", "object"},
|
|
594
1162
|
{"properties", {
|
|
595
1163
|
{"name", {
|
|
596
1164
|
{"type", "string"},
|
|
597
|
-
{"const", function
|
|
1165
|
+
{"const", function.at("name")},
|
|
598
1166
|
}},
|
|
599
|
-
{"arguments", function
|
|
1167
|
+
{"arguments", function.at("parameters")},
|
|
600
1168
|
}},
|
|
601
1169
|
{"required", json::array({"name", "arguments", "id"})},
|
|
602
1170
|
});
|
|
@@ -610,8 +1178,11 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
|
|
|
610
1178
|
schema["maxItems"] = 1;
|
|
611
1179
|
}
|
|
612
1180
|
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
|
613
|
-
}
|
|
614
|
-
data.grammar_triggers.push_back({" functools["
|
|
1181
|
+
});
|
|
1182
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, " functools["});
|
|
1183
|
+
data.preserved_tokens = {
|
|
1184
|
+
" functools[",
|
|
1185
|
+
};
|
|
615
1186
|
data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2;
|
|
616
1187
|
} else {
|
|
617
1188
|
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
@@ -622,27 +1193,45 @@ static common_chat_msg common_chat_parse_firefunction_v2(const std::string & inp
|
|
|
622
1193
|
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
|
623
1194
|
}
|
|
624
1195
|
|
|
625
|
-
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct
|
|
1196
|
+
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
626
1197
|
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
|
627
1198
|
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
|
628
1199
|
common_chat_params data;
|
|
629
1200
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
|
630
1201
|
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
|
|
631
|
-
if (
|
|
632
|
-
data.grammar_lazy = inputs.tool_choice !=
|
|
1202
|
+
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
|
1203
|
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
633
1204
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
634
1205
|
std::vector<std::string> first_tool_rules;
|
|
635
1206
|
std::vector<std::string> subsequent_tool_rules;
|
|
636
1207
|
foreach_function(inputs.tools, [&](const json & tool) {
|
|
637
|
-
const auto & function = tool
|
|
638
|
-
std::string name = function
|
|
639
|
-
auto parameters = function
|
|
1208
|
+
const auto & function = tool.at("function");
|
|
1209
|
+
std::string name = function.at("name");
|
|
1210
|
+
auto parameters = function.at("parameters");
|
|
1211
|
+
builder.resolve_refs(parameters);
|
|
640
1212
|
auto args_rule = builder.add_schema(name + "-args", parameters);
|
|
641
|
-
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
|
1213
|
+
first_tool_rules.push_back(builder.add_rule(name + "-call", "( \"assistant<|end_header_id|>\\n\" )? \"" + name + "\\n\" " + args_rule));
|
|
642
1214
|
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
|
|
643
|
-
data.grammar_triggers.push_back({
|
|
644
|
-
|
|
1215
|
+
data.grammar_triggers.push_back({
|
|
1216
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
|
1217
|
+
regex_escape(name + "\n"),
|
|
1218
|
+
});
|
|
1219
|
+
data.grammar_triggers.push_back({
|
|
1220
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
|
1221
|
+
regex_escape("assistant<|end_header_id|>\n" + name + "\n"),
|
|
1222
|
+
});
|
|
1223
|
+
data.grammar_triggers.push_back({
|
|
1224
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
|
1225
|
+
regex_escape(">>>" + name + "\n"),
|
|
1226
|
+
});
|
|
1227
|
+
data.grammar_triggers.push_back({
|
|
1228
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
|
1229
|
+
">>>assistant<|end_header_id|>\n" + name,
|
|
1230
|
+
});
|
|
645
1231
|
});
|
|
1232
|
+
data.preserved_tokens = {
|
|
1233
|
+
"<|end_header_id|>",
|
|
1234
|
+
};
|
|
646
1235
|
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
|
647
1236
|
if (inputs.parallel_tool_calls) {
|
|
648
1237
|
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
|
@@ -651,34 +1240,20 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|
|
651
1240
|
builder.add_rule("root", first_rule);
|
|
652
1241
|
}
|
|
653
1242
|
|
|
654
|
-
}
|
|
1243
|
+
});
|
|
655
1244
|
}
|
|
656
1245
|
return data;
|
|
657
1246
|
}
|
|
658
1247
|
|
|
659
|
-
static bool consume(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
|
|
660
|
-
auto expected_it = expected.begin();
|
|
661
|
-
auto tmp_it = it;
|
|
662
|
-
while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
|
|
663
|
-
++tmp_it;
|
|
664
|
-
++expected_it;
|
|
665
|
-
}
|
|
666
|
-
if (expected_it == expected.end()) {
|
|
667
|
-
it = tmp_it;
|
|
668
|
-
return true;
|
|
669
|
-
}
|
|
670
|
-
return false;
|
|
671
|
-
}
|
|
672
|
-
|
|
673
1248
|
static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) {
|
|
674
|
-
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
|
675
|
-
static std::regex close_regex(R"($|(?=>>>))");
|
|
1249
|
+
static const std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)");
|
|
1250
|
+
static const std::regex close_regex(R"($|(?=>>>))");
|
|
676
1251
|
|
|
677
1252
|
std::string content;
|
|
678
1253
|
auto it = input.begin();
|
|
679
1254
|
const auto end = input.end();
|
|
680
1255
|
|
|
681
|
-
if (
|
|
1256
|
+
if (parse_literal(it, end, "all\n")) {
|
|
682
1257
|
std::smatch match;
|
|
683
1258
|
if (std::regex_search(it, end, match, function_regex)) {
|
|
684
1259
|
auto fun_it = match.prefix().second;
|
|
@@ -693,7 +1268,7 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
|
|
|
693
1268
|
}
|
|
694
1269
|
// TODO: tighten & simplify.
|
|
695
1270
|
try {
|
|
696
|
-
auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex);
|
|
1271
|
+
auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex, /* allow_raw_python= */ true);
|
|
697
1272
|
res.content = content + res.content;
|
|
698
1273
|
return res;
|
|
699
1274
|
} catch (const std::exception & e) {
|
|
@@ -705,26 +1280,26 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
|
|
|
705
1280
|
}
|
|
706
1281
|
}
|
|
707
1282
|
|
|
708
|
-
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct
|
|
1283
|
+
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
709
1284
|
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
|
710
1285
|
common_chat_params data;
|
|
711
1286
|
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
|
712
1287
|
std::string python_code_argument_name;
|
|
713
1288
|
auto has_raw_python = false;
|
|
714
1289
|
|
|
715
|
-
data.grammar_lazy = inputs.tool_choice !=
|
|
1290
|
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
716
1291
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
717
1292
|
std::vector<std::string> tool_rules;
|
|
718
1293
|
foreach_function(inputs.tools, [&](const json & tool) {
|
|
719
|
-
const auto & function = tool
|
|
720
|
-
const auto & parameters = function
|
|
721
|
-
std::string name = function
|
|
1294
|
+
const auto & function = tool.at("function");
|
|
1295
|
+
const auto & parameters = function.at("parameters");
|
|
1296
|
+
std::string name = function.at("name");
|
|
722
1297
|
if (name == "python" || name == "ipython") {
|
|
723
1298
|
if (!parameters.contains("type")) {
|
|
724
1299
|
throw std::runtime_error("Missing type in python tool");
|
|
725
1300
|
}
|
|
726
1301
|
has_raw_python = true;
|
|
727
|
-
auto type = parameters.at("type");
|
|
1302
|
+
const auto & type = parameters.at("type");
|
|
728
1303
|
if (type == "object") {
|
|
729
1304
|
auto properties = parameters.at("properties");
|
|
730
1305
|
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
|
@@ -746,12 +1321,13 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
|
|
746
1321
|
});
|
|
747
1322
|
if (has_raw_python) {
|
|
748
1323
|
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
|
749
|
-
data.grammar_triggers.push_back({"<|python_tag|>"
|
|
1324
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
|
1325
|
+
data.preserved_tokens.push_back("<|python_tag|>");
|
|
750
1326
|
}
|
|
751
1327
|
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
|
752
1328
|
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
|
753
|
-
data.grammar_triggers.push_back({"<function="
|
|
754
|
-
}
|
|
1329
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
|
|
1330
|
+
});
|
|
755
1331
|
|
|
756
1332
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
|
757
1333
|
// TODO: if (has_raw_python)
|
|
@@ -760,38 +1336,37 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
|
|
760
1336
|
}
|
|
761
1337
|
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
|
|
762
1338
|
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
|
763
|
-
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
|
1339
|
+
static const std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
|
764
1340
|
std::smatch match;
|
|
765
1341
|
if (std::regex_search(input, match, python_tag_regex)) {
|
|
766
1342
|
auto code = match[1].str();
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
}
|
|
777
|
-
};
|
|
1343
|
+
common_chat_msg msg;
|
|
1344
|
+
msg.role = "assistant";
|
|
1345
|
+
msg.content = match.prefix().str();
|
|
1346
|
+
msg.tool_calls.push_back({
|
|
1347
|
+
/* .name = */ "python",
|
|
1348
|
+
/* .arguments = */ (json {{"code", code}}).dump(),
|
|
1349
|
+
/* .id = */ "",
|
|
1350
|
+
});
|
|
1351
|
+
return msg;
|
|
778
1352
|
}
|
|
779
|
-
static std::regex function_regex(R"(<function=(\w+)>)");
|
|
780
|
-
static std::regex close_regex(R"(</function>)");
|
|
1353
|
+
static const std::regex function_regex(R"(<function=(\w+)>)");
|
|
1354
|
+
static const std::regex close_regex(R"(</function>)");
|
|
781
1355
|
// TODO: tighten & simplify.
|
|
782
1356
|
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
|
783
1357
|
}
|
|
784
1358
|
|
|
785
|
-
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct
|
|
1359
|
+
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
786
1360
|
common_chat_params data;
|
|
787
1361
|
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
|
788
|
-
data.grammar_lazy = inputs.tool_choice !=
|
|
1362
|
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
789
1363
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
790
1364
|
std::vector<std::string> tool_rules;
|
|
1365
|
+
std::vector<std::string> tool_call_alts;
|
|
791
1366
|
foreach_function(inputs.tools, [&](const json & tool) {
|
|
792
|
-
const auto & function = tool
|
|
793
|
-
std::string name = function
|
|
794
|
-
auto parameters = function
|
|
1367
|
+
const auto & function = tool.at("function");
|
|
1368
|
+
std::string name = function.at("name");
|
|
1369
|
+
auto parameters = function.at("parameters");
|
|
795
1370
|
builder.resolve_refs(parameters);
|
|
796
1371
|
tool_rules.push_back(builder.add_schema(name + "-call", {
|
|
797
1372
|
{"type", "object"},
|
|
@@ -801,73 +1376,190 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
|
|
|
801
1376
|
}},
|
|
802
1377
|
{"required", json::array({"name", "arguments"})},
|
|
803
1378
|
}));
|
|
1379
|
+
tool_call_alts.push_back(builder.add_rule(
|
|
1380
|
+
name + "-function-tag",
|
|
1381
|
+
"\"<function\" ( \"=" + name + "\" | \" name=\\\"" + name + "\\\"\" ) \">\" space " +
|
|
1382
|
+
builder.add_schema(name + "-args", parameters) + " "
|
|
1383
|
+
"\"</function>\" space"));
|
|
1384
|
+
|
|
1385
|
+
data.grammar_triggers.push_back({
|
|
1386
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
|
1387
|
+
"<function=" + name + ">",
|
|
1388
|
+
});
|
|
1389
|
+
auto escaped_name = regex_escape(name);
|
|
1390
|
+
data.grammar_triggers.push_back({
|
|
1391
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
|
1392
|
+
"<function\\s+name\\s*=\\s*\"" + escaped_name + "\"",
|
|
1393
|
+
});
|
|
804
1394
|
});
|
|
805
|
-
auto
|
|
1395
|
+
auto any_tool_call = builder.add_rule("any_tool_call", "( " + string_join(tool_rules, " | ") + " ) space");
|
|
1396
|
+
std::vector<std::string> alt_tags {
|
|
1397
|
+
any_tool_call,
|
|
1398
|
+
"\"<tool_call>\" space " + any_tool_call + " \"</tool_call>\"",
|
|
1399
|
+
// The rest is just to accommodate common "good bad" outputs.
|
|
1400
|
+
"\"<function_call>\" space " + any_tool_call + " \"</function_call>\"",
|
|
1401
|
+
"\"<response>\" space " + any_tool_call + " \"</response>\"",
|
|
1402
|
+
"\"<tools>\" space " + any_tool_call + " \"</tools>\"",
|
|
1403
|
+
"\"<json>\" space " + any_tool_call + " \"</json>\"",
|
|
1404
|
+
"\"<xml>\" space " + any_tool_call + " \"</xml>\"",
|
|
1405
|
+
"\"<JSON>\" space " + any_tool_call + " \"</JSON>\"",
|
|
1406
|
+
};
|
|
1407
|
+
auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space");
|
|
1408
|
+
tool_call_alts.push_back(wrappable_tool_call);
|
|
1409
|
+
tool_call_alts.push_back(
|
|
1410
|
+
"( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space ");
|
|
1411
|
+
auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | "));
|
|
806
1412
|
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
|
807
|
-
data.grammar_triggers.push_back({"<tool_call>"
|
|
808
|
-
data.
|
|
809
|
-
|
|
1413
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<tool_call>"});
|
|
1414
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function"});
|
|
1415
|
+
// Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
|
|
1416
|
+
data.grammar_triggers.push_back({
|
|
1417
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
|
1418
|
+
"(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?\\s*\\{\\s*\"", //name\"\\s*:\\s*\"" + escaped_name + "\"",
|
|
1419
|
+
});
|
|
1420
|
+
data.preserved_tokens = {
|
|
1421
|
+
"<think>",
|
|
1422
|
+
"</think>",
|
|
1423
|
+
"<tool_call>",
|
|
1424
|
+
"</tool_call>",
|
|
1425
|
+
"<function",
|
|
1426
|
+
"<tools>",
|
|
1427
|
+
"</tools>",
|
|
1428
|
+
"<response>",
|
|
1429
|
+
"</response>",
|
|
1430
|
+
"<function_call>",
|
|
1431
|
+
"</function_call>",
|
|
1432
|
+
"<json>",
|
|
1433
|
+
"</json>",
|
|
1434
|
+
"<JSON>",
|
|
1435
|
+
"</JSON>",
|
|
1436
|
+
"```",
|
|
1437
|
+
"```json",
|
|
1438
|
+
"```xml",
|
|
1439
|
+
};
|
|
1440
|
+
});
|
|
810
1441
|
|
|
811
1442
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
|
812
|
-
data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
|
|
1443
|
+
data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING : COMMON_CHAT_FORMAT_HERMES_2_PRO;
|
|
813
1444
|
return data;
|
|
814
1445
|
}
|
|
815
|
-
static common_chat_msg common_chat_parse_hermes_2_pro(const std::string
|
|
816
|
-
|
|
817
|
-
std::regex
|
|
818
|
-
|
|
819
|
-
|
|
1446
|
+
static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input, bool extract_reasoning) {
|
|
1447
|
+
return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
|
|
1448
|
+
static const std::regex open_regex(
|
|
1449
|
+
"(?:"
|
|
1450
|
+
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
|
|
1451
|
+
"(<tool_call>" // match 2 (open_tag)
|
|
1452
|
+
"|<function_call>"
|
|
1453
|
+
"|<tool>"
|
|
1454
|
+
"|<tools>"
|
|
1455
|
+
"|<response>"
|
|
1456
|
+
"|<json>"
|
|
1457
|
+
"|<xml>"
|
|
1458
|
+
"|<JSON>"
|
|
1459
|
+
")?"
|
|
1460
|
+
"(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)" // match 3 (named tool call + rest)
|
|
1461
|
+
")"
|
|
1462
|
+
"|"
|
|
1463
|
+
"(?:<function=([^>]+)>" // match 4 (function name)
|
|
1464
|
+
"|<function name=\"([^\"]+)\">)" // match 5 (function name again)
|
|
1465
|
+
"([\\s\\S]*)" // match 6 (function arguments + rest)})"
|
|
1466
|
+
);
|
|
820
1467
|
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
if (rit == rend) {
|
|
825
|
-
return {
|
|
826
|
-
/* .role = */ "assistant",
|
|
827
|
-
/* .content = */ input,
|
|
828
|
-
/* .tool_calls = */ {},
|
|
829
|
-
};
|
|
830
|
-
}
|
|
1468
|
+
try {
|
|
1469
|
+
common_chat_msg msg;
|
|
1470
|
+
msg.role = "assistant";
|
|
831
1471
|
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
1472
|
+
std::string::const_iterator it = input.begin();
|
|
1473
|
+
const std::string::const_iterator end = input.end();
|
|
1474
|
+
std::smatch match;
|
|
835
1475
|
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
1476
|
+
while (it != end) {
|
|
1477
|
+
if (std::regex_search(it, end, match, open_regex)) {
|
|
1478
|
+
// Add content before the match
|
|
1479
|
+
msg.content += std::string(it, match[0].first);
|
|
1480
|
+
|
|
1481
|
+
auto block_start = match[1].str();
|
|
1482
|
+
std::string block_end = block_start.empty() ? "" : "```";
|
|
1483
|
+
|
|
1484
|
+
auto open_tag = match[2].str();
|
|
1485
|
+
std::string close_tag;
|
|
1486
|
+
|
|
1487
|
+
if (match[3].matched) {
|
|
1488
|
+
close_tag = open_tag.empty() ? "" : "</" + open_tag.substr(1);
|
|
1489
|
+
auto json_it = match[3].first;
|
|
1490
|
+
json tool_call;
|
|
1491
|
+
if (parse_json(json_it, end, tool_call) && tool_call.contains("name") && tool_call.contains("arguments")) {
|
|
1492
|
+
|
|
1493
|
+
msg.tool_calls.emplace_back(process_tool_call(tool_call));
|
|
1494
|
+
it = json_it; // Move iterator past parsed JSON
|
|
1495
|
+
|
|
1496
|
+
// Handle close tags
|
|
1497
|
+
consume_spaces(it, end);
|
|
1498
|
+
if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
|
|
1499
|
+
throw std::runtime_error("Failed to parse closing tag");
|
|
1500
|
+
}
|
|
1501
|
+
consume_spaces(it, end);
|
|
1502
|
+
if (!block_end.empty() && !parse_literal(it, end, block_end)) {
|
|
1503
|
+
throw std::runtime_error("Failed to parse block end");
|
|
1504
|
+
}
|
|
1505
|
+
consume_spaces(it, end);
|
|
1506
|
+
} else {
|
|
1507
|
+
// Not a valid tool call, treat as content
|
|
1508
|
+
msg.content += std::string(match[0].first, match[0].second);
|
|
1509
|
+
it = match[0].second;
|
|
1510
|
+
}
|
|
1511
|
+
} else {
|
|
1512
|
+
auto function_name = match[4].str();
|
|
1513
|
+
if (function_name.empty()) {
|
|
1514
|
+
function_name = match[5].str();
|
|
1515
|
+
}
|
|
1516
|
+
GGML_ASSERT(!function_name.empty());
|
|
1517
|
+
|
|
1518
|
+
close_tag = "</function>";
|
|
1519
|
+
// Start parsing from after the opening tags
|
|
1520
|
+
auto json_it = match[6].first;
|
|
1521
|
+
json arguments;
|
|
1522
|
+
if (parse_json(json_it, end, arguments)) {
|
|
1523
|
+
msg.tool_calls.emplace_back(process_tool_call({
|
|
1524
|
+
{"name", function_name},
|
|
1525
|
+
{"arguments", arguments},
|
|
1526
|
+
}));
|
|
1527
|
+
it = json_it; // Move iterator past parsed JSON
|
|
1528
|
+
|
|
1529
|
+
// Handle close tags
|
|
1530
|
+
consume_spaces(it, end);
|
|
1531
|
+
if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
|
|
1532
|
+
throw std::runtime_error("Failed to parse closing tag");
|
|
1533
|
+
}
|
|
1534
|
+
consume_spaces(it, end);
|
|
1535
|
+
if (!block_end.empty() && !parse_literal(it, end, block_end)) {
|
|
1536
|
+
throw std::runtime_error("Failed to parse block end");
|
|
1537
|
+
}
|
|
1538
|
+
consume_spaces(it, end);
|
|
1539
|
+
} else {
|
|
1540
|
+
// Not a valid tool call, treat as content
|
|
1541
|
+
msg.content += std::string(match[0].first, match[0].second);
|
|
1542
|
+
it = match[0].second;
|
|
1543
|
+
}
|
|
1544
|
+
}
|
|
1545
|
+
} else {
|
|
1546
|
+
// Add remaining content
|
|
1547
|
+
msg.content += std::string(it, end);
|
|
1548
|
+
break;
|
|
856
1549
|
}
|
|
857
|
-
break;
|
|
858
1550
|
}
|
|
1551
|
+
return msg;
|
|
1552
|
+
} catch (const std::exception & e) {
|
|
1553
|
+
LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
|
|
1554
|
+
common_chat_msg msg;
|
|
1555
|
+
msg.role = "assistant";
|
|
1556
|
+
msg.content = input;
|
|
1557
|
+
return msg;
|
|
859
1558
|
}
|
|
860
|
-
|
|
861
|
-
} catch (const std::exception & e) {
|
|
862
|
-
return {
|
|
863
|
-
/* .role = */ "assistant",
|
|
864
|
-
/* .content = */ input,
|
|
865
|
-
/* .tool_calls = */ {},
|
|
866
|
-
};
|
|
867
|
-
}
|
|
1559
|
+
});
|
|
868
1560
|
}
|
|
869
1561
|
|
|
870
|
-
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct
|
|
1562
|
+
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
871
1563
|
common_chat_params data;
|
|
872
1564
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
|
873
1565
|
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
@@ -878,62 +1570,177 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
|
|
|
878
1570
|
}
|
|
879
1571
|
data.grammar = json_schema_to_grammar(inputs.json_schema);
|
|
880
1572
|
} else {
|
|
881
|
-
data.grammar = inputs.grammar
|
|
1573
|
+
data.grammar = inputs.grammar;
|
|
882
1574
|
}
|
|
883
1575
|
return data;
|
|
884
1576
|
}
|
|
885
1577
|
|
|
886
|
-
common_chat_params
|
|
887
|
-
|
|
888
|
-
|
|
1578
|
+
static common_chat_params common_chat_templates_apply_jinja(
|
|
1579
|
+
const struct common_chat_templates * tmpls,
|
|
1580
|
+
const struct common_chat_templates_inputs & inputs)
|
|
1581
|
+
{
|
|
1582
|
+
templates_params params;
|
|
1583
|
+
params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
|
|
1584
|
+
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
|
|
1585
|
+
? *tmpls->template_tool_use
|
|
1586
|
+
: *tmpls->template_default;
|
|
1587
|
+
const auto & src = tmpl.source();
|
|
1588
|
+
const auto & caps = tmpl.original_caps();
|
|
1589
|
+
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
|
|
1590
|
+
params.add_generation_prompt = inputs.add_generation_prompt;
|
|
1591
|
+
params.extract_reasoning = inputs.extract_reasoning;
|
|
1592
|
+
params.tool_choice = inputs.tool_choice;
|
|
1593
|
+
params.grammar = inputs.grammar;
|
|
1594
|
+
if (!inputs.json_schema.empty()) {
|
|
1595
|
+
params.json_schema = json::parse(inputs.json_schema);
|
|
1596
|
+
}
|
|
889
1597
|
|
|
890
|
-
if (
|
|
891
|
-
|
|
1598
|
+
if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
|
|
1599
|
+
LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
|
|
1600
|
+
params.parallel_tool_calls = false;
|
|
1601
|
+
} else {
|
|
1602
|
+
params.parallel_tool_calls = inputs.parallel_tool_calls;
|
|
892
1603
|
}
|
|
893
1604
|
|
|
894
|
-
|
|
1605
|
+
if (params.tools.is_array()) {
|
|
1606
|
+
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
|
|
1607
|
+
throw std::runtime_error("Cannot specify grammar with tools");
|
|
1608
|
+
}
|
|
1609
|
+
if (caps.supports_tool_calls && !caps.supports_tools) {
|
|
1610
|
+
LOG_WRN("Template supports tool calls but does not natively describe tools. The fallback behaviour used may produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
|
|
1611
|
+
}
|
|
1612
|
+
}
|
|
1613
|
+
|
|
1614
|
+
// DeepSeek R1: use handler in all cases except json schema (thinking / tools).
|
|
1615
|
+
if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) {
|
|
1616
|
+
return common_chat_params_init_deepseek_r1(tmpl, params);
|
|
1617
|
+
}
|
|
1618
|
+
|
|
1619
|
+
// Command R7B: : use handler in all cases except json schema (thinking / tools).
|
|
1620
|
+
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) {
|
|
1621
|
+
return common_chat_params_init_command_r7b(tmpl, params);
|
|
1622
|
+
}
|
|
1623
|
+
|
|
1624
|
+
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
|
1625
|
+
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
|
|
1626
|
+
return common_chat_params_init_hermes_2_pro(tmpl, params);
|
|
1627
|
+
}
|
|
1628
|
+
|
|
1629
|
+
// Use generic handler when mixing tools + JSON schema.
|
|
1630
|
+
// TODO: support that mix in handlers below.
|
|
1631
|
+
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
|
1632
|
+
return common_chat_params_init_generic(tmpl, params);
|
|
1633
|
+
}
|
|
1634
|
+
|
|
1635
|
+
// Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
|
|
895
1636
|
if (src.find(">>>all") != std::string::npos) {
|
|
896
|
-
|
|
897
|
-
return common_chat_params_init_functionary_v3_2(tmpl, inputs);
|
|
1637
|
+
return common_chat_params_init_functionary_v3_2(tmpl, params);
|
|
898
1638
|
}
|
|
1639
|
+
|
|
1640
|
+
// Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
|
|
899
1641
|
if (src.find(" functools[") != std::string::npos) {
|
|
900
|
-
|
|
901
|
-
return common_chat_params_init_firefunction_v2(tmpl, inputs);
|
|
1642
|
+
return common_chat_params_init_firefunction_v2(tmpl, params);
|
|
902
1643
|
}
|
|
903
1644
|
|
|
904
|
-
|
|
905
|
-
|
|
1645
|
+
// Plain handler (no tools)
|
|
1646
|
+
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
|
1647
|
+
return common_chat_params_init_without_tools(tmpl, params);
|
|
906
1648
|
}
|
|
907
1649
|
|
|
908
|
-
|
|
909
|
-
return common_chat_params_init_hermes_2_pro(tmpl, inputs);
|
|
910
|
-
}
|
|
1650
|
+
// Functionary v3.1 (w/ tools)
|
|
911
1651
|
if (src.find("<|start_header_id|>") != std::string::npos
|
|
912
1652
|
&& src.find("<function=") != std::string::npos) {
|
|
913
|
-
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl,
|
|
1653
|
+
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
|
|
914
1654
|
}
|
|
1655
|
+
|
|
1656
|
+
// Llama 3.1, 3.2, 3.3 (w/ tools)
|
|
915
1657
|
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
|
916
1658
|
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
|
917
|
-
return common_chat_params_init_llama_3_1_tool_calls(tmpl,
|
|
918
|
-
}
|
|
919
|
-
if (src.find("<|tool▁calls▁begin|>") != std::string::npos) {
|
|
920
|
-
return common_chat_params_init_deepseek_r1(tmpl, inputs);
|
|
1659
|
+
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
|
|
921
1660
|
}
|
|
1661
|
+
|
|
1662
|
+
// Mistral Nemo (w/ tools)
|
|
922
1663
|
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
|
923
|
-
return common_chat_params_init_mistral_nemo(tmpl,
|
|
1664
|
+
return common_chat_params_init_mistral_nemo(tmpl, params);
|
|
1665
|
+
}
|
|
1666
|
+
|
|
1667
|
+
// Generic fallback
|
|
1668
|
+
return common_chat_params_init_generic(tmpl, params);
|
|
1669
|
+
}
|
|
1670
|
+
|
|
1671
|
+
// Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template.
|
|
1672
|
+
static common_chat_params common_chat_templates_apply_legacy(
|
|
1673
|
+
const struct common_chat_templates * tmpls,
|
|
1674
|
+
const struct common_chat_templates_inputs & inputs)
|
|
1675
|
+
{
|
|
1676
|
+
int alloc_size = 0;
|
|
1677
|
+
std::vector<llama_chat_message> chat;
|
|
1678
|
+
std::vector<std::string> contents;
|
|
1679
|
+
for (const auto & msg : inputs.messages) {
|
|
1680
|
+
auto content = msg.content;
|
|
1681
|
+
for (const auto & part : msg.content_parts) {
|
|
1682
|
+
if (part.type != "text") {
|
|
1683
|
+
LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
|
|
1684
|
+
continue;
|
|
1685
|
+
}
|
|
1686
|
+
if (!content.empty()) {
|
|
1687
|
+
content += "\n";;
|
|
1688
|
+
}
|
|
1689
|
+
content += part.text;
|
|
1690
|
+
}
|
|
1691
|
+
contents.emplace_back(std::move(content));
|
|
1692
|
+
}
|
|
1693
|
+
for (size_t i = 0; i < contents.size(); ++i) {
|
|
1694
|
+
const auto & msg = inputs.messages[i];
|
|
1695
|
+
const auto & content = contents[i];
|
|
1696
|
+
chat.push_back({msg.role.c_str(), content.c_str()});
|
|
1697
|
+
alloc_size += (msg.role.size() + content.size()) * 1.25;
|
|
924
1698
|
}
|
|
925
|
-
|
|
926
|
-
|
|
1699
|
+
|
|
1700
|
+
std::vector<char> buf(alloc_size);
|
|
1701
|
+
|
|
1702
|
+
// run the first time to get the total output length
|
|
1703
|
+
const auto & src = tmpls->template_default->source();
|
|
1704
|
+
int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
|
1705
|
+
|
|
1706
|
+
// error: chat template is not supported
|
|
1707
|
+
if (res < 0) {
|
|
1708
|
+
// if the custom "tmpl" is not supported, we throw an error
|
|
1709
|
+
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
|
1710
|
+
throw std::runtime_error("this custom template is not supported");
|
|
927
1711
|
}
|
|
928
|
-
|
|
1712
|
+
|
|
1713
|
+
// if it turns out that our buffer is too small, we resize it
|
|
1714
|
+
if ((size_t) res > buf.size()) {
|
|
1715
|
+
buf.resize(res);
|
|
1716
|
+
res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
|
1717
|
+
}
|
|
1718
|
+
|
|
1719
|
+
common_chat_params params;
|
|
1720
|
+
params.prompt = std::string(buf.data(), res);
|
|
1721
|
+
if (!inputs.json_schema.empty()) {
|
|
1722
|
+
params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema));
|
|
1723
|
+
} else {
|
|
1724
|
+
params.grammar = inputs.grammar;
|
|
1725
|
+
}
|
|
1726
|
+
return params;
|
|
1727
|
+
}
|
|
1728
|
+
|
|
1729
|
+
common_chat_params common_chat_templates_apply(
|
|
1730
|
+
const struct common_chat_templates * tmpls,
|
|
1731
|
+
const struct common_chat_templates_inputs & inputs)
|
|
1732
|
+
{
|
|
1733
|
+
GGML_ASSERT(tmpls != nullptr);
|
|
1734
|
+
return inputs.use_jinja
|
|
1735
|
+
? common_chat_templates_apply_jinja(tmpls, inputs)
|
|
1736
|
+
: common_chat_templates_apply_legacy(tmpls, inputs);
|
|
929
1737
|
}
|
|
930
1738
|
|
|
931
1739
|
static common_chat_msg common_chat_parse_content_only(const std::string & input) {
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
};
|
|
1740
|
+
common_chat_msg msg;
|
|
1741
|
+
msg.role = "assistant";
|
|
1742
|
+
msg.content = input;
|
|
1743
|
+
return msg;
|
|
937
1744
|
}
|
|
938
1745
|
|
|
939
1746
|
common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
|
|
@@ -949,17 +1756,23 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
|
|
|
949
1756
|
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
|
|
950
1757
|
return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true);
|
|
951
1758
|
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
|
|
952
|
-
return common_chat_parse_deepseek_r1(input);
|
|
1759
|
+
return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ false);
|
|
1760
|
+
case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING:
|
|
1761
|
+
return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ true);
|
|
953
1762
|
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
|
|
954
1763
|
return common_chat_parse_functionary_v3_2(input);
|
|
955
1764
|
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
|
|
956
1765
|
return common_chat_parse_functionary_v3_1_llama_3_1(input);
|
|
957
1766
|
case COMMON_CHAT_FORMAT_HERMES_2_PRO:
|
|
958
|
-
return common_chat_parse_hermes_2_pro(input);
|
|
1767
|
+
return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ false);
|
|
1768
|
+
case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING:
|
|
1769
|
+
return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ true);
|
|
959
1770
|
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
|
|
960
1771
|
return common_chat_parse_firefunction_v2(input);
|
|
961
1772
|
case COMMON_CHAT_FORMAT_COMMAND_R7B:
|
|
962
|
-
return common_chat_parse_command_r7b(input);
|
|
1773
|
+
return common_chat_parse_command_r7b(input, /* extract_reasoning= */ false);
|
|
1774
|
+
case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING:
|
|
1775
|
+
return common_chat_parse_command_r7b(input, /* extract_reasoning= */ true);
|
|
963
1776
|
default:
|
|
964
1777
|
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
|
|
965
1778
|
}
|