@fugood/llama.node 0.3.12 → 0.3.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +2 -1
- package/package.json +1 -1
- package/src/LlamaCompletionWorker.cpp +14 -0
- package/src/LlamaContext.cpp +110 -79
- package/src/LlamaContext.h +1 -1
- package/src/common.hpp +1 -2
- package/src/llama.cpp/.github/workflows/build.yml +95 -13
- package/src/llama.cpp/.github/workflows/docker.yml +2 -0
- package/src/llama.cpp/.github/workflows/labeler.yml +1 -1
- package/src/llama.cpp/.github/workflows/server.yml +2 -0
- package/src/llama.cpp/common/CMakeLists.txt +23 -6
- package/src/llama.cpp/common/arg.cpp +292 -14
- package/src/llama.cpp/common/chat.cpp +1128 -315
- package/src/llama.cpp/common/chat.h +135 -0
- package/src/llama.cpp/common/common.cpp +27 -171
- package/src/llama.cpp/common/common.h +41 -73
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
- package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
- package/src/llama.cpp/common/llguidance.cpp +3 -3
- package/src/llama.cpp/common/log.cpp +1 -0
- package/src/llama.cpp/common/log.h +2 -1
- package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +21 -7
- package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +61 -14
- package/src/llama.cpp/common/ngram-cache.cpp +1 -0
- package/src/llama.cpp/common/sampling.cpp +93 -49
- package/src/llama.cpp/common/speculative.cpp +6 -5
- package/src/llama.cpp/common/speculative.h +1 -1
- package/src/llama.cpp/docs/build.md +47 -9
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +4 -4
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +6 -5
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +373 -107
- package/src/llama.cpp/examples/llava/clip.h +19 -3
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
- package/src/llama.cpp/examples/llava/llava.cpp +4 -2
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
- package/src/llama.cpp/examples/main/main.cpp +73 -28
- package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +1 -0
- package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
- package/src/llama.cpp/examples/run/run.cpp +115 -79
- package/src/llama.cpp/examples/server/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/server/httplib.h +381 -292
- package/src/llama.cpp/examples/server/server.cpp +134 -128
- package/src/llama.cpp/examples/server/utils.hpp +95 -106
- package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
- package/src/llama.cpp/examples/tts/tts.cpp +251 -142
- package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
- package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-cpu.h +4 -1
- package/src/llama.cpp/ggml/include/ggml-metal.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +0 -2
- package/src/llama.cpp/ggml/include/ggml.h +6 -2
- package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
- package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
- package/src/llama.cpp/ggml/src/ggml-common.h +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +156 -11
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +2235 -641
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1572 -198
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +24 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +9 -8
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +16 -3
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +246 -120
- package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +174 -728
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +949 -602
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +37 -3
- package/src/llama.cpp/ggml/src/ggml.c +9 -4
- package/src/llama.cpp/include/llama.h +32 -14
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +1 -0
- package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
- package/src/llama.cpp/requirements.txt +1 -0
- package/src/llama.cpp/src/llama-arch.cpp +21 -0
- package/src/llama.cpp/src/llama-arch.h +1 -0
- package/src/llama.cpp/src/llama-chat.cpp +1 -0
- package/src/llama.cpp/src/llama-grammar.cpp +183 -183
- package/src/llama.cpp/src/llama-grammar.h +13 -4
- package/src/llama.cpp/src/llama-impl.h +6 -6
- package/src/llama.cpp/src/llama-kv-cache.h +2 -1
- package/src/llama.cpp/src/llama-mmap.cpp +11 -1
- package/src/llama.cpp/src/llama-mmap.h +1 -0
- package/src/llama.cpp/src/llama-model.cpp +70 -6
- package/src/llama.cpp/src/llama-sampling.cpp +174 -67
- package/src/llama.cpp/src/llama-vocab.cpp +12 -0
- package/src/llama.cpp/src/llama.cpp +154 -5
- package/src/llama.cpp/src/unicode.cpp +9 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +171 -115
- package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
- package/src/llama.cpp/tests/test-chat.cpp +691 -325
- package/src/llama.cpp/tests/test-gguf.cpp +4 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
- package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
- package/src/llama.cpp/tests/test-sampling.cpp +15 -0
- package/src/llama.cpp/Sources/llama/llama.h +0 -4
- package/src/llama.cpp/common/chat.hpp +0 -52
|
@@ -10,35 +10,12 @@
|
|
|
10
10
|
#include <json.hpp>
|
|
11
11
|
#include <string>
|
|
12
12
|
|
|
13
|
-
#include "chat
|
|
14
|
-
#include "chat.hpp"
|
|
13
|
+
#include "chat.h"
|
|
15
14
|
#include "llama-grammar.h"
|
|
16
15
|
#include "unicode.h"
|
|
17
16
|
|
|
18
17
|
using json = nlohmann::ordered_json;
|
|
19
18
|
|
|
20
|
-
static common_chat_msg msg_from_json(const json & message) {
|
|
21
|
-
common_chat_msg ret;
|
|
22
|
-
ret.role = "assistant";
|
|
23
|
-
if (message.contains("content") && !message.at("content").is_null()) {
|
|
24
|
-
ret.content = message.at("content");
|
|
25
|
-
}
|
|
26
|
-
if (message.contains("tool_plan")) {
|
|
27
|
-
ret.tool_plan = message.at("tool_plan");
|
|
28
|
-
}
|
|
29
|
-
auto has_tool_calls = message.contains("tool_calls");
|
|
30
|
-
if (has_tool_calls) {
|
|
31
|
-
for (const auto & tc : message.at("tool_calls")) {
|
|
32
|
-
const auto & arguments = tc.at("function").at("arguments");
|
|
33
|
-
ret.tool_calls.push_back({
|
|
34
|
-
tc.at("function").at("name").get<std::string>(),
|
|
35
|
-
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
|
36
|
-
tc.contains("id") ? tc.at("id").get<std::string>() : "",
|
|
37
|
-
});
|
|
38
|
-
}
|
|
39
|
-
}
|
|
40
|
-
return ret;
|
|
41
|
-
}
|
|
42
19
|
|
|
43
20
|
template <class T> static void assert_equals(const T & expected, const T & actual) {
|
|
44
21
|
if (expected != actual) {
|
|
@@ -50,7 +27,7 @@ template <class T> static void assert_equals(const T & expected, const T & actua
|
|
|
50
27
|
}
|
|
51
28
|
|
|
52
29
|
static std::string read_file(const std::string & path) {
|
|
53
|
-
std::cerr << "# Reading: " << path <<
|
|
30
|
+
std::cerr << "# Reading: " << path << '\n' << std::flush;
|
|
54
31
|
std::ifstream fs(path, std::ios_base::binary);
|
|
55
32
|
if (!fs.is_open()) {
|
|
56
33
|
fs = std::ifstream("../" + path, std::ios_base::binary);
|
|
@@ -63,10 +40,14 @@ static std::string read_file(const std::string & path) {
|
|
|
63
40
|
fs.seekg(0);
|
|
64
41
|
std::string out;
|
|
65
42
|
out.resize(static_cast<size_t>(size));
|
|
66
|
-
fs.read(
|
|
43
|
+
fs.read(out.data(), static_cast<std::streamsize>(size));
|
|
67
44
|
return out;
|
|
68
45
|
}
|
|
69
46
|
|
|
47
|
+
static common_chat_templates_ptr read_templates(const std::string & path) {
|
|
48
|
+
return common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, read_file(path)));
|
|
49
|
+
}
|
|
50
|
+
|
|
70
51
|
static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
|
|
71
52
|
return std::unique_ptr<llama_grammar>(
|
|
72
53
|
llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0));
|
|
@@ -87,122 +68,124 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
|
|
|
87
68
|
}
|
|
88
69
|
}
|
|
89
70
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
return true;
|
|
94
|
-
}
|
|
71
|
+
if (std::any_of(stacks_cur.begin(), stacks_cur.end(), [](const auto & stack) { return stack.empty(); })) {
|
|
72
|
+
// An empty stack means that the grammar has been completed
|
|
73
|
+
return true;
|
|
95
74
|
}
|
|
96
75
|
|
|
97
76
|
return false;
|
|
98
77
|
}
|
|
99
78
|
|
|
100
|
-
// Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`.
|
|
101
|
-
static std::string dump(const json & j) {
|
|
102
|
-
return minja::Value(j).dump(-1, /* to_json= */ true);
|
|
103
|
-
}
|
|
104
|
-
|
|
105
79
|
static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
|
|
106
80
|
assert_equals(expected.role, actual.role);
|
|
107
81
|
assert_equals(expected.content, actual.content);
|
|
82
|
+
assert_equals(expected.content_parts.size(), actual.content_parts.size());
|
|
83
|
+
for (size_t i = 0; i < expected.content_parts.size(); i++) {
|
|
84
|
+
const auto & expected_part = expected.content_parts[i];
|
|
85
|
+
const auto & actual_part = actual.content_parts[i];
|
|
86
|
+
assert_equals(expected_part.type, actual_part.type);
|
|
87
|
+
assert_equals(expected_part.text, actual_part.text);
|
|
88
|
+
}
|
|
89
|
+
assert_equals(expected.reasoning_content, actual.reasoning_content);
|
|
108
90
|
assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
|
|
109
91
|
for (size_t i = 0; i < expected.tool_calls.size(); i++) {
|
|
110
92
|
const auto & expected_tool_call = expected.tool_calls[i];
|
|
111
93
|
const auto & actual_tool_call = actual.tool_calls[i];
|
|
112
94
|
assert_equals(expected_tool_call.name, actual_tool_call.name);
|
|
113
|
-
assert_equals(
|
|
95
|
+
assert_equals(json::parse(expected_tool_call.arguments).dump(), json::parse(actual_tool_call.arguments).dump());
|
|
114
96
|
assert_equals(expected_tool_call.id, actual_tool_call.id);
|
|
115
97
|
}
|
|
116
98
|
}
|
|
117
99
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
"
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
"description": "Python code to execute."
|
|
163
|
-
}
|
|
164
|
-
},
|
|
165
|
-
"required": ["code"]
|
|
166
|
-
}
|
|
167
|
-
}
|
|
168
|
-
})");
|
|
169
|
-
const json tools = { special_function_tool, python_tool };
|
|
170
|
-
const json llama_3_1_tools = { special_function_tool, code_interpreter_tool };
|
|
100
|
+
common_chat_tool special_function_tool {
|
|
101
|
+
/* .name = */ "special_function",
|
|
102
|
+
/* .description = */ "I'm special",
|
|
103
|
+
/* .parameters = */ R"({
|
|
104
|
+
"type": "object",
|
|
105
|
+
"properties": {
|
|
106
|
+
"arg1": {
|
|
107
|
+
"type": "integer",
|
|
108
|
+
"description": "The arg."
|
|
109
|
+
}
|
|
110
|
+
},
|
|
111
|
+
"required": ["arg1"]
|
|
112
|
+
})",
|
|
113
|
+
};
|
|
114
|
+
common_chat_tool python_tool {
|
|
115
|
+
/* .name = */ "python",
|
|
116
|
+
/* .description = */ "an ipython interpreter",
|
|
117
|
+
/* .parameters = */ R"({
|
|
118
|
+
"type": "object",
|
|
119
|
+
"properties": {
|
|
120
|
+
"code": {
|
|
121
|
+
"type": "string",
|
|
122
|
+
"description": "Python code to execute."
|
|
123
|
+
}
|
|
124
|
+
},
|
|
125
|
+
"required": ["code"]
|
|
126
|
+
})",
|
|
127
|
+
};
|
|
128
|
+
common_chat_tool code_interpreter_tool {
|
|
129
|
+
/* .name = */ "code_interpreter",
|
|
130
|
+
/* .description = */ "an ipython interpreter",
|
|
131
|
+
/* .parameters = */ R"({
|
|
132
|
+
"type": "object",
|
|
133
|
+
"properties": {
|
|
134
|
+
"code": {
|
|
135
|
+
"type": "string",
|
|
136
|
+
"description": "Python code to execute."
|
|
137
|
+
}
|
|
138
|
+
},
|
|
139
|
+
"required": ["code"]
|
|
140
|
+
})",
|
|
141
|
+
};
|
|
142
|
+
std::vector<common_chat_tool> tools { special_function_tool, python_tool };
|
|
143
|
+
std::vector<common_chat_tool> llama_3_1_tools { special_function_tool, code_interpreter_tool };
|
|
171
144
|
|
|
172
145
|
struct delta_data {
|
|
173
146
|
std::string delta;
|
|
174
147
|
common_chat_params params;
|
|
175
148
|
};
|
|
176
149
|
|
|
177
|
-
static delta_data init_delta(const
|
|
178
|
-
const
|
|
179
|
-
const
|
|
180
|
-
|
|
150
|
+
static delta_data init_delta(const struct common_chat_templates * tmpls, const std::vector<std::string> & end_tokens,
|
|
151
|
+
const common_chat_msg & user_message,
|
|
152
|
+
const common_chat_msg & delta_message,
|
|
153
|
+
const std::vector<common_chat_tool> & tools,
|
|
154
|
+
const common_chat_tool_choice & tool_choice,
|
|
155
|
+
bool think = false) {
|
|
156
|
+
common_chat_templates_inputs inputs;
|
|
181
157
|
inputs.parallel_tool_calls = true;
|
|
182
|
-
inputs.messages = json::array();
|
|
183
158
|
inputs.messages.push_back(user_message);
|
|
184
159
|
inputs.tools = tools;
|
|
185
160
|
inputs.tool_choice = tool_choice;
|
|
186
|
-
|
|
161
|
+
inputs.extract_reasoning = think;
|
|
162
|
+
auto params_prefix = common_chat_templates_apply(tmpls, inputs);
|
|
187
163
|
|
|
188
164
|
inputs.messages.push_back(delta_message);
|
|
189
165
|
inputs.add_generation_prompt = false;
|
|
190
|
-
auto params_full =
|
|
166
|
+
auto params_full = common_chat_templates_apply(tmpls, inputs);
|
|
191
167
|
|
|
192
168
|
std::string prefix = params_prefix.prompt;
|
|
193
169
|
std::string full = params_full.prompt;
|
|
194
170
|
|
|
195
|
-
// Check full starts with prefix
|
|
196
|
-
if (full.find(prefix) != 0) {
|
|
197
|
-
fprintf(stderr, "Full:\n%s\n\nPrefix:\n%s\n\n", full.c_str(), prefix.c_str());
|
|
198
|
-
throw std::runtime_error("Full message does not start with prefix");
|
|
199
|
-
}
|
|
200
|
-
|
|
201
171
|
if (full == prefix) {
|
|
202
172
|
throw std::runtime_error("Full message is the same as the prefix");
|
|
203
173
|
}
|
|
204
174
|
|
|
205
|
-
|
|
175
|
+
size_t common_prefix_length = 0;
|
|
176
|
+
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
|
|
177
|
+
if (prefix[i] != full[i]) {
|
|
178
|
+
break;
|
|
179
|
+
}
|
|
180
|
+
if (prefix[i] == '<') {
|
|
181
|
+
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
|
|
182
|
+
// but it removes thinking tags for past messages.
|
|
183
|
+
// The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
|
|
184
|
+
continue;
|
|
185
|
+
}
|
|
186
|
+
common_prefix_length = i + 1;
|
|
187
|
+
}
|
|
188
|
+
auto delta = full.substr(common_prefix_length);
|
|
206
189
|
|
|
207
190
|
// Strip end tokens
|
|
208
191
|
for (const auto & end_token : end_tokens) {
|
|
@@ -221,28 +204,29 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
|
|
|
221
204
|
gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
|
|
222
205
|
the parsed message is the same as the test_message
|
|
223
206
|
*/
|
|
224
|
-
static void
|
|
225
|
-
const
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
207
|
+
static void test_templates(const struct common_chat_templates * tmpls, const std::vector<std::string> & end_tokens,
|
|
208
|
+
const common_chat_msg & test_message,
|
|
209
|
+
const std::vector<common_chat_tool> & tools = {},
|
|
210
|
+
const std::string & expected_delta = "",
|
|
211
|
+
bool expect_grammar_triggered = true,
|
|
212
|
+
bool test_grammar_if_triggered = true,
|
|
213
|
+
bool think = false) {
|
|
214
|
+
common_chat_msg user_message;
|
|
215
|
+
user_message.role = "user";
|
|
216
|
+
user_message.content = "Hello, world!";
|
|
217
|
+
|
|
218
|
+
for (const auto & tool_choice : std::vector<common_chat_tool_choice> {COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED}) {
|
|
219
|
+
auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice, think);
|
|
236
220
|
if (!expected_delta.empty()) {
|
|
237
221
|
assert_equals(expected_delta, data.delta);
|
|
238
222
|
}
|
|
239
223
|
|
|
240
224
|
if (expect_grammar_triggered) {
|
|
241
225
|
const auto msg = common_chat_parse(data.delta, data.params.format);
|
|
242
|
-
assert_msg_equals(
|
|
226
|
+
assert_msg_equals(test_message, msg);
|
|
243
227
|
}
|
|
244
228
|
|
|
245
|
-
if (!
|
|
229
|
+
if (!test_message.tool_calls.empty()) {
|
|
246
230
|
GGML_ASSERT(!data.params.grammar.empty());
|
|
247
231
|
}
|
|
248
232
|
if (!data.params.grammar.empty()) {
|
|
@@ -253,12 +237,35 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
|
|
253
237
|
auto earliest_trigger_pos = std::string::npos;
|
|
254
238
|
auto constrained = data.delta;
|
|
255
239
|
for (const auto & trigger : data.params.grammar_triggers) {
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
240
|
+
size_t pos = std::string::npos;
|
|
241
|
+
std::smatch match;
|
|
242
|
+
switch (trigger.type) {
|
|
243
|
+
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
|
244
|
+
{
|
|
245
|
+
const auto & word = trigger.value;
|
|
246
|
+
pos = constrained.find(word);
|
|
247
|
+
break;
|
|
248
|
+
}
|
|
249
|
+
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
|
250
|
+
{
|
|
251
|
+
const auto & pattern = trigger.value;
|
|
252
|
+
if (std::regex_search(constrained, match, std::regex(pattern))) {
|
|
253
|
+
pos = match.position();
|
|
254
|
+
}
|
|
255
|
+
break;
|
|
256
|
+
}
|
|
257
|
+
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
|
|
258
|
+
{
|
|
259
|
+
const auto & pattern = trigger.value;
|
|
260
|
+
if (std::regex_search(constrained, match, std::regex(pattern)) && match.position() == 0) {
|
|
261
|
+
pos = 0;
|
|
262
|
+
}
|
|
263
|
+
break;
|
|
264
|
+
}
|
|
265
|
+
default:
|
|
266
|
+
throw std::runtime_error("Unknown trigger type");
|
|
259
267
|
}
|
|
260
|
-
if (pos
|
|
261
|
-
fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str());
|
|
268
|
+
if (pos == std::string::npos) {
|
|
262
269
|
continue;
|
|
263
270
|
}
|
|
264
271
|
if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
|
|
@@ -274,161 +281,363 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
|
|
274
281
|
assert_equals(expect_grammar_triggered, grammar_triggered);
|
|
275
282
|
}
|
|
276
283
|
|
|
277
|
-
if (grammar_triggered && !match_string(constrained, grammar.get())) {
|
|
284
|
+
if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) {
|
|
278
285
|
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
|
|
279
|
-
|
|
286
|
+
"\n\nConstrained: " + constrained +
|
|
287
|
+
"\n\nGrammar: " + data.params.grammar);
|
|
280
288
|
}
|
|
281
289
|
}
|
|
282
290
|
}
|
|
283
291
|
}
|
|
284
292
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
}
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
json tool_call_message {
|
|
296
|
-
{ "role", "assistant"},
|
|
297
|
-
{ "content", {}},
|
|
298
|
-
{ "tool_calls", {
|
|
299
|
-
{
|
|
300
|
-
{ "type", "function" },
|
|
301
|
-
{ "function", {
|
|
302
|
-
{ "name", "special_function" },
|
|
303
|
-
{ "arguments", "{\"arg1\": 1}" },
|
|
304
|
-
}},
|
|
305
|
-
},
|
|
306
|
-
}},
|
|
307
|
-
};
|
|
308
|
-
json tool_call_message_with_id {
|
|
309
|
-
{ "role", "assistant"},
|
|
310
|
-
{ "content", {}},
|
|
311
|
-
{ "tool_calls", {
|
|
312
|
-
{
|
|
313
|
-
{ "type", "function" },
|
|
314
|
-
{ "function", {
|
|
315
|
-
{ "name", "special_function" },
|
|
316
|
-
{ "arguments", "{\"arg1\": 1}" },
|
|
317
|
-
}},
|
|
318
|
-
{"id", "123456789"},
|
|
319
|
-
},
|
|
320
|
-
}},
|
|
321
|
-
{ "role", "assistant" },
|
|
322
|
-
{ "content", {} },
|
|
323
|
-
{ "tool_calls", tool_calls }
|
|
324
|
-
};
|
|
325
|
-
json tool_call_plan_message_with_idx {
|
|
326
|
-
{ "role", "assistant"},
|
|
327
|
-
{ "content", {}},
|
|
328
|
-
{ "tool_plan", "I'm not so sure"},
|
|
329
|
-
{ "tool_calls", {
|
|
330
|
-
{
|
|
331
|
-
{ "type", "function" },
|
|
332
|
-
{ "function", {
|
|
333
|
-
{ "name", "special_function" },
|
|
334
|
-
{ "arguments", "{\"arg1\": 1}" },
|
|
335
|
-
}},
|
|
336
|
-
// Index of the tool call in the tool_calls array
|
|
337
|
-
{"id", "0"},
|
|
338
|
-
},
|
|
339
|
-
}},
|
|
340
|
-
{ "role", "assistant" },
|
|
341
|
-
{ "content", {} },
|
|
342
|
-
{ "tool_calls", tool_calls }
|
|
343
|
-
};
|
|
293
|
+
const common_chat_msg message_user {
|
|
294
|
+
"user",
|
|
295
|
+
"Hey there!",
|
|
296
|
+
/* .content_parts = */ {},
|
|
297
|
+
/* .tool_calls = */ {},
|
|
298
|
+
/* .reasoning_content = */ "",
|
|
299
|
+
/* .tool_name = */ "",
|
|
300
|
+
/* .tool_call_id = */ "",
|
|
301
|
+
};
|
|
344
302
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
303
|
+
const common_chat_msg message_user_parts {
|
|
304
|
+
"user",
|
|
305
|
+
/* .content = */ "",
|
|
306
|
+
/* .content_parts = */ {
|
|
307
|
+
{ "text", "Hey" },
|
|
308
|
+
{ "text", "there" },
|
|
309
|
+
},
|
|
310
|
+
/* .tool_calls = */ {},
|
|
311
|
+
/* .reasoning_content = */ "",
|
|
312
|
+
/* .tool_name = */ "",
|
|
313
|
+
/* .tool_call_id = */ "",
|
|
314
|
+
};
|
|
315
|
+
const common_chat_msg message_assist {
|
|
316
|
+
"assistant",
|
|
317
|
+
"Hello, world!\nWhat's up?",
|
|
318
|
+
/* .content_parts = */ {},
|
|
319
|
+
/* .tool_calls = */ {},
|
|
320
|
+
/* .reasoning_content = */ "",
|
|
321
|
+
/* .tool_name = */ "",
|
|
322
|
+
/* .tool_call_id = */ "",
|
|
323
|
+
};
|
|
324
|
+
const common_chat_msg message_assist_thoughts_unparsed_think {
|
|
325
|
+
"assistant",
|
|
326
|
+
"<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
|
327
|
+
/* .content_parts = */ {},
|
|
328
|
+
/* .tool_calls = */ {},
|
|
329
|
+
/* .reasoning_content = */ "",
|
|
330
|
+
/* .tool_name = */ "",
|
|
331
|
+
/* .tool_call_id = */ "",
|
|
332
|
+
};
|
|
333
|
+
const common_chat_msg message_assist_thoughts_unparsed_r7b {
|
|
334
|
+
"assistant",
|
|
335
|
+
"<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?",
|
|
336
|
+
/* .content_parts = */ {},
|
|
337
|
+
/* .tool_calls = */ {},
|
|
338
|
+
/* .reasoning_content = */ "",
|
|
339
|
+
/* .tool_name = */ "",
|
|
340
|
+
/* .tool_call_id = */ "",
|
|
341
|
+
};
|
|
342
|
+
const common_chat_msg message_assist_thoughts {
|
|
343
|
+
"assistant",
|
|
344
|
+
"Hello, world!\nWhat's up?",
|
|
345
|
+
/* .content_parts = */ {},
|
|
346
|
+
/* .tool_calls = */ {},
|
|
347
|
+
/* .reasoning_content = */ "I'm thinking",
|
|
348
|
+
/* .tool_name = */ "",
|
|
349
|
+
/* .tool_call_id = */ "",
|
|
350
|
+
};
|
|
351
|
+
const std::vector<common_chat_tool_call> tool_calls {
|
|
352
|
+
{ "special_function", "{\"arg1\": 1}", /* .id = */ "" },
|
|
353
|
+
};
|
|
354
|
+
const std::vector<common_chat_tool_call> tool_calls_idx {
|
|
355
|
+
{ "special_function", "{\"arg1\": 1}", /* .id = */ "0" },
|
|
356
|
+
};
|
|
357
|
+
const std::vector<common_chat_tool_call> tool_calls_id {
|
|
358
|
+
{ "special_function", "{\"arg1\": 1}", /* .id = */ "123456789" },
|
|
359
|
+
};
|
|
360
|
+
|
|
361
|
+
const common_chat_msg message_assist_call {
|
|
362
|
+
"assistant",
|
|
363
|
+
"",
|
|
364
|
+
/* .content_parts = */ {},
|
|
365
|
+
tool_calls,
|
|
366
|
+
/* .reasoning_content = */ "",
|
|
367
|
+
/* .tool_name = */ "",
|
|
368
|
+
/* .tool_call_id = */ "",
|
|
369
|
+
};
|
|
370
|
+
const common_chat_msg message_assist_call_thoughts = {
|
|
371
|
+
"assistant",
|
|
372
|
+
/* .content = */ "",
|
|
373
|
+
/* .content_parts = */ {},
|
|
374
|
+
tool_calls,
|
|
375
|
+
/* .reasoning_content = */ "I'm\nthinking",
|
|
376
|
+
/* .tool_name = */ "",
|
|
377
|
+
/* .tool_call_id = */ "",
|
|
378
|
+
};
|
|
379
|
+
const common_chat_msg message_assist_call_thoughts_unparsed = {
|
|
380
|
+
"assistant",
|
|
381
|
+
/* .content = */ "<think>I'm\nthinking</think>",
|
|
382
|
+
/* .content_parts = */ {},
|
|
383
|
+
tool_calls,
|
|
384
|
+
/* .reasoning_content = */ "",
|
|
385
|
+
/* .tool_name = */ "",
|
|
386
|
+
/* .tool_call_id = */ "",
|
|
387
|
+
};
|
|
388
|
+
const common_chat_msg message_assist_call_id {
|
|
389
|
+
"assistant",
|
|
390
|
+
"",
|
|
391
|
+
/* .content_parts = */ {},
|
|
392
|
+
tool_calls_id,
|
|
393
|
+
/* .reasoning_content = */ "",
|
|
394
|
+
/* .tool_name = */ "",
|
|
395
|
+
/* .tool_call_id = */ "",
|
|
396
|
+
};
|
|
397
|
+
const common_chat_msg message_assist_call_idx {
|
|
398
|
+
"assistant",
|
|
399
|
+
"",
|
|
400
|
+
/* .content_parts = */ {},
|
|
401
|
+
tool_calls_idx,
|
|
402
|
+
/* .reasoning_content = */ "",
|
|
403
|
+
/* .tool_name = */ "",
|
|
404
|
+
/* .tool_call_id = */ "",
|
|
405
|
+
};
|
|
406
|
+
const common_chat_msg message_assist_call_python {
|
|
407
|
+
"assistant",
|
|
408
|
+
"",
|
|
409
|
+
/* .content_parts = */ {},
|
|
410
|
+
{ { "python", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
|
|
411
|
+
/* .reasoning_content = */ "",
|
|
412
|
+
/* .tool_name = */ "",
|
|
413
|
+
/* .tool_call_id = */ "",
|
|
414
|
+
};
|
|
415
|
+
const common_chat_msg message_assist_call_code_interpreter {
|
|
416
|
+
"assistant",
|
|
417
|
+
"",
|
|
418
|
+
/* .content_parts = */ {},
|
|
419
|
+
{ { "code_interpreter", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
|
|
420
|
+
/* .reasoning_content = */ "",
|
|
421
|
+
/* .tool_name = */ "",
|
|
422
|
+
/* .tool_call_id = */ "",
|
|
423
|
+
};
|
|
424
|
+
|
|
425
|
+
static void test_msgs_oaicompat_json_conversion() {
|
|
426
|
+
std::vector<common_chat_msg> msgs{
|
|
427
|
+
message_user,
|
|
428
|
+
message_user_parts,
|
|
429
|
+
message_assist_call,
|
|
430
|
+
message_assist_call_thoughts,
|
|
431
|
+
message_assist_call_thoughts_unparsed,
|
|
432
|
+
message_assist_call_id,
|
|
433
|
+
message_assist_call_idx,
|
|
434
|
+
message_assist_call_python,
|
|
435
|
+
message_assist_call_code_interpreter,
|
|
374
436
|
};
|
|
437
|
+
for (const auto & msg : msgs) {
|
|
438
|
+
auto oai_json = common_chat_msgs_to_json_oaicompat<json>({msg});
|
|
439
|
+
auto msgs2 = common_chat_msgs_parse_oaicompat(oai_json);
|
|
440
|
+
assert_equals((size_t) 1, msgs2.size());
|
|
441
|
+
auto msg2 = msgs2[0];
|
|
442
|
+
assert_msg_equals(msg, msg2);
|
|
443
|
+
}
|
|
444
|
+
assert_equals(
|
|
445
|
+
std::string(
|
|
446
|
+
"[\n"
|
|
447
|
+
" {\n"
|
|
448
|
+
" \"role\": \"user\",\n"
|
|
449
|
+
" \"content\": [\n"
|
|
450
|
+
" {\n"
|
|
451
|
+
" \"type\": \"text\",\n"
|
|
452
|
+
" \"text\": \"Hey\"\n"
|
|
453
|
+
" },\n"
|
|
454
|
+
" {\n"
|
|
455
|
+
" \"type\": \"text\",\n"
|
|
456
|
+
" \"text\": \"there\"\n"
|
|
457
|
+
" }\n"
|
|
458
|
+
" ]\n"
|
|
459
|
+
" }\n"
|
|
460
|
+
"]"
|
|
461
|
+
),
|
|
462
|
+
common_chat_msgs_to_json_oaicompat<json>({message_user_parts}).dump(2));
|
|
463
|
+
|
|
464
|
+
assert_equals(
|
|
465
|
+
std::string(
|
|
466
|
+
"[\n"
|
|
467
|
+
" {\n"
|
|
468
|
+
" \"role\": \"assistant\",\n"
|
|
469
|
+
" \"content\": null,\n"
|
|
470
|
+
" \"tool_calls\": [\n"
|
|
471
|
+
" {\n"
|
|
472
|
+
" \"type\": \"function\",\n"
|
|
473
|
+
" \"function\": {\n"
|
|
474
|
+
" \"name\": \"python\",\n"
|
|
475
|
+
" \"arguments\": \"{\\\"code\\\": \\\"print('hey')\\\"}\"\n"
|
|
476
|
+
" }\n"
|
|
477
|
+
" }\n"
|
|
478
|
+
" ]\n"
|
|
479
|
+
" }\n"
|
|
480
|
+
"]"
|
|
481
|
+
),
|
|
482
|
+
common_chat_msgs_to_json_oaicompat<json>({message_assist_call_python}).dump(2));
|
|
483
|
+
|
|
484
|
+
auto res = common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\", \"tool_calls\": []}]"));
|
|
485
|
+
assert_equals<size_t>(1, res.size());
|
|
486
|
+
assert_equals<std::string>(res[0].role, "assistant");
|
|
487
|
+
assert_equals(true, res[0].content.empty());
|
|
488
|
+
assert_equals(true, res[0].tool_calls.empty());
|
|
489
|
+
|
|
490
|
+
try {
|
|
491
|
+
common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\"}]"));
|
|
492
|
+
throw std::runtime_error("Expected exception");
|
|
493
|
+
} catch (const std::exception & e) {
|
|
494
|
+
if (std::string(e.what()).find("'content'") == std::string::npos) {
|
|
495
|
+
throw std::runtime_error("Expected exception about missing 'content'");
|
|
496
|
+
}
|
|
497
|
+
}
|
|
498
|
+
}
|
|
375
499
|
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
500
|
+
static void test_tools_oaicompat_json_conversion() {
|
|
501
|
+
std::vector<common_chat_tool> tools{
|
|
502
|
+
special_function_tool,
|
|
503
|
+
python_tool,
|
|
504
|
+
code_interpreter_tool,
|
|
379
505
|
};
|
|
380
506
|
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
507
|
+
for (const auto & tool : tools) {
|
|
508
|
+
auto oai_json = common_chat_tools_to_json_oaicompat<json>({tool});
|
|
509
|
+
auto tools2 = common_chat_tools_parse_oaicompat(oai_json);
|
|
510
|
+
assert_equals((size_t) 1, tools2.size());
|
|
511
|
+
auto tool2 = tools2[0];
|
|
512
|
+
assert_equals(tool.name, tool2.name);
|
|
513
|
+
assert_equals(tool.description, tool2.description);
|
|
514
|
+
assert_equals(json::parse(tool.parameters).dump(2), json::parse(tool2.parameters).dump(2));
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
assert_equals(
|
|
518
|
+
std::string(
|
|
519
|
+
"[\n"
|
|
520
|
+
" {\n"
|
|
521
|
+
" \"type\": \"function\",\n"
|
|
522
|
+
" \"function\": {\n"
|
|
523
|
+
" \"name\": \"special_function\",\n"
|
|
524
|
+
" \"description\": \"I'm special\",\n"
|
|
525
|
+
" \"parameters\": {\n"
|
|
526
|
+
" \"type\": \"object\",\n"
|
|
527
|
+
" \"properties\": {\n"
|
|
528
|
+
" \"arg1\": {\n"
|
|
529
|
+
" \"type\": \"integer\",\n"
|
|
530
|
+
" \"description\": \"The arg.\"\n"
|
|
531
|
+
" }\n"
|
|
532
|
+
" },\n"
|
|
533
|
+
" \"required\": [\n"
|
|
534
|
+
" \"arg1\"\n"
|
|
535
|
+
" ]\n"
|
|
536
|
+
" }\n"
|
|
537
|
+
" }\n"
|
|
538
|
+
" }\n"
|
|
539
|
+
"]"
|
|
540
|
+
),
|
|
541
|
+
common_chat_tools_to_json_oaicompat<json>({special_function_tool}).dump(2));
|
|
542
|
+
}
|
|
543
|
+
|
|
544
|
+
static void test_template_output_parsers() {
|
|
545
|
+
|
|
546
|
+
common_chat_templates_inputs inputs_no_tools;
|
|
547
|
+
inputs_no_tools.messages = {message_user};
|
|
548
|
+
inputs_no_tools.extract_reasoning = false;
|
|
549
|
+
|
|
550
|
+
common_chat_templates_inputs inputs_no_tools_think;
|
|
551
|
+
inputs_no_tools_think.messages = {message_user};
|
|
552
|
+
inputs_no_tools_think.extract_reasoning = true;
|
|
553
|
+
|
|
554
|
+
common_chat_templates_inputs inputs_tools;
|
|
555
|
+
inputs_tools.messages = {message_user};
|
|
556
|
+
inputs_tools.tools = {special_function_tool};
|
|
557
|
+
inputs_tools.extract_reasoning = false;
|
|
384
558
|
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
559
|
+
common_chat_templates_inputs inputs_tools_think;
|
|
560
|
+
inputs_tools_think.messages = {message_user};
|
|
561
|
+
inputs_tools_think.tools = {special_function_tool};
|
|
562
|
+
inputs_tools_think.extract_reasoning = true;
|
|
563
|
+
|
|
564
|
+
common_chat_templates_inputs inputs_tools_builtin;
|
|
565
|
+
inputs_tools_builtin.messages = {message_user};
|
|
566
|
+
inputs_tools_builtin.tools = {python_tool};
|
|
567
|
+
inputs_tools_builtin.extract_reasoning = false;
|
|
388
568
|
|
|
389
569
|
{
|
|
390
570
|
// Not supported yet
|
|
391
|
-
|
|
392
|
-
assert_equals(COMMON_CHAT_FORMAT_GENERIC,
|
|
571
|
+
auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja");
|
|
572
|
+
assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
393
573
|
}
|
|
394
574
|
{
|
|
395
|
-
|
|
575
|
+
auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja");
|
|
396
576
|
std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
|
|
397
577
|
|
|
398
|
-
assert_equals(
|
|
399
|
-
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B,
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
578
|
+
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
579
|
+
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
580
|
+
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
|
|
581
|
+
|
|
582
|
+
assert_msg_equals(message_assist,
|
|
583
|
+
common_chat_parse(
|
|
584
|
+
"Hello, world!\nWhat's up?",
|
|
585
|
+
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
|
586
|
+
assert_msg_equals(message_assist,
|
|
587
|
+
common_chat_parse(
|
|
588
|
+
"Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
|
589
|
+
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
|
590
|
+
assert_msg_equals(message_assist,
|
|
591
|
+
common_chat_parse(
|
|
592
|
+
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
|
593
|
+
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
|
594
|
+
assert_msg_equals(message_assist_thoughts_unparsed_r7b,
|
|
595
|
+
common_chat_parse(
|
|
596
|
+
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
|
597
|
+
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
|
598
|
+
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
|
599
|
+
assert_msg_equals(message_assist_thoughts_unparsed_r7b,
|
|
600
|
+
common_chat_parse(
|
|
601
|
+
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
|
602
|
+
"Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
|
603
|
+
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
|
604
|
+
|
|
605
|
+
assert_msg_equals(message_assist_thoughts,
|
|
606
|
+
common_chat_parse(
|
|
607
|
+
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
|
608
|
+
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
|
609
|
+
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING));
|
|
610
|
+
|
|
611
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call_idx, tools,
|
|
612
|
+
"<|START_THINKING|><|END_THINKING|>"
|
|
403
613
|
"<|START_ACTION|>[\n"
|
|
404
614
|
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
|
|
405
615
|
"]<|END_ACTION|>");
|
|
406
|
-
|
|
616
|
+
test_templates(tmpls.get(), end_tokens, message_assist, tools,
|
|
407
617
|
"<|START_RESPONSE|>Hello, world!\n"
|
|
408
618
|
"What's up?<|END_RESPONSE|>",
|
|
409
619
|
/* expect_grammar_triggered= */ false);
|
|
410
620
|
}
|
|
411
621
|
{
|
|
412
|
-
|
|
622
|
+
auto tmpls = read_templates("models/templates/google-gemma-2-2b-it.jinja");
|
|
413
623
|
std::vector<std::string> end_tokens{ "<end_of_turn>" };
|
|
414
624
|
|
|
415
|
-
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
|
416
|
-
assert_equals(COMMON_CHAT_FORMAT_GENERIC,
|
|
625
|
+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
626
|
+
assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
417
627
|
assert_equals(COMMON_CHAT_FORMAT_GENERIC,
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
"<s>", "</s>"),
|
|
628
|
+
common_chat_templates_apply(
|
|
629
|
+
read_templates("models/templates/microsoft-Phi-3.5-mini-instruct.jinja").get(),
|
|
421
630
|
inputs_tools)
|
|
422
631
|
.format);
|
|
423
632
|
|
|
424
633
|
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
|
425
634
|
|
|
426
|
-
assert_msg_equals(
|
|
635
|
+
assert_msg_equals(message_assist,
|
|
427
636
|
common_chat_parse("{\n"
|
|
428
637
|
" \"response\": \"Hello, world!\\nWhat's up?\"\n"
|
|
429
638
|
"}",
|
|
430
|
-
|
|
431
|
-
|
|
639
|
+
common_chat_templates_apply(tmpls.get(), inputs_tools).format));
|
|
640
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools,
|
|
432
641
|
"{\n"
|
|
433
642
|
" \"tool_calls\": [\n"
|
|
434
643
|
" {\n"
|
|
@@ -442,166 +651,323 @@ static void test_template_output_parsers() {
|
|
|
442
651
|
"}");
|
|
443
652
|
}
|
|
444
653
|
{
|
|
445
|
-
|
|
446
|
-
"</s>");
|
|
654
|
+
auto tmpls = read_templates("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja");
|
|
447
655
|
std::vector<std::string> end_tokens{ "</s>" };
|
|
448
656
|
|
|
449
|
-
assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO,
|
|
657
|
+
assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
450
658
|
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
659
|
+
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
660
|
+
test_templates(
|
|
661
|
+
tmpls.get(), end_tokens, message_assist_call_id, tools,
|
|
454
662
|
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
|
|
455
663
|
}
|
|
456
664
|
{
|
|
457
|
-
|
|
458
|
-
read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
|
665
|
+
auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
|
|
459
666
|
std::vector<std::string> end_tokens{ "<|im_end|>" };
|
|
460
667
|
|
|
461
|
-
assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
|
668
|
+
assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
462
669
|
assert_equals(
|
|
463
670
|
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
"<s>", "</s>"),
|
|
671
|
+
common_chat_templates_apply(
|
|
672
|
+
read_templates("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja").get(),
|
|
467
673
|
inputs_tools)
|
|
468
674
|
.format);
|
|
469
675
|
assert_equals(
|
|
470
676
|
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
|
471
|
-
|
|
472
|
-
|
|
677
|
+
common_chat_templates_apply(
|
|
678
|
+
read_templates("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja").get(),
|
|
473
679
|
inputs_tools)
|
|
474
680
|
.format);
|
|
475
681
|
|
|
476
|
-
|
|
477
|
-
|
|
682
|
+
// Test parsing
|
|
683
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
684
|
+
"<tool_call>\n"
|
|
685
|
+
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
686
|
+
"</tool_call>",
|
|
687
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
688
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
689
|
+
"<function=special_function>{\"arg1\": 1}</function>",
|
|
690
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
691
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
692
|
+
"<function name=\"special_function\">\n"
|
|
693
|
+
"{\"arg1\": 1}\n"
|
|
694
|
+
"</function>",
|
|
695
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
696
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
697
|
+
"<tool>\n"
|
|
698
|
+
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
699
|
+
"</tool>",
|
|
700
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
701
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
702
|
+
"<tools>\n"
|
|
703
|
+
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
704
|
+
"</tools>",
|
|
705
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
706
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
707
|
+
"<response>\n"
|
|
708
|
+
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
709
|
+
"</response>",
|
|
710
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
711
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
712
|
+
"```xml\n"
|
|
713
|
+
"<response>\n"
|
|
714
|
+
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
715
|
+
"</response>\n"
|
|
716
|
+
"```",
|
|
717
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
718
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
719
|
+
"```xml\n"
|
|
720
|
+
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
721
|
+
"```",
|
|
722
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
723
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
724
|
+
"```\n"
|
|
725
|
+
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
726
|
+
"```",
|
|
727
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
728
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
729
|
+
"```\n"
|
|
730
|
+
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
731
|
+
"```",
|
|
732
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
733
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
734
|
+
"```json\n"
|
|
735
|
+
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
736
|
+
"```",
|
|
737
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
738
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
739
|
+
"```json\n"
|
|
740
|
+
"\n"
|
|
741
|
+
" <function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}} \n"
|
|
742
|
+
" </function_call> \n"
|
|
743
|
+
"``` ",
|
|
744
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
745
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
746
|
+
"<json>\n"
|
|
747
|
+
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
748
|
+
"</json>",
|
|
749
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
750
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
751
|
+
"<xml>\n"
|
|
752
|
+
" {\n"
|
|
753
|
+
" \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}\n"
|
|
754
|
+
" }\n"
|
|
755
|
+
"</xml>",
|
|
756
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
757
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
758
|
+
"<JSON>\n"
|
|
759
|
+
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
760
|
+
"</JSON>",
|
|
761
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
762
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
763
|
+
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
|
|
764
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
765
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
766
|
+
"{\n \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
|
|
767
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
768
|
+
|
|
769
|
+
assert_msg_equals(message_assist_thoughts_unparsed_think,
|
|
770
|
+
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
|
771
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
772
|
+
assert_msg_equals(message_assist_thoughts_unparsed_think,
|
|
773
|
+
common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
|
|
774
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
775
|
+
assert_msg_equals(message_assist_thoughts,
|
|
776
|
+
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
|
777
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING));
|
|
778
|
+
assert_msg_equals(message_assist_thoughts,
|
|
779
|
+
common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
|
|
780
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING));
|
|
781
|
+
|
|
782
|
+
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
783
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
|
478
784
|
"<tool_call>\n"
|
|
479
785
|
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
480
786
|
"</tool_call>");
|
|
481
|
-
|
|
787
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools,
|
|
482
788
|
"<tool_call>\n"
|
|
483
789
|
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
|
|
484
790
|
"</tool_call>");
|
|
485
791
|
}
|
|
486
792
|
{
|
|
487
|
-
|
|
488
|
-
"</s>");
|
|
793
|
+
auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja");
|
|
489
794
|
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
|
|
490
795
|
|
|
491
|
-
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X,
|
|
796
|
+
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
492
797
|
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
|
493
|
-
|
|
798
|
+
common_chat_templates_apply(tmpls.get(), inputs_tools_builtin).format);
|
|
494
799
|
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
"<s>", "</s>"),
|
|
800
|
+
common_chat_templates_apply(
|
|
801
|
+
read_templates("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja").get(),
|
|
498
802
|
inputs_tools_builtin)
|
|
499
803
|
.format);
|
|
500
804
|
|
|
501
|
-
//
|
|
502
|
-
|
|
805
|
+
// test_templates(tmpls.get(), end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
|
|
806
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call_code_interpreter, llama_3_1_tools,
|
|
503
807
|
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
|
|
504
|
-
|
|
808
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools,
|
|
505
809
|
"<|python_tag|>python.call(code=\"print('hey')\")");
|
|
506
|
-
|
|
810
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
|
507
811
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
|
508
812
|
}
|
|
509
813
|
{
|
|
510
|
-
|
|
511
|
-
"</s>");
|
|
814
|
+
auto tmpls = read_templates("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja");
|
|
512
815
|
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
|
|
513
816
|
|
|
514
|
-
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X,
|
|
817
|
+
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
515
818
|
|
|
516
|
-
|
|
517
|
-
|
|
819
|
+
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
820
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
|
518
821
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
|
519
822
|
}
|
|
520
823
|
{
|
|
521
|
-
|
|
522
|
-
"</s>");
|
|
824
|
+
auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja");
|
|
523
825
|
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
|
|
524
826
|
|
|
525
827
|
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
|
526
|
-
|
|
828
|
+
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
527
829
|
|
|
528
|
-
|
|
529
|
-
|
|
830
|
+
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
831
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
|
530
832
|
"<function=special_function>{\"arg1\": 1}</function>");
|
|
531
833
|
}
|
|
532
834
|
{
|
|
533
|
-
|
|
534
|
-
"</s>");
|
|
835
|
+
auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.2.jinja");
|
|
535
836
|
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
|
|
536
837
|
|
|
537
|
-
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
|
538
|
-
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
|
838
|
+
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
839
|
+
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
539
840
|
|
|
540
|
-
|
|
841
|
+
test_templates(tmpls.get(), end_tokens, message_assist, {},
|
|
541
842
|
"all\n"
|
|
542
843
|
"Hello, world!\n"
|
|
543
844
|
"What's up?",
|
|
544
845
|
/* expect_grammar_triggered= */ false);
|
|
545
|
-
|
|
846
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
|
546
847
|
"special_function\n"
|
|
547
848
|
"{\"arg1\": 1}");
|
|
548
849
|
}
|
|
549
850
|
{
|
|
550
|
-
|
|
551
|
-
"</s>");
|
|
851
|
+
auto tmpls = read_templates("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja");
|
|
552
852
|
std::vector<std::string> end_tokens{ "<|eot_id|>" };
|
|
553
853
|
|
|
554
|
-
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
|
854
|
+
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
555
855
|
|
|
556
|
-
|
|
557
|
-
|
|
856
|
+
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
857
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
|
558
858
|
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
|
559
859
|
}
|
|
560
860
|
{
|
|
561
|
-
|
|
562
|
-
|
|
861
|
+
// Original DeepSeek R1 template. Leaves <|tool▁calls▁begin|> and others unclosed. Our logic fixes the prompt.
|
|
862
|
+
auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja");
|
|
563
863
|
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
|
|
564
864
|
|
|
565
|
-
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
|
865
|
+
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
866
|
+
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
|
|
867
|
+
|
|
868
|
+
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
869
|
+
test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
870
|
+
assert_msg_equals(message_assist_thoughts_unparsed_think,
|
|
871
|
+
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
|
872
|
+
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
|
873
|
+
assert_msg_equals(message_assist_thoughts,
|
|
874
|
+
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
|
875
|
+
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
|
|
876
|
+
assert_msg_equals(message_assist_thoughts,
|
|
877
|
+
// Latest template update (ast of 20250209) adds a trailing <think>\n if add_generation_prompt is true.
|
|
878
|
+
common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
|
|
879
|
+
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
|
|
880
|
+
// test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
|
881
|
+
// "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
|
882
|
+
// "```json\n"
|
|
883
|
+
// "{\"arg1\": 1}\n"
|
|
884
|
+
// // Look what's not here: <|tool▁calls▁end|> (also missing the <|end▁of▁sentence|>, but that is removed lazily by the test's delta logic)
|
|
885
|
+
// "```<|tool▁call▁end|>",
|
|
886
|
+
// /* expect_grammar_triggered= */ true,
|
|
887
|
+
// /* test_grammar_if_triggered= */ false);
|
|
888
|
+
}
|
|
889
|
+
{
|
|
890
|
+
// Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all.
|
|
891
|
+
auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja");
|
|
892
|
+
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
|
|
566
893
|
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
894
|
+
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
895
|
+
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
|
|
896
|
+
|
|
897
|
+
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
898
|
+
test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
899
|
+
assert_msg_equals(message_assist_thoughts_unparsed_think,
|
|
900
|
+
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
|
901
|
+
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
|
902
|
+
assert_msg_equals(message_assist_thoughts,
|
|
903
|
+
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
|
904
|
+
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
|
|
905
|
+
|
|
906
|
+
assert_msg_equals(message_assist_call_thoughts_unparsed,
|
|
907
|
+
common_chat_parse(
|
|
908
|
+
"<think>I'm\nthinking</think>\n\n"
|
|
909
|
+
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
|
910
|
+
"```json\n"
|
|
911
|
+
"{\"arg1\": 1}\n"
|
|
912
|
+
"```<|tool▁call▁end|><|tool▁calls▁end|>",
|
|
913
|
+
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
|
914
|
+
assert_msg_equals(message_assist_call_thoughts,
|
|
915
|
+
common_chat_parse(
|
|
916
|
+
"<think>I'm\nthinking</think>\n\n"
|
|
917
|
+
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
|
918
|
+
"```json\n"
|
|
919
|
+
"{\"arg1\": 1}\n"
|
|
920
|
+
"```<|tool▁call▁end|><|tool▁calls▁end|>",
|
|
921
|
+
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
|
|
922
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
|
923
|
+
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
|
924
|
+
"```json\n"
|
|
925
|
+
"{\"arg1\": 1}\n"
|
|
926
|
+
"```<|tool▁call▁end|><|tool▁calls▁end|>");
|
|
573
927
|
}
|
|
574
928
|
}
|
|
575
929
|
|
|
576
930
|
int main(int argc, char ** argv) {
|
|
931
|
+
// try {
|
|
577
932
|
#ifndef _WIN32
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
933
|
+
if (argc > 1) {
|
|
934
|
+
common_chat_templates_inputs inputs;
|
|
935
|
+
common_chat_msg msg;
|
|
936
|
+
msg.role = "user";
|
|
937
|
+
msg.content = "Hey";
|
|
938
|
+
inputs.messages = {msg};
|
|
939
|
+
inputs.tools = { special_function_tool };
|
|
940
|
+
|
|
941
|
+
std::cout << "| Template | Format |\n";
|
|
942
|
+
std::cout << "|----------|--------|\n";
|
|
943
|
+
|
|
944
|
+
for (int i = 1; i < argc; i++) {
|
|
945
|
+
try {
|
|
946
|
+
std::string path = argv[i];
|
|
947
|
+
if (path.rfind(".jinja") != path.size() - 6) {
|
|
948
|
+
std::cerr << "Skipping non-jinja file: " << path << '\n';
|
|
949
|
+
continue;
|
|
950
|
+
}
|
|
951
|
+
auto tmpls = read_templates(path);
|
|
952
|
+
auto parts = string_split(path, "/");
|
|
953
|
+
auto name = parts[parts.size() - 1];
|
|
954
|
+
auto format = common_chat_format_name(common_chat_templates_apply(tmpls.get(), inputs).format);
|
|
955
|
+
std::cout << "| " << name << " | " << format << " |\n";
|
|
956
|
+
} catch (const std::exception & e) {
|
|
957
|
+
std::cerr << "Failed to process " << argv[i] << ": " << e.what() << '\n';
|
|
958
|
+
}
|
|
593
959
|
}
|
|
594
|
-
|
|
595
|
-
auto parts = string_split(path, "/");
|
|
596
|
-
auto name = parts[parts.size() - 1];
|
|
597
|
-
std::cout << "| " << name << " | " << common_chat_format_name(common_chat_params_init(tmpl, inputs).format)
|
|
598
|
-
<< " |\n";
|
|
599
|
-
}
|
|
600
|
-
} else
|
|
960
|
+
} else
|
|
601
961
|
#endif
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
962
|
+
{
|
|
963
|
+
test_msgs_oaicompat_json_conversion();
|
|
964
|
+
test_tools_oaicompat_json_conversion();
|
|
965
|
+
test_template_output_parsers();
|
|
966
|
+
std::cout << "\n[chat] All tests passed!" << '\n';
|
|
967
|
+
}
|
|
968
|
+
return 0;
|
|
969
|
+
// } catch (const std::exception & e) {
|
|
970
|
+
// std::cerr << "Error: " << e.what() << '\n';
|
|
971
|
+
// return 1;
|
|
972
|
+
// }
|
|
607
973
|
}
|