cui-llama.rn 1.4.4 → 1.4.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +2 -2
- package/android/src/main/jni.cpp +12 -10
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/cpp/chat-template.hpp +529 -529
- package/cpp/chat.cpp +959 -265
- package/cpp/chat.h +135 -0
- package/cpp/common.cpp +2064 -1996
- package/cpp/common.h +700 -744
- package/cpp/ggml-alloc.c +1039 -1030
- package/cpp/ggml-alloc.h +1 -1
- package/cpp/ggml-backend-impl.h +255 -255
- package/cpp/ggml-backend-reg.cpp +586 -582
- package/cpp/ggml-backend.cpp +2004 -2002
- package/cpp/ggml-backend.h +354 -354
- package/cpp/ggml-common.h +1851 -1851
- package/cpp/ggml-cpp.h +39 -39
- package/cpp/ggml-cpu-aarch64.cpp +4248 -4247
- package/cpp/ggml-cpu-aarch64.h +8 -8
- package/cpp/ggml-cpu-impl.h +531 -380
- package/cpp/ggml-cpu-quants.c +12527 -11517
- package/cpp/ggml-cpu-traits.cpp +36 -36
- package/cpp/ggml-cpu-traits.h +38 -38
- package/cpp/ggml-cpu.c +15766 -14485
- package/cpp/ggml-cpu.cpp +655 -633
- package/cpp/ggml-cpu.h +138 -135
- package/cpp/ggml-impl.h +567 -567
- package/cpp/ggml-metal-impl.h +235 -0
- package/cpp/ggml-metal.h +66 -66
- package/cpp/ggml-metal.m +5146 -5002
- package/cpp/ggml-opt.cpp +854 -854
- package/cpp/ggml-opt.h +216 -216
- package/cpp/ggml-quants.c +5238 -5238
- package/cpp/ggml-threading.h +14 -14
- package/cpp/ggml.c +6529 -6524
- package/cpp/ggml.h +2198 -2194
- package/cpp/gguf.cpp +1329 -1329
- package/cpp/gguf.h +202 -202
- package/cpp/json-schema-to-grammar.cpp +1024 -1025
- package/cpp/json-schema-to-grammar.h +21 -22
- package/cpp/json.hpp +24766 -24766
- package/cpp/llama-adapter.cpp +347 -347
- package/cpp/llama-adapter.h +74 -74
- package/cpp/llama-arch.cpp +1513 -1492
- package/cpp/llama-arch.h +403 -402
- package/cpp/llama-batch.cpp +368 -368
- package/cpp/llama-batch.h +88 -88
- package/cpp/llama-chat.cpp +588 -587
- package/cpp/llama-chat.h +53 -53
- package/cpp/llama-context.cpp +1775 -1775
- package/cpp/llama-context.h +128 -128
- package/cpp/llama-cparams.cpp +1 -1
- package/cpp/llama-cparams.h +37 -37
- package/cpp/llama-cpp.h +30 -30
- package/cpp/llama-grammar.cpp +1219 -1219
- package/cpp/llama-grammar.h +173 -164
- package/cpp/llama-hparams.cpp +71 -71
- package/cpp/llama-hparams.h +139 -139
- package/cpp/llama-impl.cpp +167 -167
- package/cpp/llama-impl.h +61 -61
- package/cpp/llama-kv-cache.cpp +718 -718
- package/cpp/llama-kv-cache.h +219 -218
- package/cpp/llama-mmap.cpp +600 -590
- package/cpp/llama-mmap.h +68 -68
- package/cpp/llama-model-loader.cpp +1124 -1124
- package/cpp/llama-model-loader.h +167 -167
- package/cpp/llama-model.cpp +4087 -4023
- package/cpp/llama-model.h +370 -370
- package/cpp/llama-sampling.cpp +2558 -2525
- package/cpp/llama-sampling.h +32 -32
- package/cpp/llama-vocab.cpp +3264 -3252
- package/cpp/llama-vocab.h +125 -125
- package/cpp/llama.cpp +10284 -10137
- package/cpp/llama.h +1354 -1340
- package/cpp/log.cpp +393 -423
- package/cpp/log.h +132 -132
- package/cpp/minja/chat-template.hpp +529 -0
- package/cpp/minja/minja.hpp +2915 -0
- package/cpp/minja.hpp +2915 -2883
- package/cpp/rn-llama.cpp +20 -37
- package/cpp/rn-llama.h +12 -2
- package/cpp/sampling.cpp +570 -532
- package/cpp/sgemm.cpp +2598 -2598
- package/cpp/sgemm.h +14 -14
- package/cpp/speculative.cpp +278 -277
- package/cpp/speculative.h +28 -28
- package/package.json +1 -1
- package/android/src/main/build-arm64/CMakeCache.txt +0 -429
- package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCCompiler.cmake +0 -81
- package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCXXCompiler.cmake +0 -101
- package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_C.bin +0 -0
- package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_CXX.bin +0 -0
- package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeSystem.cmake +0 -15
- package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.c +0 -904
- package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.cpp +0 -919
- package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/CMakeConfigureLog.yaml +0 -431
- package/android/src/main/build-arm64/CMakeFiles/CMakeDirectoryInformation.cmake +0 -16
- package/android/src/main/build-arm64/CMakeFiles/Makefile.cmake +0 -165
- package/android/src/main/build-arm64/CMakeFiles/Makefile2 +0 -297
- package/android/src/main/build-arm64/CMakeFiles/Progress/1 +0 -1
- package/android/src/main/build-arm64/CMakeFiles/Progress/2 +0 -1
- package/android/src/main/build-arm64/CMakeFiles/Progress/3 +0 -1
- package/android/src/main/build-arm64/CMakeFiles/Progress/4 +0 -1
- package/android/src/main/build-arm64/CMakeFiles/Progress/5 +0 -1
- package/android/src/main/build-arm64/CMakeFiles/Progress/6 +0 -1
- package/android/src/main/build-arm64/CMakeFiles/Progress/count.txt +0 -1
- package/android/src/main/build-arm64/CMakeFiles/TargetDirectories.txt +0 -8
- package/android/src/main/build-arm64/CMakeFiles/cmake.check_cache +0 -1
- package/android/src/main/build-arm64/CMakeFiles/progress.marks +0 -1
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o.d +0 -58
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o.d +0 -756
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o.d +0 -709
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o.d +0 -714
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o.d +0 -62
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o.d +0 -708
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o.d +0 -113
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o.d +0 -713
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o.d +0 -763
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o.d +0 -61
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o.d +0 -707
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o.d +0 -104
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o.d +0 -714
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o +0 -0
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o.d +0 -723
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/DependInfo.cmake +0 -62
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/build.make +0 -722
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/cmake_clean.cmake +0 -89
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.make +0 -2
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.ts +0 -2
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/depend.make +0 -2
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/flags.make +0 -17
- package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/progress.make +0 -41
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/DependInfo.cmake +0 -62
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/build.make +0 -722
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/cmake_clean.cmake +0 -89
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.make +0 -2
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.ts +0 -2
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/depend.make +0 -2
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/flags.make +0 -17
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/progress.make +0 -41
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/DependInfo.cmake +0 -62
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/build.make +0 -722
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/cmake_clean.cmake +0 -89
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.make +0 -2
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.ts +0 -2
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/depend.make +0 -2
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/flags.make +0 -17
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/progress.make +0 -41
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/DependInfo.cmake +0 -62
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/build.make +0 -722
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/cmake_clean.cmake +0 -89
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.make +0 -2
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.ts +0 -2
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/depend.make +0 -2
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/flags.make +0 -17
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/progress.make +0 -41
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/DependInfo.cmake +0 -62
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/build.make +0 -722
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/cmake_clean.cmake +0 -89
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.make +0 -2
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.ts +0 -2
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/depend.make +0 -2
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/flags.make +0 -17
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/progress.make +0 -41
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/DependInfo.cmake +0 -62
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/build.make +0 -722
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/cmake_clean.cmake +0 -89
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.make +0 -2
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.ts +0 -2
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/depend.make +0 -2
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/flags.make +0 -17
- package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/progress.make +0 -41
- package/android/src/main/build-arm64/Makefile +0 -1862
- package/android/src/main/build-arm64/cmake_install.cmake +0 -66
- package/cpp/chat.hpp +0 -55
- package/cpp/rn-llama.hpp +0 -913
package/cpp/chat.cpp
CHANGED
@@ -1,9 +1,437 @@
|
|
1
|
-
#include "chat.
|
2
|
-
#include "chat-template.hpp"
|
1
|
+
#include "chat.h"
|
3
2
|
#include "json-schema-to-grammar.h"
|
4
3
|
#include "log.h"
|
4
|
+
#include "chat-template.hpp"
|
5
5
|
#include "minja.hpp"
|
6
6
|
|
7
|
+
#include <optional>
|
8
|
+
|
9
|
+
typedef minja::chat_template common_chat_template;
|
10
|
+
|
11
|
+
struct common_chat_templates {
|
12
|
+
bool has_explicit_template; // Model had builtin template or template overridde was specified.
|
13
|
+
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
|
14
|
+
std::unique_ptr<common_chat_template> template_tool_use;
|
15
|
+
};
|
16
|
+
|
17
|
+
struct templates_params {
|
18
|
+
json messages;
|
19
|
+
json tools;
|
20
|
+
common_chat_tool_choice tool_choice;
|
21
|
+
json json_schema;
|
22
|
+
bool parallel_tool_calls;
|
23
|
+
bool stream;
|
24
|
+
std::string grammar;
|
25
|
+
bool add_generation_prompt = true;
|
26
|
+
bool extract_reasoning = true;
|
27
|
+
};
|
28
|
+
|
29
|
+
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
|
30
|
+
if (tool_choice == "auto") {
|
31
|
+
return COMMON_CHAT_TOOL_CHOICE_AUTO;
|
32
|
+
}
|
33
|
+
if (tool_choice == "none") {
|
34
|
+
return COMMON_CHAT_TOOL_CHOICE_NONE;
|
35
|
+
}
|
36
|
+
if (tool_choice == "required") {
|
37
|
+
return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
38
|
+
}
|
39
|
+
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
|
40
|
+
}
|
41
|
+
|
42
|
+
template <>
|
43
|
+
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
|
44
|
+
std::vector<common_chat_msg> msgs;
|
45
|
+
|
46
|
+
try {
|
47
|
+
|
48
|
+
if (!messages.is_array()) {
|
49
|
+
throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
|
50
|
+
}
|
51
|
+
|
52
|
+
for (const auto & message : messages) {
|
53
|
+
if (!message.is_object()) {
|
54
|
+
throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
|
55
|
+
}
|
56
|
+
|
57
|
+
common_chat_msg msg;
|
58
|
+
if (!message.contains("role")) {
|
59
|
+
throw std::runtime_error("Missing 'role' in message: " + message.dump());
|
60
|
+
}
|
61
|
+
msg.role = message.at("role");
|
62
|
+
|
63
|
+
auto has_content = message.contains("content");
|
64
|
+
auto has_tool_calls = message.contains("tool_calls");
|
65
|
+
if (has_content) {
|
66
|
+
const auto & content = message.at("content");
|
67
|
+
if (content.is_string()) {
|
68
|
+
msg.content = content;
|
69
|
+
} else if (content.is_array()) {
|
70
|
+
for (const auto & part : content) {
|
71
|
+
if (!part.contains("type")) {
|
72
|
+
throw std::runtime_error("Missing content part type: " + part.dump());
|
73
|
+
}
|
74
|
+
const auto & type = part.at("type");
|
75
|
+
if (type != "text") {
|
76
|
+
throw std::runtime_error("Unsupported content part type: " + type.dump());
|
77
|
+
}
|
78
|
+
common_chat_msg_content_part msg_part;
|
79
|
+
msg_part.type = type;
|
80
|
+
msg_part.text = part.at("text");
|
81
|
+
msg.content_parts.push_back(msg_part);
|
82
|
+
}
|
83
|
+
} else if (!content.is_null()) {
|
84
|
+
throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
|
85
|
+
}
|
86
|
+
}
|
87
|
+
if (has_tool_calls) {
|
88
|
+
for (const auto & tool_call : message.at("tool_calls")) {
|
89
|
+
common_chat_tool_call tc;
|
90
|
+
if (!tool_call.contains("type")) {
|
91
|
+
throw std::runtime_error("Missing tool call type: " + tool_call.dump());
|
92
|
+
}
|
93
|
+
const auto & type = tool_call.at("type");
|
94
|
+
if (type != "function") {
|
95
|
+
throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
|
96
|
+
}
|
97
|
+
if (!tool_call.contains("function")) {
|
98
|
+
throw std::runtime_error("Missing tool call function: " + tool_call.dump());
|
99
|
+
}
|
100
|
+
const auto & fc = tool_call.at("function");
|
101
|
+
if (!fc.contains("name")) {
|
102
|
+
throw std::runtime_error("Missing tool call name: " + tool_call.dump());
|
103
|
+
}
|
104
|
+
tc.name = fc.at("name");
|
105
|
+
tc.arguments = fc.at("arguments");
|
106
|
+
if (tool_call.contains("id")) {
|
107
|
+
tc.id = tool_call.at("id");
|
108
|
+
}
|
109
|
+
msg.tool_calls.push_back(tc);
|
110
|
+
}
|
111
|
+
}
|
112
|
+
if (!has_content && !has_tool_calls) {
|
113
|
+
throw std::runtime_error("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)");
|
114
|
+
}
|
115
|
+
if (message.contains("reasoning_content")) {
|
116
|
+
msg.reasoning_content = message.at("reasoning_content");
|
117
|
+
}
|
118
|
+
if (message.contains("name")) {
|
119
|
+
msg.tool_name = message.at("name");
|
120
|
+
}
|
121
|
+
if (message.contains("tool_call_id")) {
|
122
|
+
msg.tool_call_id = message.at("tool_call_id");
|
123
|
+
}
|
124
|
+
|
125
|
+
msgs.push_back(msg);
|
126
|
+
}
|
127
|
+
} catch (const std::exception & e) {
|
128
|
+
throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2));
|
129
|
+
}
|
130
|
+
|
131
|
+
return msgs;
|
132
|
+
}
|
133
|
+
|
134
|
+
template <>
|
135
|
+
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
|
136
|
+
json messages = json::array();
|
137
|
+
for (const auto & msg : msgs) {
|
138
|
+
if (!msg.content.empty() && !msg.content_parts.empty()) {
|
139
|
+
throw std::runtime_error("Cannot specify both content and content_parts");
|
140
|
+
}
|
141
|
+
json jmsg {
|
142
|
+
{"role", msg.role},
|
143
|
+
};
|
144
|
+
if (!msg.content.empty()) {
|
145
|
+
jmsg["content"] = msg.content;
|
146
|
+
} else if (!msg.content_parts.empty()) {
|
147
|
+
if (concat_typed_text) {
|
148
|
+
std::string text;
|
149
|
+
for (const auto & part : msg.content_parts) {
|
150
|
+
if (part.type != "text") {
|
151
|
+
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
|
152
|
+
continue;
|
153
|
+
}
|
154
|
+
if (!text.empty()) {
|
155
|
+
text += '\n';
|
156
|
+
}
|
157
|
+
text += part.text;
|
158
|
+
}
|
159
|
+
jmsg["content"] = text;
|
160
|
+
} else {
|
161
|
+
auto & parts = jmsg["content"] = json::array();
|
162
|
+
for (const auto & part : msg.content_parts) {
|
163
|
+
parts.push_back({
|
164
|
+
{"type", part.type},
|
165
|
+
{"text", part.text},
|
166
|
+
});
|
167
|
+
}
|
168
|
+
}
|
169
|
+
} else {
|
170
|
+
jmsg["content"] = json(); // null
|
171
|
+
}
|
172
|
+
if (!msg.reasoning_content.empty()) {
|
173
|
+
jmsg["reasoning_content"] = msg.reasoning_content;
|
174
|
+
}
|
175
|
+
if (!msg.tool_name.empty()) {
|
176
|
+
jmsg["name"] = msg.tool_name;
|
177
|
+
}
|
178
|
+
if (!msg.tool_call_id.empty()) {
|
179
|
+
jmsg["tool_call_id"] = msg.tool_call_id;
|
180
|
+
}
|
181
|
+
if (!msg.tool_calls.empty()) {
|
182
|
+
auto & tool_calls = jmsg["tool_calls"] = json::array();
|
183
|
+
for (const auto & tool_call : msg.tool_calls) {
|
184
|
+
json tc {
|
185
|
+
{"type", "function"},
|
186
|
+
{"function", {
|
187
|
+
{"name", tool_call.name},
|
188
|
+
{"arguments", tool_call.arguments},
|
189
|
+
}},
|
190
|
+
};
|
191
|
+
if (!tool_call.id.empty()) {
|
192
|
+
tc["id"] = tool_call.id;
|
193
|
+
}
|
194
|
+
tool_calls.push_back(tc);
|
195
|
+
}
|
196
|
+
}
|
197
|
+
messages.push_back(jmsg);
|
198
|
+
}
|
199
|
+
return messages;
|
200
|
+
}
|
201
|
+
|
202
|
+
template <>
|
203
|
+
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
|
204
|
+
return common_chat_msgs_parse_oaicompat(json::parse(messages));
|
205
|
+
}
|
206
|
+
|
207
|
+
template <>
|
208
|
+
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
|
209
|
+
std::vector<common_chat_tool> result;
|
210
|
+
|
211
|
+
try {
|
212
|
+
if (!tools.is_null()) {
|
213
|
+
if (!tools.is_array()) {
|
214
|
+
throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
|
215
|
+
}
|
216
|
+
for (const auto & tool : tools) {
|
217
|
+
if (!tool.contains("type")) {
|
218
|
+
throw std::runtime_error("Missing tool type: " + tool.dump());
|
219
|
+
}
|
220
|
+
const auto & type = tool.at("type");
|
221
|
+
if (!type.is_string() || type != "function") {
|
222
|
+
throw std::runtime_error("Unsupported tool type: " + tool.dump());
|
223
|
+
}
|
224
|
+
if (!tool.contains("function")) {
|
225
|
+
throw std::runtime_error("Missing tool function: " + tool.dump());
|
226
|
+
}
|
227
|
+
|
228
|
+
const auto & function = tool.at("function");
|
229
|
+
result.push_back({
|
230
|
+
/* .name = */ function.at("name"),
|
231
|
+
/* .description = */ function.at("description"),
|
232
|
+
/* .parameters = */ function.at("parameters").dump(),
|
233
|
+
});
|
234
|
+
}
|
235
|
+
}
|
236
|
+
} catch (const std::exception & e) {
|
237
|
+
throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2));
|
238
|
+
}
|
239
|
+
|
240
|
+
return result;
|
241
|
+
}
|
242
|
+
|
243
|
+
template <>
|
244
|
+
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
|
245
|
+
return common_chat_tools_parse_oaicompat(json::parse(tools));
|
246
|
+
}
|
247
|
+
|
248
|
+
template <>
|
249
|
+
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
|
250
|
+
if (tools.empty()) {
|
251
|
+
return json();
|
252
|
+
}
|
253
|
+
|
254
|
+
auto result = json::array();
|
255
|
+
for (const auto & tool : tools) {
|
256
|
+
result.push_back({
|
257
|
+
{"type", "function"},
|
258
|
+
{"function", {
|
259
|
+
{"name", tool.name},
|
260
|
+
{"description", tool.description},
|
261
|
+
{"parameters", json::parse(tool.parameters)},
|
262
|
+
}},
|
263
|
+
});
|
264
|
+
}
|
265
|
+
return result;
|
266
|
+
}
|
267
|
+
|
268
|
+
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
269
|
+
if (use_jinja) {
|
270
|
+
try {
|
271
|
+
common_chat_msg msg;
|
272
|
+
msg.role = "user";
|
273
|
+
msg.content = "test";
|
274
|
+
|
275
|
+
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);
|
276
|
+
|
277
|
+
common_chat_templates_inputs inputs;
|
278
|
+
inputs.messages = {msg};
|
279
|
+
|
280
|
+
common_chat_templates_apply(tmpls.get(), inputs);
|
281
|
+
return true;
|
282
|
+
} catch (const std::exception & e) {
|
283
|
+
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
284
|
+
return false;
|
285
|
+
}
|
286
|
+
}
|
287
|
+
llama_chat_message chat[] = {{"user", "test"}};
|
288
|
+
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
|
289
|
+
return res >= 0;
|
290
|
+
}
|
291
|
+
|
292
|
+
std::string common_chat_format_single(
|
293
|
+
const struct common_chat_templates * tmpls,
|
294
|
+
const std::vector<common_chat_msg> & past_msg,
|
295
|
+
const common_chat_msg & new_msg,
|
296
|
+
bool add_ass,
|
297
|
+
bool use_jinja) {
|
298
|
+
|
299
|
+
common_chat_templates_inputs inputs;
|
300
|
+
inputs.use_jinja = use_jinja;
|
301
|
+
|
302
|
+
std::string fmt_past_msg;
|
303
|
+
if (!past_msg.empty()) {
|
304
|
+
inputs.messages = past_msg;
|
305
|
+
inputs.add_generation_prompt = false;
|
306
|
+
fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt;
|
307
|
+
}
|
308
|
+
std::ostringstream ss;
|
309
|
+
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
310
|
+
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
311
|
+
ss << "\n";
|
312
|
+
};
|
313
|
+
// format chat with new_msg
|
314
|
+
inputs.messages.push_back(new_msg);
|
315
|
+
inputs.add_generation_prompt = add_ass;
|
316
|
+
auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt;
|
317
|
+
// get the diff part
|
318
|
+
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
319
|
+
return ss.str();
|
320
|
+
}
|
321
|
+
|
322
|
+
std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
|
323
|
+
common_chat_templates_inputs inputs;
|
324
|
+
inputs.use_jinja = use_jinja;
|
325
|
+
auto add_simple_msg = [&](auto role, auto content) {
|
326
|
+
common_chat_msg msg;
|
327
|
+
msg.role = role;
|
328
|
+
msg.content = content;
|
329
|
+
inputs.messages.push_back(msg);
|
330
|
+
};
|
331
|
+
add_simple_msg("system", "You are a helpful assistant");
|
332
|
+
add_simple_msg("user", "Hello");
|
333
|
+
add_simple_msg("assistant", "Hi there");
|
334
|
+
add_simple_msg("user", "How are you?");
|
335
|
+
return common_chat_templates_apply(tmpls, inputs).prompt;
|
336
|
+
}
|
337
|
+
|
338
|
+
#define CHATML_TEMPLATE_SRC \
|
339
|
+
"{%- for message in messages -%}\n" \
|
340
|
+
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
|
341
|
+
"{%- endfor -%}\n" \
|
342
|
+
"{%- if add_generation_prompt -%}\n" \
|
343
|
+
" {{- '<|im_start|>assistant\n' -}}\n" \
|
344
|
+
"{%- endif -%}"
|
345
|
+
|
346
|
+
void common_chat_templates_free(struct common_chat_templates * tmpls) {
|
347
|
+
delete tmpls;
|
348
|
+
}
|
349
|
+
|
350
|
+
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) {
|
351
|
+
return tmpls->has_explicit_template;
|
352
|
+
}
|
353
|
+
|
354
|
+
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
|
355
|
+
if (variant != nullptr) {
|
356
|
+
if (strcmp(variant, "tool_use") == 0) {
|
357
|
+
if (tmpls->template_tool_use) {
|
358
|
+
return tmpls->template_tool_use->source().c_str();
|
359
|
+
}
|
360
|
+
return nullptr;
|
361
|
+
} else {
|
362
|
+
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
|
363
|
+
}
|
364
|
+
}
|
365
|
+
return tmpls->template_default->source().c_str();
|
366
|
+
}
|
367
|
+
|
368
|
+
common_chat_templates_ptr common_chat_templates_init(
|
369
|
+
const struct llama_model * model,
|
370
|
+
const std::string & chat_template_override,
|
371
|
+
const std::string & bos_token_override,
|
372
|
+
const std::string & eos_token_override)
|
373
|
+
{
|
374
|
+
std::string default_template_src;
|
375
|
+
std::string template_tool_use_src;
|
376
|
+
|
377
|
+
bool has_explicit_template = !chat_template_override.empty();
|
378
|
+
if (chat_template_override.empty()) {
|
379
|
+
LM_GGML_ASSERT(model != nullptr);
|
380
|
+
const auto * str = llama_model_chat_template(model, /* name */ nullptr);
|
381
|
+
if (str) {
|
382
|
+
default_template_src = str;
|
383
|
+
has_explicit_template = true;
|
384
|
+
}
|
385
|
+
str = llama_model_chat_template(model, /* name */ "tool_use");
|
386
|
+
if (str) {
|
387
|
+
template_tool_use_src = str;
|
388
|
+
has_explicit_template = true;
|
389
|
+
}
|
390
|
+
} else {
|
391
|
+
default_template_src = chat_template_override;
|
392
|
+
}
|
393
|
+
if (default_template_src.empty() || default_template_src == "chatml") {
|
394
|
+
if (!template_tool_use_src.empty()) {
|
395
|
+
default_template_src = template_tool_use_src;
|
396
|
+
} else {
|
397
|
+
default_template_src = CHATML_TEMPLATE_SRC;
|
398
|
+
}
|
399
|
+
}
|
400
|
+
std::string token_bos = bos_token_override;
|
401
|
+
std::string token_eos = eos_token_override;
|
402
|
+
if (model) {
|
403
|
+
const auto * vocab = llama_model_get_vocab(model);
|
404
|
+
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
|
405
|
+
if (token == LLAMA_TOKEN_NULL) {
|
406
|
+
if (default_template_src.find(jinja_variable_name) != std::string::npos
|
407
|
+
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
|
408
|
+
LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name);
|
409
|
+
}
|
410
|
+
return std::string();
|
411
|
+
}
|
412
|
+
return common_token_to_piece(vocab, token, true);
|
413
|
+
};
|
414
|
+
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
|
415
|
+
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
|
416
|
+
}
|
417
|
+
common_chat_templates_ptr tmpls(new common_chat_templates());
|
418
|
+
tmpls->has_explicit_template = has_explicit_template;
|
419
|
+
try {
|
420
|
+
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
|
421
|
+
} catch (const std::exception & e) {
|
422
|
+
LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
|
423
|
+
tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
|
424
|
+
}
|
425
|
+
if (!template_tool_use_src.empty()) {
|
426
|
+
try {
|
427
|
+
tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
|
428
|
+
} catch (const std::exception & e) {
|
429
|
+
LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
|
430
|
+
}
|
431
|
+
}
|
432
|
+
return tmpls;
|
433
|
+
}
|
434
|
+
|
7
435
|
std::string common_chat_format_name(common_chat_format format) {
|
8
436
|
switch (format) {
|
9
437
|
case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only";
|
@@ -17,6 +445,7 @@ std::string common_chat_format_name(common_chat_format format) {
|
|
17
445
|
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
|
18
446
|
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
|
19
447
|
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
|
448
|
+
case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING: return "Hermes 2 Pro (extract reasoning)";
|
20
449
|
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
|
21
450
|
case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)";
|
22
451
|
default:
|
@@ -24,12 +453,6 @@ std::string common_chat_format_name(common_chat_format format) {
|
|
24
453
|
}
|
25
454
|
}
|
26
455
|
|
27
|
-
const common_grammar_options grammar_options {
|
28
|
-
/* .dotall = */ false,
|
29
|
-
/* .compact_spaces = */ false,
|
30
|
-
// /* .compact_spaces = */ true,
|
31
|
-
};
|
32
|
-
|
33
456
|
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
|
34
457
|
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
35
458
|
struct json_error_locator : public nlohmann::json_sax<json> {
|
@@ -38,22 +461,22 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
|
|
38
461
|
|
39
462
|
json_error_locator() : position(0), found_error(false) {}
|
40
463
|
|
41
|
-
bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
|
464
|
+
bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT
|
42
465
|
this->position = position - 1;
|
43
466
|
this->found_error = true;
|
44
467
|
return false;
|
45
468
|
}
|
46
|
-
bool null() override { return true; }
|
47
|
-
bool boolean(bool) override { return true; }
|
48
|
-
bool number_integer(number_integer_t) override { return true; }
|
49
|
-
bool number_unsigned(number_unsigned_t) override { return true; }
|
50
|
-
bool number_float(number_float_t, const string_t &) override { return true; }
|
51
|
-
bool string(string_t &) override { return true; }
|
52
|
-
bool binary(binary_t &) override { return true; }
|
53
|
-
bool start_object(std::size_t) override { return true; }
|
54
|
-
bool key(string_t &) override { return true; }
|
469
|
+
bool null() override { return true; } // NOLINT
|
470
|
+
bool boolean(bool) override { return true; } // NOLINT
|
471
|
+
bool number_integer(number_integer_t) override { return true; } // NOLINT
|
472
|
+
bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT
|
473
|
+
bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT
|
474
|
+
bool string(string_t &) override { return true; } // NOLINT
|
475
|
+
bool binary(binary_t &) override { return true; } // NOLINT
|
476
|
+
bool start_object(std::size_t) override { return true; } // NOLINT
|
477
|
+
bool key(string_t &) override { return true; } // NOLINT
|
55
478
|
bool end_object() override { return true; }
|
56
|
-
bool start_array(std::size_t) override { return true; }
|
479
|
+
bool start_array(std::size_t) override { return true; } // NOLINT
|
57
480
|
bool end_array() override { return true; }
|
58
481
|
};
|
59
482
|
json_error_locator err_loc;
|
@@ -75,6 +498,34 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
|
|
75
498
|
}
|
76
499
|
}
|
77
500
|
|
501
|
+
static bool parse_literal(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
|
502
|
+
auto expected_it = expected.begin();
|
503
|
+
auto tmp_it = it;
|
504
|
+
while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
|
505
|
+
++tmp_it;
|
506
|
+
++expected_it;
|
507
|
+
}
|
508
|
+
if (expected_it == expected.end()) {
|
509
|
+
it = tmp_it;
|
510
|
+
return true;
|
511
|
+
}
|
512
|
+
return false;
|
513
|
+
}
|
514
|
+
|
515
|
+
static std::optional<std::smatch> parse_pattern(std::string::const_iterator & it, const std::string::const_iterator & end, const std::regex & expected) {
|
516
|
+
std::smatch match;
|
517
|
+
if (std::regex_match(it, end, match, expected)) {
|
518
|
+
it = match.suffix().first;
|
519
|
+
return match;
|
520
|
+
}
|
521
|
+
return std::nullopt;
|
522
|
+
}
|
523
|
+
|
524
|
+
static void consume_spaces(std::string::const_iterator & it, const std::string::const_iterator & end) {
|
525
|
+
while (it != end && std::isspace(*it)) {
|
526
|
+
++it;
|
527
|
+
}
|
528
|
+
}
|
78
529
|
|
79
530
|
/**
|
80
531
|
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
@@ -84,7 +535,8 @@ static common_chat_msg parse_json_tool_calls(
|
|
84
535
|
const std::string& input,
|
85
536
|
const std::optional<std::regex> & trigger_opt,
|
86
537
|
const std::regex & function_regex,
|
87
|
-
const std::regex & close_regex
|
538
|
+
const std::regex & close_regex,
|
539
|
+
bool allow_raw_python = false) {
|
88
540
|
std::smatch match;
|
89
541
|
|
90
542
|
common_chat_msg result;
|
@@ -115,14 +567,19 @@ static common_chat_msg parse_json_tool_calls(
|
|
115
567
|
it = rit->suffix().first;
|
116
568
|
|
117
569
|
json arguments;
|
118
|
-
if (
|
570
|
+
if (parse_json(it, end, arguments)) {
|
571
|
+
if (!std::regex_search(it, end, match, close_regex)) {
|
572
|
+
throw std::runtime_error("Malformed input, missing closing pattern: " + input);
|
573
|
+
}
|
574
|
+
it = match.suffix().first;
|
575
|
+
result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
|
576
|
+
} else {
|
577
|
+
if (allow_raw_python && name == "python") {
|
578
|
+
result.tool_calls.push_back({name, json({{"code", std::string(it, end)}}).dump(), /* id= */ ""});
|
579
|
+
break;
|
580
|
+
}
|
119
581
|
throw std::runtime_error("Failed to parse json tool call arguments: " + input);
|
120
582
|
}
|
121
|
-
if (!std::regex_search(it, end, match, close_regex)) {
|
122
|
-
throw std::runtime_error("Malformed input, missing closing pattern: " + input);
|
123
|
-
}
|
124
|
-
it = match.suffix().first;
|
125
|
-
result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
|
126
583
|
}
|
127
584
|
|
128
585
|
if (!result.tool_calls.empty()) {
|
@@ -134,29 +591,29 @@ static common_chat_msg parse_json_tool_calls(
|
|
134
591
|
return result;
|
135
592
|
}
|
136
593
|
|
594
|
+
static common_chat_tool_call process_tool_call(const json & tool_call) {
|
595
|
+
const auto & arguments = tool_call.at("arguments");
|
596
|
+
return {
|
597
|
+
/* .name = */ tool_call.at("name"),
|
598
|
+
/* .arguments = */ arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
599
|
+
/* .id = */ tool_call.contains("id") ? tool_call.at("id") : "",
|
600
|
+
};
|
601
|
+
}
|
137
602
|
static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
|
138
603
|
auto content_end = input.find(prefix);
|
139
604
|
size_t tc_start = std::string::npos;
|
140
605
|
|
141
606
|
common_chat_msg result;
|
142
607
|
result.role = "assistant";
|
143
|
-
const auto process_tool_calls = [&](const json & tool_calls) {
|
144
|
-
for (const auto & tool_call : tool_calls) {
|
145
|
-
const auto & arguments = tool_call.at("arguments");
|
146
|
-
result.tool_calls.push_back({
|
147
|
-
tool_call.at("name"),
|
148
|
-
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
149
|
-
tool_call.contains("id") ? tool_call.at("id") : "",
|
150
|
-
});
|
151
|
-
}
|
152
|
-
};
|
153
608
|
if (content_end == std::string::npos) {
|
154
609
|
result.content = input;
|
155
610
|
} else {
|
156
611
|
tc_start = content_end + prefix.size() - rstrip_prefix;
|
157
612
|
result.content = input.substr(0, content_end);
|
158
613
|
auto tool_calls = json::parse(input.substr(tc_start));
|
159
|
-
|
614
|
+
for (const auto & tool_call : tool_calls) {
|
615
|
+
result.tool_calls.emplace_back(process_tool_call(tool_call));
|
616
|
+
}
|
160
617
|
}
|
161
618
|
return result;
|
162
619
|
}
|
@@ -187,13 +644,20 @@ static std::string apply(
|
|
187
644
|
// tmpl_inputs.now = std::chrono::system_clock::now();
|
188
645
|
|
189
646
|
minja::chat_template_options tmpl_opts;
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
647
|
+
// To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
|
648
|
+
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
|
649
|
+
// may be needed inside the template / between messages too.
|
650
|
+
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
|
651
|
+
if (string_starts_with(result, tmpl.bos_token())) {
|
652
|
+
result = result.substr(tmpl.bos_token().size());
|
653
|
+
}
|
654
|
+
if (string_ends_with(result, tmpl.eos_token())) {
|
655
|
+
result = result.substr(0, result.size() - tmpl.eos_token().size());
|
656
|
+
}
|
657
|
+
return result;
|
194
658
|
}
|
195
659
|
|
196
|
-
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct
|
660
|
+
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
197
661
|
common_chat_params data;
|
198
662
|
|
199
663
|
auto tool_call_schemas = json::array();
|
@@ -247,7 +711,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
|
|
247
711
|
{"required", json::array({"tool_call"})},
|
248
712
|
};
|
249
713
|
const auto schema =
|
250
|
-
inputs.tool_choice !=
|
714
|
+
inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED
|
251
715
|
? json {
|
252
716
|
{"anyOf", json::array({
|
253
717
|
tool_call,
|
@@ -268,7 +732,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
|
|
268
732
|
data.grammar_lazy = false;
|
269
733
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
270
734
|
builder.add_schema("root", schema);
|
271
|
-
}
|
735
|
+
});
|
272
736
|
|
273
737
|
auto tweaked_messages = common_chat_template::add_system(
|
274
738
|
inputs.messages,
|
@@ -303,9 +767,9 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) {
|
|
303
767
|
return result;
|
304
768
|
}
|
305
769
|
|
306
|
-
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct
|
770
|
+
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
307
771
|
common_chat_params data;
|
308
|
-
data.grammar_lazy = inputs.tool_choice !=
|
772
|
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
309
773
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
310
774
|
auto schemas = json::array();
|
311
775
|
foreach_function(inputs.tools, [&](const json & tool) {
|
@@ -338,8 +802,11 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
|
|
338
802
|
schema["maxItems"] = 1;
|
339
803
|
}
|
340
804
|
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
341
|
-
}
|
342
|
-
data.grammar_triggers.push_back({"[TOOL_CALLS]"
|
805
|
+
});
|
806
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"});
|
807
|
+
data.preserved_tokens = {
|
808
|
+
"[TOOL_CALLS]",
|
809
|
+
};
|
343
810
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
344
811
|
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
|
345
812
|
return data;
|
@@ -348,9 +815,9 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
|
|
348
815
|
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
349
816
|
}
|
350
817
|
|
351
|
-
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct
|
818
|
+
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
352
819
|
common_chat_params data;
|
353
|
-
data.grammar_lazy = inputs.tool_choice !=
|
820
|
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
354
821
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
355
822
|
auto schemas = json::array();
|
356
823
|
foreach_function(inputs.tools, [&](const json & tool) {
|
@@ -381,14 +848,18 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
|
381
848
|
schema["maxItems"] = 1;
|
382
849
|
}
|
383
850
|
builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
|
384
|
-
}
|
385
|
-
data.grammar_triggers.push_back({
|
851
|
+
});
|
852
|
+
data.grammar_triggers.push_back({
|
853
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
854
|
+
"<|START_ACTION|>",
|
855
|
+
});
|
386
856
|
data.preserved_tokens = {
|
857
|
+
"<|START_ACTION|>",
|
858
|
+
"<|END_ACTION|>",
|
387
859
|
"<|START_RESPONSE|>",
|
388
860
|
"<|END_RESPONSE|>",
|
389
861
|
"<|START_THINKING|>",
|
390
862
|
"<|END_THINKING|>",
|
391
|
-
"<|END_ACTION|>",
|
392
863
|
};
|
393
864
|
auto adjusted_messages = json::array();
|
394
865
|
for (const auto & msg : inputs.messages) {
|
@@ -408,9 +879,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
|
408
879
|
return data;
|
409
880
|
}
|
410
881
|
static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) {
|
411
|
-
static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S
|
412
|
-
static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S
|
413
|
-
static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S
|
882
|
+
static const std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)");
|
883
|
+
static const std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>");
|
884
|
+
static const std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>");
|
414
885
|
|
415
886
|
std::smatch match;
|
416
887
|
|
@@ -455,10 +926,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame
|
|
455
926
|
const auto & parameters_required = parameters.at("required");
|
456
927
|
for (const auto & prop : expected_properties) {
|
457
928
|
if (!parameters_properties.contains(prop)) {
|
458
|
-
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
|
929
|
+
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT
|
459
930
|
}
|
460
931
|
if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
|
461
|
-
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
|
932
|
+
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT
|
462
933
|
}
|
463
934
|
}
|
464
935
|
if (parameters_properties.size() != expected_properties.size()) {
|
@@ -466,18 +937,16 @@ static void expect_tool_parameters(const std::string & name, const json & parame
|
|
466
937
|
}
|
467
938
|
}
|
468
939
|
|
469
|
-
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct
|
940
|
+
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
|
470
941
|
auto builtin_tools = json::array();
|
471
942
|
common_chat_params data;
|
472
|
-
data.grammar_lazy = inputs.tool_choice !=
|
943
|
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
473
944
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
474
945
|
std::vector<std::string> tool_rules;
|
475
946
|
|
476
947
|
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
477
|
-
if (name == "wolfram_alpha") {
|
948
|
+
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
|
478
949
|
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
479
|
-
expect_tool_parameters(name, parameters, {"query"});
|
480
|
-
} else if (name == "web_search" || name == "brave_search") {
|
481
950
|
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
482
951
|
expect_tool_parameters(name, parameters, {"query"});
|
483
952
|
} else if (name == "python" || name == "code_interpreter") {
|
@@ -489,7 +958,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
|
489
958
|
|
490
959
|
std::vector<std::string> kvs;
|
491
960
|
for (const auto & [key, value] : parameters.at("properties").items()) {
|
492
|
-
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
|
961
|
+
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
|
493
962
|
}
|
494
963
|
|
495
964
|
tool_rules.push_back(
|
@@ -515,23 +984,23 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
|
515
984
|
builder.add_rule(
|
516
985
|
name + "-call",
|
517
986
|
"\"{\" space "
|
518
|
-
"( \"\\\"type\\\":\" space \"\\\"function\\\",\" space )? "
|
519
|
-
"\"\\\"name\\\"
|
520
|
-
|
521
|
-
"
|
522
|
-
|
987
|
+
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
|
988
|
+
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
|
989
|
+
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
|
990
|
+
"\"}\" space"));
|
991
|
+
});
|
992
|
+
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
|
993
|
+
data.grammar_triggers.push_back({
|
994
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
995
|
+
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
|
523
996
|
});
|
524
|
-
data.grammar_triggers.push_back({"{\"name\":", /* .at_start = */ true});
|
525
|
-
data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true});
|
526
|
-
data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true});
|
527
|
-
data.grammar_triggers.push_back({"{\"type\": \"function\"", /* .at_start = */ true});
|
528
|
-
data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true});
|
529
|
-
data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true});
|
530
997
|
if (!builtin_tools.empty()) {
|
531
|
-
data.grammar_triggers.push_back({"<|python_tag|>"
|
998
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
999
|
+
data.preserved_tokens.push_back("<|python_tag|>");
|
532
1000
|
}
|
1001
|
+
// Allow a few empty lines on top of the usual constrained json schema space rule.
|
533
1002
|
builder.add_rule("root", string_join(tool_rules, " | "));
|
534
|
-
}
|
1003
|
+
});
|
535
1004
|
data.additional_stops.push_back("<|eom_id|>");
|
536
1005
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
537
1006
|
{"tools_in_user_message", false},
|
@@ -544,54 +1013,53 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
|
544
1013
|
}
|
545
1014
|
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
|
546
1015
|
// TODO: tighten & simplify the parser, don't accept leading text context.
|
547
|
-
static std::regex function_regex(
|
548
|
-
|
549
|
-
static std::regex
|
1016
|
+
static const std::regex function_regex(
|
1017
|
+
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
|
1018
|
+
static const std::regex close_regex("\\}\\s*");
|
1019
|
+
static const std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)");
|
550
1020
|
|
551
1021
|
if (with_builtin_tools) {
|
552
1022
|
std::smatch match;
|
553
1023
|
if (std::regex_match(input, match, builtin_call_regex)) {
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
},
|
574
|
-
},
|
575
|
-
};
|
1024
|
+
try {
|
1025
|
+
auto name = match[1].str();
|
1026
|
+
auto arg_name = match[2].str();
|
1027
|
+
auto arg_value_str = match[3].str();
|
1028
|
+
auto arg_value = json::parse(arg_value_str);
|
1029
|
+
|
1030
|
+
common_chat_msg msg;
|
1031
|
+
msg.role = "assistant";
|
1032
|
+
msg.tool_calls.push_back({
|
1033
|
+
/* .name = */ name,
|
1034
|
+
/* .arguments = */ (json {
|
1035
|
+
{arg_name, arg_value},
|
1036
|
+
}).dump(),
|
1037
|
+
/* .id = */ "",
|
1038
|
+
});
|
1039
|
+
return msg;
|
1040
|
+
} catch (const std::exception & e) {
|
1041
|
+
LOG_WRN("Failed to parse builtin tool call arguments (%s): %s", e.what(), input.c_str());
|
1042
|
+
}
|
576
1043
|
}
|
577
1044
|
}
|
578
1045
|
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
579
1046
|
}
|
580
1047
|
|
581
|
-
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct
|
1048
|
+
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
582
1049
|
common_chat_params data;
|
583
1050
|
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
584
|
-
data.grammar_lazy = inputs.tool_choice !=
|
1051
|
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
|
585
1052
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
586
1053
|
std::vector<std::string> tool_rules;
|
587
1054
|
foreach_function(inputs.tools, [&](const json & tool) {
|
588
1055
|
const auto & function = tool.at("function");
|
589
1056
|
std::string name = function.at("name");
|
590
1057
|
auto parameters = function.at("parameters");
|
591
|
-
|
1058
|
+
builder.resolve_refs(parameters);
|
592
1059
|
tool_rules.push_back(builder.add_rule(name + "-call",
|
593
1060
|
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
|
594
|
-
"```json\\n\" " +
|
1061
|
+
"```json\\n\" " + builder.add_schema(name + "-args", parameters) + " "
|
1062
|
+
"\"```<|tool▁call▁end|>\""));
|
595
1063
|
});
|
596
1064
|
// Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
|
597
1065
|
// so we accept common variants (then it's all constrained)
|
@@ -600,18 +1068,20 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
|
|
600
1068
|
"(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
|
601
1069
|
"\"<|tool▁calls▁end|>\""
|
602
1070
|
" space");
|
603
|
-
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>"
|
604
|
-
data.grammar_triggers.push_back({"<|tool_calls_begin|>"
|
605
|
-
data.grammar_triggers.push_back({"<|tool calls begin|>"
|
606
|
-
data.grammar_triggers.push_back({"<|tool\\_calls\\_begin|>"
|
1071
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool▁calls▁begin|>"});
|
1072
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls_begin|>"});
|
1073
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool calls begin|>"});
|
1074
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool\\_calls\\_begin|>"});
|
607
1075
|
data.preserved_tokens = {
|
608
1076
|
"<think>",
|
609
1077
|
"</think>",
|
1078
|
+
"<|tool▁calls▁begin|>",
|
1079
|
+
"<|tool▁call▁begin|>",
|
610
1080
|
"<|tool▁sep|>",
|
611
|
-
"<|tool▁calls▁end|",
|
612
1081
|
"<|tool▁call▁end|>",
|
1082
|
+
"<|tool▁calls▁end|",
|
613
1083
|
};
|
614
|
-
}
|
1084
|
+
});
|
615
1085
|
}
|
616
1086
|
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
617
1087
|
|
@@ -636,45 +1106,53 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
|
|
636
1106
|
data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1;
|
637
1107
|
return data;
|
638
1108
|
}
|
639
|
-
static common_chat_msg
|
640
|
-
static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
641
|
-
static std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
|
642
|
-
static std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
|
643
|
-
static std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>");
|
644
|
-
common_chat_msg msg;
|
645
|
-
msg.role = "assistant";
|
1109
|
+
static common_chat_msg handle_think_tag_prelude(const std::string & input, bool extract_reasoning, const std::function<common_chat_msg(const std::string &)> & rest_parser) {
|
646
1110
|
std::smatch match;
|
1111
|
+
static const std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
|
647
1112
|
if (std::regex_match(input, match, reasoning_content_regex)) {
|
648
|
-
|
1113
|
+
auto rest = match[3].str();
|
1114
|
+
auto msg = rest_parser(rest);
|
1115
|
+
auto reasoning_content = string_strip(match[2].str());
|
649
1116
|
if (extract_reasoning) {
|
650
|
-
msg.reasoning_content =
|
651
|
-
} else {
|
652
|
-
|
1117
|
+
msg.reasoning_content = reasoning_content;
|
1118
|
+
} else if (!reasoning_content.empty()) {
|
1119
|
+
std::ostringstream content;
|
1120
|
+
content << "<think>" << reasoning_content << "</think>" << msg.content;
|
1121
|
+
msg.content = content.str();
|
653
1122
|
}
|
654
|
-
|
1123
|
+
return msg;
|
1124
|
+
}
|
1125
|
+
return rest_parser(input);
|
1126
|
+
}
|
1127
|
+
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) {
|
1128
|
+
return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
|
1129
|
+
static const std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
1130
|
+
static const std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
|
1131
|
+
static const std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>");
|
655
1132
|
|
656
|
-
|
1133
|
+
common_chat_msg msg;
|
1134
|
+
msg.role = "assistant";
|
1135
|
+
std::smatch match;
|
1136
|
+
if (std::regex_search(input, match, tool_calls_regex)) {
|
657
1137
|
auto tool_calls = match[1].str();
|
658
1138
|
auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
|
659
1139
|
msg.tool_calls = std::move(msg2.tool_calls);
|
660
1140
|
} else {
|
661
|
-
msg.content
|
1141
|
+
msg.content = input;
|
662
1142
|
}
|
663
|
-
|
664
|
-
|
665
|
-
}
|
666
|
-
return msg;
|
1143
|
+
return msg;
|
1144
|
+
});
|
667
1145
|
}
|
668
1146
|
|
669
|
-
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct
|
670
|
-
|
1147
|
+
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
1148
|
+
LOG_DBG("%s\n", __func__);
|
671
1149
|
common_chat_params data;
|
672
1150
|
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
673
1151
|
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
674
1152
|
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
675
1153
|
});
|
676
1154
|
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
677
|
-
data.grammar_lazy = inputs.tool_choice !=
|
1155
|
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
678
1156
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
679
1157
|
auto schemas = json::array();
|
680
1158
|
foreach_function(inputs.tools, [&](const json & tool) {
|
@@ -700,8 +1178,11 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
|
|
700
1178
|
schema["maxItems"] = 1;
|
701
1179
|
}
|
702
1180
|
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
703
|
-
}
|
704
|
-
data.grammar_triggers.push_back({" functools["
|
1181
|
+
});
|
1182
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, " functools["});
|
1183
|
+
data.preserved_tokens = {
|
1184
|
+
" functools[",
|
1185
|
+
};
|
705
1186
|
data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2;
|
706
1187
|
} else {
|
707
1188
|
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
@@ -712,14 +1193,14 @@ static common_chat_msg common_chat_parse_firefunction_v2(const std::string & inp
|
|
712
1193
|
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
713
1194
|
}
|
714
1195
|
|
715
|
-
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct
|
1196
|
+
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
716
1197
|
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
717
1198
|
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
718
1199
|
common_chat_params data;
|
719
1200
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
720
1201
|
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
|
721
1202
|
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
722
|
-
data.grammar_lazy = inputs.tool_choice !=
|
1203
|
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
723
1204
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
724
1205
|
std::vector<std::string> first_tool_rules;
|
725
1206
|
std::vector<std::string> subsequent_tool_rules;
|
@@ -727,12 +1208,30 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|
727
1208
|
const auto & function = tool.at("function");
|
728
1209
|
std::string name = function.at("name");
|
729
1210
|
auto parameters = function.at("parameters");
|
1211
|
+
builder.resolve_refs(parameters);
|
730
1212
|
auto args_rule = builder.add_schema(name + "-args", parameters);
|
731
|
-
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
1213
|
+
first_tool_rules.push_back(builder.add_rule(name + "-call", "( \"assistant<|end_header_id|>\\n\" )? \"" + name + "\\n\" " + args_rule));
|
732
1214
|
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
|
733
|
-
data.grammar_triggers.push_back({
|
734
|
-
|
1215
|
+
data.grammar_triggers.push_back({
|
1216
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
1217
|
+
regex_escape(name + "\n"),
|
1218
|
+
});
|
1219
|
+
data.grammar_triggers.push_back({
|
1220
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
1221
|
+
regex_escape("assistant<|end_header_id|>\n" + name + "\n"),
|
1222
|
+
});
|
1223
|
+
data.grammar_triggers.push_back({
|
1224
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
1225
|
+
regex_escape(">>>" + name + "\n"),
|
1226
|
+
});
|
1227
|
+
data.grammar_triggers.push_back({
|
1228
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
1229
|
+
">>>assistant<|end_header_id|>\n" + name,
|
1230
|
+
});
|
735
1231
|
});
|
1232
|
+
data.preserved_tokens = {
|
1233
|
+
"<|end_header_id|>",
|
1234
|
+
};
|
736
1235
|
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
737
1236
|
if (inputs.parallel_tool_calls) {
|
738
1237
|
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
@@ -741,34 +1240,20 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|
741
1240
|
builder.add_rule("root", first_rule);
|
742
1241
|
}
|
743
1242
|
|
744
|
-
}
|
1243
|
+
});
|
745
1244
|
}
|
746
1245
|
return data;
|
747
1246
|
}
|
748
1247
|
|
749
|
-
static bool consume(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
|
750
|
-
auto expected_it = expected.begin();
|
751
|
-
auto tmp_it = it;
|
752
|
-
while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
|
753
|
-
++tmp_it;
|
754
|
-
++expected_it;
|
755
|
-
}
|
756
|
-
if (expected_it == expected.end()) {
|
757
|
-
it = tmp_it;
|
758
|
-
return true;
|
759
|
-
}
|
760
|
-
return false;
|
761
|
-
}
|
762
|
-
|
763
1248
|
static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) {
|
764
|
-
static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
|
765
|
-
static std::regex close_regex(R"($|(?=>>>))");
|
1249
|
+
static const std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)");
|
1250
|
+
static const std::regex close_regex(R"($|(?=>>>))");
|
766
1251
|
|
767
1252
|
std::string content;
|
768
1253
|
auto it = input.begin();
|
769
1254
|
const auto end = input.end();
|
770
1255
|
|
771
|
-
if (
|
1256
|
+
if (parse_literal(it, end, "all\n")) {
|
772
1257
|
std::smatch match;
|
773
1258
|
if (std::regex_search(it, end, match, function_regex)) {
|
774
1259
|
auto fun_it = match.prefix().second;
|
@@ -783,7 +1268,7 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
|
|
783
1268
|
}
|
784
1269
|
// TODO: tighten & simplify.
|
785
1270
|
try {
|
786
|
-
auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex);
|
1271
|
+
auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex, /* allow_raw_python= */ true);
|
787
1272
|
res.content = content + res.content;
|
788
1273
|
return res;
|
789
1274
|
} catch (const std::exception & e) {
|
@@ -795,14 +1280,14 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
|
|
795
1280
|
}
|
796
1281
|
}
|
797
1282
|
|
798
|
-
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct
|
1283
|
+
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
799
1284
|
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
800
1285
|
common_chat_params data;
|
801
1286
|
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
802
1287
|
std::string python_code_argument_name;
|
803
1288
|
auto has_raw_python = false;
|
804
1289
|
|
805
|
-
data.grammar_lazy = inputs.tool_choice !=
|
1290
|
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
806
1291
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
807
1292
|
std::vector<std::string> tool_rules;
|
808
1293
|
foreach_function(inputs.tools, [&](const json & tool) {
|
@@ -814,7 +1299,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
|
814
1299
|
throw std::runtime_error("Missing type in python tool");
|
815
1300
|
}
|
816
1301
|
has_raw_python = true;
|
817
|
-
auto type = parameters.at("type");
|
1302
|
+
const auto & type = parameters.at("type");
|
818
1303
|
if (type == "object") {
|
819
1304
|
auto properties = parameters.at("properties");
|
820
1305
|
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
@@ -836,12 +1321,13 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
|
836
1321
|
});
|
837
1322
|
if (has_raw_python) {
|
838
1323
|
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
839
|
-
data.grammar_triggers.push_back({"<|python_tag|>"
|
1324
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
1325
|
+
data.preserved_tokens.push_back("<|python_tag|>");
|
840
1326
|
}
|
841
1327
|
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
842
1328
|
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
843
|
-
data.grammar_triggers.push_back({"<function="
|
844
|
-
}
|
1329
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
|
1330
|
+
});
|
845
1331
|
|
846
1332
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
847
1333
|
// TODO: if (has_raw_python)
|
@@ -850,34 +1336,33 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
|
850
1336
|
}
|
851
1337
|
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
|
852
1338
|
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
853
|
-
static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
1339
|
+
static const std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
854
1340
|
std::smatch match;
|
855
1341
|
if (std::regex_search(input, match, python_tag_regex)) {
|
856
1342
|
auto code = match[1].str();
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
}
|
867
|
-
};
|
1343
|
+
common_chat_msg msg;
|
1344
|
+
msg.role = "assistant";
|
1345
|
+
msg.content = match.prefix().str();
|
1346
|
+
msg.tool_calls.push_back({
|
1347
|
+
/* .name = */ "python",
|
1348
|
+
/* .arguments = */ (json {{"code", code}}).dump(),
|
1349
|
+
/* .id = */ "",
|
1350
|
+
});
|
1351
|
+
return msg;
|
868
1352
|
}
|
869
|
-
static std::regex function_regex(R"(<function=(\w+)>)");
|
870
|
-
static std::regex close_regex(R"(</function>)");
|
1353
|
+
static const std::regex function_regex(R"(<function=(\w+)>)");
|
1354
|
+
static const std::regex close_regex(R"(</function>)");
|
871
1355
|
// TODO: tighten & simplify.
|
872
1356
|
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
873
1357
|
}
|
874
1358
|
|
875
|
-
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct
|
1359
|
+
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
876
1360
|
common_chat_params data;
|
877
1361
|
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
878
|
-
data.grammar_lazy = inputs.tool_choice !=
|
1362
|
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
879
1363
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
880
1364
|
std::vector<std::string> tool_rules;
|
1365
|
+
std::vector<std::string> tool_call_alts;
|
881
1366
|
foreach_function(inputs.tools, [&](const json & tool) {
|
882
1367
|
const auto & function = tool.at("function");
|
883
1368
|
std::string name = function.at("name");
|
@@ -891,73 +1376,190 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
|
|
891
1376
|
}},
|
892
1377
|
{"required", json::array({"name", "arguments"})},
|
893
1378
|
}));
|
1379
|
+
tool_call_alts.push_back(builder.add_rule(
|
1380
|
+
name + "-function-tag",
|
1381
|
+
"\"<function\" ( \"=" + name + "\" | \" name=\\\"" + name + "\\\"\" ) \">\" space " +
|
1382
|
+
builder.add_schema(name + "-args", parameters) + " "
|
1383
|
+
"\"</function>\" space"));
|
1384
|
+
|
1385
|
+
data.grammar_triggers.push_back({
|
1386
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
1387
|
+
"<function=" + name + ">",
|
1388
|
+
});
|
1389
|
+
auto escaped_name = regex_escape(name);
|
1390
|
+
data.grammar_triggers.push_back({
|
1391
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
1392
|
+
"<function\\s+name\\s*=\\s*\"" + escaped_name + "\"",
|
1393
|
+
});
|
894
1394
|
});
|
895
|
-
auto
|
1395
|
+
auto any_tool_call = builder.add_rule("any_tool_call", "( " + string_join(tool_rules, " | ") + " ) space");
|
1396
|
+
std::vector<std::string> alt_tags {
|
1397
|
+
any_tool_call,
|
1398
|
+
"\"<tool_call>\" space " + any_tool_call + " \"</tool_call>\"",
|
1399
|
+
// The rest is just to accommodate common "good bad" outputs.
|
1400
|
+
"\"<function_call>\" space " + any_tool_call + " \"</function_call>\"",
|
1401
|
+
"\"<response>\" space " + any_tool_call + " \"</response>\"",
|
1402
|
+
"\"<tools>\" space " + any_tool_call + " \"</tools>\"",
|
1403
|
+
"\"<json>\" space " + any_tool_call + " \"</json>\"",
|
1404
|
+
"\"<xml>\" space " + any_tool_call + " \"</xml>\"",
|
1405
|
+
"\"<JSON>\" space " + any_tool_call + " \"</JSON>\"",
|
1406
|
+
};
|
1407
|
+
auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space");
|
1408
|
+
tool_call_alts.push_back(wrappable_tool_call);
|
1409
|
+
tool_call_alts.push_back(
|
1410
|
+
"( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space ");
|
1411
|
+
auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | "));
|
896
1412
|
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
897
|
-
data.grammar_triggers.push_back({"<tool_call>"
|
898
|
-
data.
|
899
|
-
|
1413
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<tool_call>"});
|
1414
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function"});
|
1415
|
+
// Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
|
1416
|
+
data.grammar_triggers.push_back({
|
1417
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
1418
|
+
"(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?\\s*\\{\\s*\"", //name\"\\s*:\\s*\"" + escaped_name + "\"",
|
1419
|
+
});
|
1420
|
+
data.preserved_tokens = {
|
1421
|
+
"<think>",
|
1422
|
+
"</think>",
|
1423
|
+
"<tool_call>",
|
1424
|
+
"</tool_call>",
|
1425
|
+
"<function",
|
1426
|
+
"<tools>",
|
1427
|
+
"</tools>",
|
1428
|
+
"<response>",
|
1429
|
+
"</response>",
|
1430
|
+
"<function_call>",
|
1431
|
+
"</function_call>",
|
1432
|
+
"<json>",
|
1433
|
+
"</json>",
|
1434
|
+
"<JSON>",
|
1435
|
+
"</JSON>",
|
1436
|
+
"```",
|
1437
|
+
"```json",
|
1438
|
+
"```xml",
|
1439
|
+
};
|
1440
|
+
});
|
900
1441
|
|
901
1442
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
902
|
-
data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
|
1443
|
+
data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING : COMMON_CHAT_FORMAT_HERMES_2_PRO;
|
903
1444
|
return data;
|
904
1445
|
}
|
905
|
-
static common_chat_msg common_chat_parse_hermes_2_pro(const std::string
|
906
|
-
|
907
|
-
std::regex
|
908
|
-
|
909
|
-
|
910
|
-
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
|
917
|
-
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
const
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
1446
|
+
static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input, bool extract_reasoning) {
|
1447
|
+
return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
|
1448
|
+
static const std::regex open_regex(
|
1449
|
+
"(?:"
|
1450
|
+
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
|
1451
|
+
"(<tool_call>" // match 2 (open_tag)
|
1452
|
+
"|<function_call>"
|
1453
|
+
"|<tool>"
|
1454
|
+
"|<tools>"
|
1455
|
+
"|<response>"
|
1456
|
+
"|<json>"
|
1457
|
+
"|<xml>"
|
1458
|
+
"|<JSON>"
|
1459
|
+
")?"
|
1460
|
+
"(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)" // match 3 (named tool call + rest)
|
1461
|
+
")"
|
1462
|
+
"|"
|
1463
|
+
"(?:<function=([^>]+)>" // match 4 (function name)
|
1464
|
+
"|<function name=\"([^\"]+)\">)" // match 5 (function name again)
|
1465
|
+
"([\\s\\S]*)" // match 6 (function arguments + rest)})"
|
1466
|
+
);
|
1467
|
+
|
1468
|
+
try {
|
1469
|
+
common_chat_msg msg;
|
1470
|
+
msg.role = "assistant";
|
1471
|
+
|
1472
|
+
std::string::const_iterator it = input.begin();
|
1473
|
+
const std::string::const_iterator end = input.end();
|
1474
|
+
std::smatch match;
|
1475
|
+
|
1476
|
+
while (it != end) {
|
1477
|
+
if (std::regex_search(it, end, match, open_regex)) {
|
1478
|
+
// Add content before the match
|
1479
|
+
msg.content += std::string(it, match[0].first);
|
1480
|
+
|
1481
|
+
auto block_start = match[1].str();
|
1482
|
+
std::string block_end = block_start.empty() ? "" : "```";
|
1483
|
+
|
1484
|
+
auto open_tag = match[2].str();
|
1485
|
+
std::string close_tag;
|
1486
|
+
|
1487
|
+
if (match[3].matched) {
|
1488
|
+
close_tag = open_tag.empty() ? "" : "</" + open_tag.substr(1);
|
1489
|
+
auto json_it = match[3].first;
|
1490
|
+
json tool_call;
|
1491
|
+
if (parse_json(json_it, end, tool_call) && tool_call.contains("name") && tool_call.contains("arguments")) {
|
1492
|
+
|
1493
|
+
msg.tool_calls.emplace_back(process_tool_call(tool_call));
|
1494
|
+
it = json_it; // Move iterator past parsed JSON
|
1495
|
+
|
1496
|
+
// Handle close tags
|
1497
|
+
consume_spaces(it, end);
|
1498
|
+
if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
|
1499
|
+
throw std::runtime_error("Failed to parse closing tag");
|
1500
|
+
}
|
1501
|
+
consume_spaces(it, end);
|
1502
|
+
if (!block_end.empty() && !parse_literal(it, end, block_end)) {
|
1503
|
+
throw std::runtime_error("Failed to parse block end");
|
1504
|
+
}
|
1505
|
+
consume_spaces(it, end);
|
1506
|
+
} else {
|
1507
|
+
// Not a valid tool call, treat as content
|
1508
|
+
msg.content += std::string(match[0].first, match[0].second);
|
1509
|
+
it = match[0].second;
|
1510
|
+
}
|
1511
|
+
} else {
|
1512
|
+
auto function_name = match[4].str();
|
1513
|
+
if (function_name.empty()) {
|
1514
|
+
function_name = match[5].str();
|
1515
|
+
}
|
1516
|
+
LM_GGML_ASSERT(!function_name.empty());
|
1517
|
+
|
1518
|
+
close_tag = "</function>";
|
1519
|
+
// Start parsing from after the opening tags
|
1520
|
+
auto json_it = match[6].first;
|
1521
|
+
json arguments;
|
1522
|
+
if (parse_json(json_it, end, arguments)) {
|
1523
|
+
msg.tool_calls.emplace_back(process_tool_call({
|
1524
|
+
{"name", function_name},
|
1525
|
+
{"arguments", arguments},
|
1526
|
+
}));
|
1527
|
+
it = json_it; // Move iterator past parsed JSON
|
1528
|
+
|
1529
|
+
// Handle close tags
|
1530
|
+
consume_spaces(it, end);
|
1531
|
+
if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
|
1532
|
+
throw std::runtime_error("Failed to parse closing tag");
|
1533
|
+
}
|
1534
|
+
consume_spaces(it, end);
|
1535
|
+
if (!block_end.empty() && !parse_literal(it, end, block_end)) {
|
1536
|
+
throw std::runtime_error("Failed to parse block end");
|
1537
|
+
}
|
1538
|
+
consume_spaces(it, end);
|
1539
|
+
} else {
|
1540
|
+
// Not a valid tool call, treat as content
|
1541
|
+
msg.content += std::string(match[0].first, match[0].second);
|
1542
|
+
it = match[0].second;
|
1543
|
+
}
|
1544
|
+
}
|
1545
|
+
} else {
|
1546
|
+
// Add remaining content
|
1547
|
+
msg.content += std::string(it, end);
|
1548
|
+
break;
|
946
1549
|
}
|
947
|
-
break;
|
948
1550
|
}
|
1551
|
+
return msg;
|
1552
|
+
} catch (const std::exception & e) {
|
1553
|
+
LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
|
1554
|
+
common_chat_msg msg;
|
1555
|
+
msg.role = "assistant";
|
1556
|
+
msg.content = input;
|
1557
|
+
return msg;
|
949
1558
|
}
|
950
|
-
|
951
|
-
} catch (const std::exception & e) {
|
952
|
-
return {
|
953
|
-
/* .role = */ "assistant",
|
954
|
-
/* .content = */ input,
|
955
|
-
/* .tool_calls = */ {},
|
956
|
-
};
|
957
|
-
}
|
1559
|
+
});
|
958
1560
|
}
|
959
1561
|
|
960
|
-
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct
|
1562
|
+
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
961
1563
|
common_chat_params data;
|
962
1564
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
963
1565
|
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
@@ -973,12 +1575,35 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
|
|
973
1575
|
return data;
|
974
1576
|
}
|
975
1577
|
|
976
|
-
common_chat_params
|
1578
|
+
static common_chat_params common_chat_templates_apply_jinja(
|
1579
|
+
const struct common_chat_templates * tmpls,
|
1580
|
+
const struct common_chat_templates_inputs & inputs)
|
1581
|
+
{
|
1582
|
+
templates_params params;
|
1583
|
+
params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
|
1584
|
+
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
|
1585
|
+
? *tmpls->template_tool_use
|
1586
|
+
: *tmpls->template_default;
|
977
1587
|
const auto & src = tmpl.source();
|
978
1588
|
const auto & caps = tmpl.original_caps();
|
1589
|
+
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
|
1590
|
+
params.add_generation_prompt = inputs.add_generation_prompt;
|
1591
|
+
params.extract_reasoning = inputs.extract_reasoning;
|
1592
|
+
params.tool_choice = inputs.tool_choice;
|
1593
|
+
params.grammar = inputs.grammar;
|
1594
|
+
if (!inputs.json_schema.empty()) {
|
1595
|
+
params.json_schema = json::parse(inputs.json_schema);
|
1596
|
+
}
|
1597
|
+
|
1598
|
+
if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
|
1599
|
+
LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
|
1600
|
+
params.parallel_tool_calls = false;
|
1601
|
+
} else {
|
1602
|
+
params.parallel_tool_calls = inputs.parallel_tool_calls;
|
1603
|
+
}
|
979
1604
|
|
980
|
-
if (
|
981
|
-
if (
|
1605
|
+
if (params.tools.is_array()) {
|
1606
|
+
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
|
982
1607
|
throw std::runtime_error("Cannot specify grammar with tools");
|
983
1608
|
}
|
984
1609
|
if (caps.supports_tool_calls && !caps.supports_tools) {
|
@@ -987,68 +1612,135 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
|
|
987
1612
|
}
|
988
1613
|
|
989
1614
|
// DeepSeek R1: use handler in all cases except json schema (thinking / tools).
|
990
|
-
if (src.find("<|tool▁calls▁begin|>") != std::string::npos &&
|
991
|
-
return common_chat_params_init_deepseek_r1(tmpl,
|
1615
|
+
if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) {
|
1616
|
+
return common_chat_params_init_deepseek_r1(tmpl, params);
|
992
1617
|
}
|
993
1618
|
|
994
1619
|
// Command R7B: : use handler in all cases except json schema (thinking / tools).
|
995
|
-
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos &&
|
996
|
-
return common_chat_params_init_command_r7b(tmpl,
|
1620
|
+
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) {
|
1621
|
+
return common_chat_params_init_command_r7b(tmpl, params);
|
1622
|
+
}
|
1623
|
+
|
1624
|
+
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
1625
|
+
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
|
1626
|
+
return common_chat_params_init_hermes_2_pro(tmpl, params);
|
997
1627
|
}
|
998
1628
|
|
999
1629
|
// Use generic handler when mixing tools + JSON schema.
|
1000
1630
|
// TODO: support that mix in handlers below.
|
1001
|
-
if ((
|
1002
|
-
return common_chat_params_init_generic(tmpl,
|
1631
|
+
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
1632
|
+
return common_chat_params_init_generic(tmpl, params);
|
1003
1633
|
}
|
1004
1634
|
|
1005
1635
|
// Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
|
1006
1636
|
if (src.find(">>>all") != std::string::npos) {
|
1007
|
-
return common_chat_params_init_functionary_v3_2(tmpl,
|
1637
|
+
return common_chat_params_init_functionary_v3_2(tmpl, params);
|
1008
1638
|
}
|
1009
1639
|
|
1010
1640
|
// Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
|
1011
1641
|
if (src.find(" functools[") != std::string::npos) {
|
1012
|
-
return common_chat_params_init_firefunction_v2(tmpl,
|
1642
|
+
return common_chat_params_init_firefunction_v2(tmpl, params);
|
1013
1643
|
}
|
1014
1644
|
|
1015
1645
|
// Plain handler (no tools)
|
1016
|
-
if (
|
1017
|
-
return common_chat_params_init_without_tools(tmpl,
|
1018
|
-
}
|
1019
|
-
|
1020
|
-
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
1021
|
-
if (src.find("<tool_call>") != std::string::npos) {
|
1022
|
-
return common_chat_params_init_hermes_2_pro(tmpl, inputs);
|
1646
|
+
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
1647
|
+
return common_chat_params_init_without_tools(tmpl, params);
|
1023
1648
|
}
|
1024
1649
|
|
1025
1650
|
// Functionary v3.1 (w/ tools)
|
1026
1651
|
if (src.find("<|start_header_id|>") != std::string::npos
|
1027
1652
|
&& src.find("<function=") != std::string::npos) {
|
1028
|
-
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl,
|
1653
|
+
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
|
1029
1654
|
}
|
1030
1655
|
|
1031
1656
|
// Llama 3.1, 3.2, 3.3 (w/ tools)
|
1032
1657
|
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
1033
1658
|
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
1034
|
-
return common_chat_params_init_llama_3_1_tool_calls(tmpl,
|
1659
|
+
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
|
1035
1660
|
}
|
1036
1661
|
|
1037
1662
|
// Mistral Nemo (w/ tools)
|
1038
1663
|
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
1039
|
-
return common_chat_params_init_mistral_nemo(tmpl,
|
1664
|
+
return common_chat_params_init_mistral_nemo(tmpl, params);
|
1040
1665
|
}
|
1041
1666
|
|
1042
1667
|
// Generic fallback
|
1043
|
-
return common_chat_params_init_generic(tmpl,
|
1668
|
+
return common_chat_params_init_generic(tmpl, params);
|
1669
|
+
}
|
1670
|
+
|
1671
|
+
// Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template.
|
1672
|
+
static common_chat_params common_chat_templates_apply_legacy(
|
1673
|
+
const struct common_chat_templates * tmpls,
|
1674
|
+
const struct common_chat_templates_inputs & inputs)
|
1675
|
+
{
|
1676
|
+
int alloc_size = 0;
|
1677
|
+
std::vector<llama_chat_message> chat;
|
1678
|
+
std::vector<std::string> contents;
|
1679
|
+
for (const auto & msg : inputs.messages) {
|
1680
|
+
auto content = msg.content;
|
1681
|
+
for (const auto & part : msg.content_parts) {
|
1682
|
+
if (part.type != "text") {
|
1683
|
+
LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
|
1684
|
+
continue;
|
1685
|
+
}
|
1686
|
+
if (!content.empty()) {
|
1687
|
+
content += "\n";;
|
1688
|
+
}
|
1689
|
+
content += part.text;
|
1690
|
+
}
|
1691
|
+
contents.emplace_back(std::move(content));
|
1692
|
+
}
|
1693
|
+
for (size_t i = 0; i < contents.size(); ++i) {
|
1694
|
+
const auto & msg = inputs.messages[i];
|
1695
|
+
const auto & content = contents[i];
|
1696
|
+
chat.push_back({msg.role.c_str(), content.c_str()});
|
1697
|
+
alloc_size += (msg.role.size() + content.size()) * 1.25;
|
1698
|
+
}
|
1699
|
+
|
1700
|
+
std::vector<char> buf(alloc_size);
|
1701
|
+
|
1702
|
+
// run the first time to get the total output length
|
1703
|
+
const auto & src = tmpls->template_default->source();
|
1704
|
+
int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
1705
|
+
|
1706
|
+
// error: chat template is not supported
|
1707
|
+
if (res < 0) {
|
1708
|
+
// if the custom "tmpl" is not supported, we throw an error
|
1709
|
+
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
1710
|
+
throw std::runtime_error("this custom template is not supported");
|
1711
|
+
}
|
1712
|
+
|
1713
|
+
// if it turns out that our buffer is too small, we resize it
|
1714
|
+
if ((size_t) res > buf.size()) {
|
1715
|
+
buf.resize(res);
|
1716
|
+
res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
1717
|
+
}
|
1718
|
+
|
1719
|
+
common_chat_params params;
|
1720
|
+
params.prompt = std::string(buf.data(), res);
|
1721
|
+
if (!inputs.json_schema.empty()) {
|
1722
|
+
params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema));
|
1723
|
+
} else {
|
1724
|
+
params.grammar = inputs.grammar;
|
1725
|
+
}
|
1726
|
+
return params;
|
1727
|
+
}
|
1728
|
+
|
1729
|
+
common_chat_params common_chat_templates_apply(
|
1730
|
+
const struct common_chat_templates * tmpls,
|
1731
|
+
const struct common_chat_templates_inputs & inputs)
|
1732
|
+
{
|
1733
|
+
LM_GGML_ASSERT(tmpls != nullptr);
|
1734
|
+
return inputs.use_jinja
|
1735
|
+
? common_chat_templates_apply_jinja(tmpls, inputs)
|
1736
|
+
: common_chat_templates_apply_legacy(tmpls, inputs);
|
1044
1737
|
}
|
1045
1738
|
|
1046
1739
|
static common_chat_msg common_chat_parse_content_only(const std::string & input) {
|
1047
|
-
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
};
|
1740
|
+
common_chat_msg msg;
|
1741
|
+
msg.role = "assistant";
|
1742
|
+
msg.content = input;
|
1743
|
+
return msg;
|
1052
1744
|
}
|
1053
1745
|
|
1054
1746
|
common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
|
@@ -1072,7 +1764,9 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
|
|
1072
1764
|
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
|
1073
1765
|
return common_chat_parse_functionary_v3_1_llama_3_1(input);
|
1074
1766
|
case COMMON_CHAT_FORMAT_HERMES_2_PRO:
|
1075
|
-
return common_chat_parse_hermes_2_pro(input);
|
1767
|
+
return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ false);
|
1768
|
+
case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING:
|
1769
|
+
return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ true);
|
1076
1770
|
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
|
1077
1771
|
return common_chat_parse_firefunction_v2(input);
|
1078
1772
|
case COMMON_CHAT_FORMAT_COMMAND_R7B:
|