@fugood/llama.node 0.3.13 → 0.3.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +1 -1
- package/package.json +1 -1
- package/src/LlamaContext.cpp +98 -76
- package/src/LlamaContext.h +1 -1
- package/src/common.hpp +1 -2
- package/src/llama.cpp/.github/workflows/build.yml +60 -10
- package/src/llama.cpp/.github/workflows/server.yml +2 -0
- package/src/llama.cpp/common/CMakeLists.txt +3 -3
- package/src/llama.cpp/common/arg.cpp +112 -11
- package/src/llama.cpp/common/chat.cpp +960 -266
- package/src/llama.cpp/common/chat.h +135 -0
- package/src/llama.cpp/common/common.cpp +27 -171
- package/src/llama.cpp/common/common.h +27 -67
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
- package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
- package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +37 -5
- package/src/llama.cpp/common/ngram-cache.cpp +1 -0
- package/src/llama.cpp/common/sampling.cpp +45 -7
- package/src/llama.cpp/common/speculative.cpp +6 -5
- package/src/llama.cpp/common/speculative.h +1 -1
- package/src/llama.cpp/docs/build.md +45 -7
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +2 -3
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +373 -107
- package/src/llama.cpp/examples/llava/clip.h +19 -3
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
- package/src/llama.cpp/examples/llava/llava.cpp +4 -2
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
- package/src/llama.cpp/examples/main/main.cpp +73 -28
- package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
- package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
- package/src/llama.cpp/examples/run/run.cpp +110 -67
- package/src/llama.cpp/examples/server/server.cpp +82 -87
- package/src/llama.cpp/examples/server/utils.hpp +94 -107
- package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
- package/src/llama.cpp/examples/tts/tts.cpp +251 -142
- package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
- package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-cpu.h +3 -0
- package/src/llama.cpp/ggml/include/ggml.h +5 -1
- package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
- package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +151 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1396 -386
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1432 -151
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +22 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +15 -2
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +220 -116
- package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +168 -721
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +146 -42
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +13 -3
- package/src/llama.cpp/ggml/src/ggml.c +8 -3
- package/src/llama.cpp/include/llama.h +19 -5
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +1 -0
- package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
- package/src/llama.cpp/requirements.txt +1 -0
- package/src/llama.cpp/src/llama-arch.cpp +21 -0
- package/src/llama.cpp/src/llama-arch.h +1 -0
- package/src/llama.cpp/src/llama-chat.cpp +1 -0
- package/src/llama.cpp/src/llama-grammar.cpp +182 -182
- package/src/llama.cpp/src/llama-grammar.h +12 -3
- package/src/llama.cpp/src/llama-kv-cache.h +1 -0
- package/src/llama.cpp/src/llama-mmap.cpp +11 -1
- package/src/llama.cpp/src/llama-model.cpp +69 -5
- package/src/llama.cpp/src/llama-sampling.cpp +43 -10
- package/src/llama.cpp/src/llama-vocab.cpp +12 -0
- package/src/llama.cpp/src/llama.cpp +147 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +166 -110
- package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
- package/src/llama.cpp/tests/test-chat.cpp +593 -395
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
- package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
- package/src/llama.cpp/Sources/llama/llama.h +0 -4
- package/src/llama.cpp/common/chat.hpp +0 -55
- /package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +0 -0
|
@@ -10,38 +10,12 @@
|
|
|
10
10
|
#include <json.hpp>
|
|
11
11
|
#include <string>
|
|
12
12
|
|
|
13
|
-
#include "chat
|
|
14
|
-
#include "chat.hpp"
|
|
13
|
+
#include "chat.h"
|
|
15
14
|
#include "llama-grammar.h"
|
|
16
15
|
#include "unicode.h"
|
|
17
16
|
|
|
18
17
|
using json = nlohmann::ordered_json;
|
|
19
18
|
|
|
20
|
-
static common_chat_msg msg_from_json(const json & message) {
|
|
21
|
-
common_chat_msg ret;
|
|
22
|
-
ret.role = "assistant";
|
|
23
|
-
if (message.contains("content") && !message.at("content").is_null()) {
|
|
24
|
-
ret.content = message.at("content");
|
|
25
|
-
}
|
|
26
|
-
if (message.contains("tool_plan")) {
|
|
27
|
-
ret.reasoning_content = message.at("tool_plan");
|
|
28
|
-
}
|
|
29
|
-
if (message.contains("reasoning_content")) {
|
|
30
|
-
ret.reasoning_content = message.at("reasoning_content");
|
|
31
|
-
}
|
|
32
|
-
auto has_tool_calls = message.contains("tool_calls");
|
|
33
|
-
if (has_tool_calls) {
|
|
34
|
-
for (const auto & tc : message.at("tool_calls")) {
|
|
35
|
-
const auto & arguments = tc.at("function").at("arguments");
|
|
36
|
-
ret.tool_calls.push_back({
|
|
37
|
-
tc.at("function").at("name").get<std::string>(),
|
|
38
|
-
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
|
39
|
-
tc.contains("id") ? tc.at("id").get<std::string>() : "",
|
|
40
|
-
});
|
|
41
|
-
}
|
|
42
|
-
}
|
|
43
|
-
return ret;
|
|
44
|
-
}
|
|
45
19
|
|
|
46
20
|
template <class T> static void assert_equals(const T & expected, const T & actual) {
|
|
47
21
|
if (expected != actual) {
|
|
@@ -53,7 +27,7 @@ template <class T> static void assert_equals(const T & expected, const T & actua
|
|
|
53
27
|
}
|
|
54
28
|
|
|
55
29
|
static std::string read_file(const std::string & path) {
|
|
56
|
-
std::cerr << "# Reading: " << path <<
|
|
30
|
+
std::cerr << "# Reading: " << path << '\n' << std::flush;
|
|
57
31
|
std::ifstream fs(path, std::ios_base::binary);
|
|
58
32
|
if (!fs.is_open()) {
|
|
59
33
|
fs = std::ifstream("../" + path, std::ios_base::binary);
|
|
@@ -66,10 +40,14 @@ static std::string read_file(const std::string & path) {
|
|
|
66
40
|
fs.seekg(0);
|
|
67
41
|
std::string out;
|
|
68
42
|
out.resize(static_cast<size_t>(size));
|
|
69
|
-
fs.read(
|
|
43
|
+
fs.read(out.data(), static_cast<std::streamsize>(size));
|
|
70
44
|
return out;
|
|
71
45
|
}
|
|
72
46
|
|
|
47
|
+
static common_chat_templates_ptr read_templates(const std::string & path) {
|
|
48
|
+
return common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, read_file(path)));
|
|
49
|
+
}
|
|
50
|
+
|
|
73
51
|
static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
|
|
74
52
|
return std::unique_ptr<llama_grammar>(
|
|
75
53
|
llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0));
|
|
@@ -90,110 +68,102 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
|
|
|
90
68
|
}
|
|
91
69
|
}
|
|
92
70
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
return true;
|
|
97
|
-
}
|
|
71
|
+
if (std::any_of(stacks_cur.begin(), stacks_cur.end(), [](const auto & stack) { return stack.empty(); })) {
|
|
72
|
+
// An empty stack means that the grammar has been completed
|
|
73
|
+
return true;
|
|
98
74
|
}
|
|
99
75
|
|
|
100
76
|
return false;
|
|
101
77
|
}
|
|
102
78
|
|
|
103
|
-
// Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`.
|
|
104
|
-
static std::string dump(const json & j) {
|
|
105
|
-
return minja::Value(j).dump(-1, /* to_json= */ true);
|
|
106
|
-
}
|
|
107
|
-
|
|
108
79
|
static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
|
|
109
80
|
assert_equals(expected.role, actual.role);
|
|
110
81
|
assert_equals(expected.content, actual.content);
|
|
82
|
+
assert_equals(expected.content_parts.size(), actual.content_parts.size());
|
|
83
|
+
for (size_t i = 0; i < expected.content_parts.size(); i++) {
|
|
84
|
+
const auto & expected_part = expected.content_parts[i];
|
|
85
|
+
const auto & actual_part = actual.content_parts[i];
|
|
86
|
+
assert_equals(expected_part.type, actual_part.type);
|
|
87
|
+
assert_equals(expected_part.text, actual_part.text);
|
|
88
|
+
}
|
|
111
89
|
assert_equals(expected.reasoning_content, actual.reasoning_content);
|
|
112
90
|
assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
|
|
113
91
|
for (size_t i = 0; i < expected.tool_calls.size(); i++) {
|
|
114
92
|
const auto & expected_tool_call = expected.tool_calls[i];
|
|
115
93
|
const auto & actual_tool_call = actual.tool_calls[i];
|
|
116
94
|
assert_equals(expected_tool_call.name, actual_tool_call.name);
|
|
117
|
-
assert_equals(
|
|
95
|
+
assert_equals(json::parse(expected_tool_call.arguments).dump(), json::parse(actual_tool_call.arguments).dump());
|
|
118
96
|
assert_equals(expected_tool_call.id, actual_tool_call.id);
|
|
119
97
|
}
|
|
120
98
|
}
|
|
121
99
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
"
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
"description": "Python code to execute."
|
|
167
|
-
}
|
|
168
|
-
},
|
|
169
|
-
"required": ["code"]
|
|
170
|
-
}
|
|
171
|
-
}
|
|
172
|
-
})");
|
|
173
|
-
const json tools = { special_function_tool, python_tool };
|
|
174
|
-
const json llama_3_1_tools = { special_function_tool, code_interpreter_tool };
|
|
100
|
+
common_chat_tool special_function_tool {
|
|
101
|
+
/* .name = */ "special_function",
|
|
102
|
+
/* .description = */ "I'm special",
|
|
103
|
+
/* .parameters = */ R"({
|
|
104
|
+
"type": "object",
|
|
105
|
+
"properties": {
|
|
106
|
+
"arg1": {
|
|
107
|
+
"type": "integer",
|
|
108
|
+
"description": "The arg."
|
|
109
|
+
}
|
|
110
|
+
},
|
|
111
|
+
"required": ["arg1"]
|
|
112
|
+
})",
|
|
113
|
+
};
|
|
114
|
+
common_chat_tool python_tool {
|
|
115
|
+
/* .name = */ "python",
|
|
116
|
+
/* .description = */ "an ipython interpreter",
|
|
117
|
+
/* .parameters = */ R"({
|
|
118
|
+
"type": "object",
|
|
119
|
+
"properties": {
|
|
120
|
+
"code": {
|
|
121
|
+
"type": "string",
|
|
122
|
+
"description": "Python code to execute."
|
|
123
|
+
}
|
|
124
|
+
},
|
|
125
|
+
"required": ["code"]
|
|
126
|
+
})",
|
|
127
|
+
};
|
|
128
|
+
common_chat_tool code_interpreter_tool {
|
|
129
|
+
/* .name = */ "code_interpreter",
|
|
130
|
+
/* .description = */ "an ipython interpreter",
|
|
131
|
+
/* .parameters = */ R"({
|
|
132
|
+
"type": "object",
|
|
133
|
+
"properties": {
|
|
134
|
+
"code": {
|
|
135
|
+
"type": "string",
|
|
136
|
+
"description": "Python code to execute."
|
|
137
|
+
}
|
|
138
|
+
},
|
|
139
|
+
"required": ["code"]
|
|
140
|
+
})",
|
|
141
|
+
};
|
|
142
|
+
std::vector<common_chat_tool> tools { special_function_tool, python_tool };
|
|
143
|
+
std::vector<common_chat_tool> llama_3_1_tools { special_function_tool, code_interpreter_tool };
|
|
175
144
|
|
|
176
145
|
struct delta_data {
|
|
177
146
|
std::string delta;
|
|
178
147
|
common_chat_params params;
|
|
179
148
|
};
|
|
180
149
|
|
|
181
|
-
static delta_data init_delta(const
|
|
182
|
-
const
|
|
183
|
-
const
|
|
150
|
+
static delta_data init_delta(const struct common_chat_templates * tmpls, const std::vector<std::string> & end_tokens,
|
|
151
|
+
const common_chat_msg & user_message,
|
|
152
|
+
const common_chat_msg & delta_message,
|
|
153
|
+
const std::vector<common_chat_tool> & tools,
|
|
154
|
+
const common_chat_tool_choice & tool_choice,
|
|
184
155
|
bool think = false) {
|
|
185
|
-
|
|
156
|
+
common_chat_templates_inputs inputs;
|
|
186
157
|
inputs.parallel_tool_calls = true;
|
|
187
|
-
inputs.messages = json::array();
|
|
188
158
|
inputs.messages.push_back(user_message);
|
|
189
159
|
inputs.tools = tools;
|
|
190
160
|
inputs.tool_choice = tool_choice;
|
|
191
161
|
inputs.extract_reasoning = think;
|
|
192
|
-
auto params_prefix =
|
|
162
|
+
auto params_prefix = common_chat_templates_apply(tmpls, inputs);
|
|
193
163
|
|
|
194
164
|
inputs.messages.push_back(delta_message);
|
|
195
165
|
inputs.add_generation_prompt = false;
|
|
196
|
-
auto params_full =
|
|
166
|
+
auto params_full = common_chat_templates_apply(tmpls, inputs);
|
|
197
167
|
|
|
198
168
|
std::string prefix = params_prefix.prompt;
|
|
199
169
|
std::string full = params_full.prompt;
|
|
@@ -234,30 +204,29 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto
|
|
|
234
204
|
gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
|
|
235
205
|
the parsed message is the same as the test_message
|
|
236
206
|
*/
|
|
237
|
-
static void
|
|
238
|
-
const
|
|
207
|
+
static void test_templates(const struct common_chat_templates * tmpls, const std::vector<std::string> & end_tokens,
|
|
208
|
+
const common_chat_msg & test_message,
|
|
209
|
+
const std::vector<common_chat_tool> & tools = {},
|
|
210
|
+
const std::string & expected_delta = "",
|
|
239
211
|
bool expect_grammar_triggered = true,
|
|
240
212
|
bool test_grammar_if_triggered = true,
|
|
241
213
|
bool think = false) {
|
|
242
|
-
common_chat_msg
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
{ "role", "user" },
|
|
246
|
-
{ "content", "Hello, world!" }
|
|
247
|
-
};
|
|
214
|
+
common_chat_msg user_message;
|
|
215
|
+
user_message.role = "user";
|
|
216
|
+
user_message.content = "Hello, world!";
|
|
248
217
|
|
|
249
|
-
for (const auto & tool_choice :
|
|
250
|
-
auto data = init_delta(
|
|
218
|
+
for (const auto & tool_choice : std::vector<common_chat_tool_choice> {COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED}) {
|
|
219
|
+
auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice, think);
|
|
251
220
|
if (!expected_delta.empty()) {
|
|
252
221
|
assert_equals(expected_delta, data.delta);
|
|
253
222
|
}
|
|
254
223
|
|
|
255
224
|
if (expect_grammar_triggered) {
|
|
256
225
|
const auto msg = common_chat_parse(data.delta, data.params.format);
|
|
257
|
-
assert_msg_equals(
|
|
226
|
+
assert_msg_equals(test_message, msg);
|
|
258
227
|
}
|
|
259
228
|
|
|
260
|
-
if (!
|
|
229
|
+
if (!test_message.tool_calls.empty()) {
|
|
261
230
|
GGML_ASSERT(!data.params.grammar.empty());
|
|
262
231
|
}
|
|
263
232
|
if (!data.params.grammar.empty()) {
|
|
@@ -268,12 +237,35 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
|
|
268
237
|
auto earliest_trigger_pos = std::string::npos;
|
|
269
238
|
auto constrained = data.delta;
|
|
270
239
|
for (const auto & trigger : data.params.grammar_triggers) {
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
240
|
+
size_t pos = std::string::npos;
|
|
241
|
+
std::smatch match;
|
|
242
|
+
switch (trigger.type) {
|
|
243
|
+
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
|
244
|
+
{
|
|
245
|
+
const auto & word = trigger.value;
|
|
246
|
+
pos = constrained.find(word);
|
|
247
|
+
break;
|
|
248
|
+
}
|
|
249
|
+
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
|
250
|
+
{
|
|
251
|
+
const auto & pattern = trigger.value;
|
|
252
|
+
if (std::regex_search(constrained, match, std::regex(pattern))) {
|
|
253
|
+
pos = match.position();
|
|
254
|
+
}
|
|
255
|
+
break;
|
|
256
|
+
}
|
|
257
|
+
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
|
|
258
|
+
{
|
|
259
|
+
const auto & pattern = trigger.value;
|
|
260
|
+
if (std::regex_search(constrained, match, std::regex(pattern)) && match.position() == 0) {
|
|
261
|
+
pos = 0;
|
|
262
|
+
}
|
|
263
|
+
break;
|
|
264
|
+
}
|
|
265
|
+
default:
|
|
266
|
+
throw std::runtime_error("Unknown trigger type");
|
|
274
267
|
}
|
|
275
|
-
if (pos
|
|
276
|
-
fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str());
|
|
268
|
+
if (pos == std::string::npos) {
|
|
277
269
|
continue;
|
|
278
270
|
}
|
|
279
271
|
if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
|
|
@@ -291,252 +283,361 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
|
|
|
291
283
|
|
|
292
284
|
if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) {
|
|
293
285
|
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
|
|
294
|
-
|
|
286
|
+
"\n\nConstrained: " + constrained +
|
|
287
|
+
"\n\nGrammar: " + data.params.grammar);
|
|
295
288
|
}
|
|
296
289
|
}
|
|
297
290
|
}
|
|
298
291
|
}
|
|
299
292
|
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
}
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
{ "
|
|
315
|
-
{ "
|
|
316
|
-
}
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
}
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
}
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
293
|
+
const common_chat_msg message_user {
|
|
294
|
+
"user",
|
|
295
|
+
"Hey there!",
|
|
296
|
+
/* .content_parts = */ {},
|
|
297
|
+
/* .tool_calls = */ {},
|
|
298
|
+
/* .reasoning_content = */ "",
|
|
299
|
+
/* .tool_name = */ "",
|
|
300
|
+
/* .tool_call_id = */ "",
|
|
301
|
+
};
|
|
302
|
+
|
|
303
|
+
const common_chat_msg message_user_parts {
|
|
304
|
+
"user",
|
|
305
|
+
/* .content = */ "",
|
|
306
|
+
/* .content_parts = */ {
|
|
307
|
+
{ "text", "Hey" },
|
|
308
|
+
{ "text", "there" },
|
|
309
|
+
},
|
|
310
|
+
/* .tool_calls = */ {},
|
|
311
|
+
/* .reasoning_content = */ "",
|
|
312
|
+
/* .tool_name = */ "",
|
|
313
|
+
/* .tool_call_id = */ "",
|
|
314
|
+
};
|
|
315
|
+
const common_chat_msg message_assist {
|
|
316
|
+
"assistant",
|
|
317
|
+
"Hello, world!\nWhat's up?",
|
|
318
|
+
/* .content_parts = */ {},
|
|
319
|
+
/* .tool_calls = */ {},
|
|
320
|
+
/* .reasoning_content = */ "",
|
|
321
|
+
/* .tool_name = */ "",
|
|
322
|
+
/* .tool_call_id = */ "",
|
|
323
|
+
};
|
|
324
|
+
const common_chat_msg message_assist_thoughts_unparsed_think {
|
|
325
|
+
"assistant",
|
|
326
|
+
"<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
|
327
|
+
/* .content_parts = */ {},
|
|
328
|
+
/* .tool_calls = */ {},
|
|
329
|
+
/* .reasoning_content = */ "",
|
|
330
|
+
/* .tool_name = */ "",
|
|
331
|
+
/* .tool_call_id = */ "",
|
|
332
|
+
};
|
|
333
|
+
const common_chat_msg message_assist_thoughts_unparsed_r7b {
|
|
334
|
+
"assistant",
|
|
335
|
+
"<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?",
|
|
336
|
+
/* .content_parts = */ {},
|
|
337
|
+
/* .tool_calls = */ {},
|
|
338
|
+
/* .reasoning_content = */ "",
|
|
339
|
+
/* .tool_name = */ "",
|
|
340
|
+
/* .tool_call_id = */ "",
|
|
341
|
+
};
|
|
342
|
+
const common_chat_msg message_assist_thoughts {
|
|
343
|
+
"assistant",
|
|
344
|
+
"Hello, world!\nWhat's up?",
|
|
345
|
+
/* .content_parts = */ {},
|
|
346
|
+
/* .tool_calls = */ {},
|
|
347
|
+
/* .reasoning_content = */ "I'm thinking",
|
|
348
|
+
/* .tool_name = */ "",
|
|
349
|
+
/* .tool_call_id = */ "",
|
|
350
|
+
};
|
|
351
|
+
const std::vector<common_chat_tool_call> tool_calls {
|
|
352
|
+
{ "special_function", "{\"arg1\": 1}", /* .id = */ "" },
|
|
353
|
+
};
|
|
354
|
+
const std::vector<common_chat_tool_call> tool_calls_idx {
|
|
355
|
+
{ "special_function", "{\"arg1\": 1}", /* .id = */ "0" },
|
|
356
|
+
};
|
|
357
|
+
const std::vector<common_chat_tool_call> tool_calls_id {
|
|
358
|
+
{ "special_function", "{\"arg1\": 1}", /* .id = */ "123456789" },
|
|
359
|
+
};
|
|
360
|
+
|
|
361
|
+
const common_chat_msg message_assist_call {
|
|
362
|
+
"assistant",
|
|
363
|
+
"",
|
|
364
|
+
/* .content_parts = */ {},
|
|
365
|
+
tool_calls,
|
|
366
|
+
/* .reasoning_content = */ "",
|
|
367
|
+
/* .tool_name = */ "",
|
|
368
|
+
/* .tool_call_id = */ "",
|
|
369
|
+
};
|
|
370
|
+
const common_chat_msg message_assist_call_thoughts = {
|
|
371
|
+
"assistant",
|
|
372
|
+
/* .content = */ "",
|
|
373
|
+
/* .content_parts = */ {},
|
|
374
|
+
tool_calls,
|
|
375
|
+
/* .reasoning_content = */ "I'm\nthinking",
|
|
376
|
+
/* .tool_name = */ "",
|
|
377
|
+
/* .tool_call_id = */ "",
|
|
378
|
+
};
|
|
379
|
+
const common_chat_msg message_assist_call_thoughts_unparsed = {
|
|
380
|
+
"assistant",
|
|
381
|
+
/* .content = */ "<think>I'm\nthinking</think>",
|
|
382
|
+
/* .content_parts = */ {},
|
|
383
|
+
tool_calls,
|
|
384
|
+
/* .reasoning_content = */ "",
|
|
385
|
+
/* .tool_name = */ "",
|
|
386
|
+
/* .tool_call_id = */ "",
|
|
387
|
+
};
|
|
388
|
+
const common_chat_msg message_assist_call_id {
|
|
389
|
+
"assistant",
|
|
390
|
+
"",
|
|
391
|
+
/* .content_parts = */ {},
|
|
392
|
+
tool_calls_id,
|
|
393
|
+
/* .reasoning_content = */ "",
|
|
394
|
+
/* .tool_name = */ "",
|
|
395
|
+
/* .tool_call_id = */ "",
|
|
396
|
+
};
|
|
397
|
+
const common_chat_msg message_assist_call_idx {
|
|
398
|
+
"assistant",
|
|
399
|
+
"",
|
|
400
|
+
/* .content_parts = */ {},
|
|
401
|
+
tool_calls_idx,
|
|
402
|
+
/* .reasoning_content = */ "",
|
|
403
|
+
/* .tool_name = */ "",
|
|
404
|
+
/* .tool_call_id = */ "",
|
|
405
|
+
};
|
|
406
|
+
const common_chat_msg message_assist_call_python {
|
|
407
|
+
"assistant",
|
|
408
|
+
"",
|
|
409
|
+
/* .content_parts = */ {},
|
|
410
|
+
{ { "python", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
|
|
411
|
+
/* .reasoning_content = */ "",
|
|
412
|
+
/* .tool_name = */ "",
|
|
413
|
+
/* .tool_call_id = */ "",
|
|
414
|
+
};
|
|
415
|
+
const common_chat_msg message_assist_call_code_interpreter {
|
|
416
|
+
"assistant",
|
|
417
|
+
"",
|
|
418
|
+
/* .content_parts = */ {},
|
|
419
|
+
{ { "code_interpreter", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
|
|
420
|
+
/* .reasoning_content = */ "",
|
|
421
|
+
/* .tool_name = */ "",
|
|
422
|
+
/* .tool_call_id = */ "",
|
|
423
|
+
};
|
|
424
|
+
|
|
425
|
+
static void test_msgs_oaicompat_json_conversion() {
|
|
426
|
+
std::vector<common_chat_msg> msgs{
|
|
427
|
+
message_user,
|
|
428
|
+
message_user_parts,
|
|
429
|
+
message_assist_call,
|
|
430
|
+
message_assist_call_thoughts,
|
|
431
|
+
message_assist_call_thoughts_unparsed,
|
|
432
|
+
message_assist_call_id,
|
|
433
|
+
message_assist_call_idx,
|
|
434
|
+
message_assist_call_python,
|
|
435
|
+
message_assist_call_code_interpreter,
|
|
419
436
|
};
|
|
420
|
-
auto
|
|
421
|
-
{
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
437
|
+
for (const auto & msg : msgs) {
|
|
438
|
+
auto oai_json = common_chat_msgs_to_json_oaicompat<json>({msg});
|
|
439
|
+
auto msgs2 = common_chat_msgs_parse_oaicompat(oai_json);
|
|
440
|
+
assert_equals((size_t) 1, msgs2.size());
|
|
441
|
+
auto msg2 = msgs2[0];
|
|
442
|
+
assert_msg_equals(msg, msg2);
|
|
443
|
+
}
|
|
444
|
+
assert_equals(
|
|
445
|
+
std::string(
|
|
446
|
+
"[\n"
|
|
447
|
+
" {\n"
|
|
448
|
+
" \"role\": \"user\",\n"
|
|
449
|
+
" \"content\": [\n"
|
|
450
|
+
" {\n"
|
|
451
|
+
" \"type\": \"text\",\n"
|
|
452
|
+
" \"text\": \"Hey\"\n"
|
|
453
|
+
" },\n"
|
|
454
|
+
" {\n"
|
|
455
|
+
" \"type\": \"text\",\n"
|
|
456
|
+
" \"text\": \"there\"\n"
|
|
457
|
+
" }\n"
|
|
458
|
+
" ]\n"
|
|
459
|
+
" }\n"
|
|
460
|
+
"]"
|
|
461
|
+
),
|
|
462
|
+
common_chat_msgs_to_json_oaicompat<json>({message_user_parts}).dump(2));
|
|
463
|
+
|
|
464
|
+
assert_equals(
|
|
465
|
+
std::string(
|
|
466
|
+
"[\n"
|
|
467
|
+
" {\n"
|
|
468
|
+
" \"role\": \"assistant\",\n"
|
|
469
|
+
" \"content\": null,\n"
|
|
470
|
+
" \"tool_calls\": [\n"
|
|
471
|
+
" {\n"
|
|
472
|
+
" \"type\": \"function\",\n"
|
|
473
|
+
" \"function\": {\n"
|
|
474
|
+
" \"name\": \"python\",\n"
|
|
475
|
+
" \"arguments\": \"{\\\"code\\\": \\\"print('hey')\\\"}\"\n"
|
|
476
|
+
" }\n"
|
|
477
|
+
" }\n"
|
|
478
|
+
" ]\n"
|
|
479
|
+
" }\n"
|
|
480
|
+
"]"
|
|
481
|
+
),
|
|
482
|
+
common_chat_msgs_to_json_oaicompat<json>({message_assist_call_python}).dump(2));
|
|
483
|
+
|
|
484
|
+
auto res = common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\", \"tool_calls\": []}]"));
|
|
485
|
+
assert_equals<size_t>(1, res.size());
|
|
486
|
+
assert_equals<std::string>(res[0].role, "assistant");
|
|
487
|
+
assert_equals(true, res[0].content.empty());
|
|
488
|
+
assert_equals(true, res[0].tool_calls.empty());
|
|
489
|
+
|
|
490
|
+
try {
|
|
491
|
+
common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\"}]"));
|
|
492
|
+
throw std::runtime_error("Expected exception");
|
|
493
|
+
} catch (const std::exception & e) {
|
|
494
|
+
if (std::string(e.what()).find("'content'") == std::string::npos) {
|
|
495
|
+
throw std::runtime_error("Expected exception about missing 'content'");
|
|
496
|
+
}
|
|
497
|
+
}
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
static void test_tools_oaicompat_json_conversion() {
|
|
501
|
+
std::vector<common_chat_tool> tools{
|
|
502
|
+
special_function_tool,
|
|
503
|
+
python_tool,
|
|
504
|
+
code_interpreter_tool,
|
|
434
505
|
};
|
|
435
506
|
|
|
436
|
-
|
|
437
|
-
|
|
507
|
+
for (const auto & tool : tools) {
|
|
508
|
+
auto oai_json = common_chat_tools_to_json_oaicompat<json>({tool});
|
|
509
|
+
auto tools2 = common_chat_tools_parse_oaicompat(oai_json);
|
|
510
|
+
assert_equals((size_t) 1, tools2.size());
|
|
511
|
+
auto tool2 = tools2[0];
|
|
512
|
+
assert_equals(tool.name, tool2.name);
|
|
513
|
+
assert_equals(tool.description, tool2.description);
|
|
514
|
+
assert_equals(json::parse(tool.parameters).dump(2), json::parse(tool2.parameters).dump(2));
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
assert_equals(
|
|
518
|
+
std::string(
|
|
519
|
+
"[\n"
|
|
520
|
+
" {\n"
|
|
521
|
+
" \"type\": \"function\",\n"
|
|
522
|
+
" \"function\": {\n"
|
|
523
|
+
" \"name\": \"special_function\",\n"
|
|
524
|
+
" \"description\": \"I'm special\",\n"
|
|
525
|
+
" \"parameters\": {\n"
|
|
526
|
+
" \"type\": \"object\",\n"
|
|
527
|
+
" \"properties\": {\n"
|
|
528
|
+
" \"arg1\": {\n"
|
|
529
|
+
" \"type\": \"integer\",\n"
|
|
530
|
+
" \"description\": \"The arg.\"\n"
|
|
531
|
+
" }\n"
|
|
532
|
+
" },\n"
|
|
533
|
+
" \"required\": [\n"
|
|
534
|
+
" \"arg1\"\n"
|
|
535
|
+
" ]\n"
|
|
536
|
+
" }\n"
|
|
537
|
+
" }\n"
|
|
538
|
+
" }\n"
|
|
539
|
+
"]"
|
|
540
|
+
),
|
|
541
|
+
common_chat_tools_to_json_oaicompat<json>({special_function_tool}).dump(2));
|
|
542
|
+
}
|
|
543
|
+
|
|
544
|
+
static void test_template_output_parsers() {
|
|
545
|
+
|
|
546
|
+
common_chat_templates_inputs inputs_no_tools;
|
|
547
|
+
inputs_no_tools.messages = {message_user};
|
|
438
548
|
inputs_no_tools.extract_reasoning = false;
|
|
439
549
|
|
|
440
|
-
|
|
441
|
-
inputs_no_tools_think.messages =
|
|
550
|
+
common_chat_templates_inputs inputs_no_tools_think;
|
|
551
|
+
inputs_no_tools_think.messages = {message_user};
|
|
442
552
|
inputs_no_tools_think.extract_reasoning = true;
|
|
443
553
|
|
|
444
|
-
|
|
445
|
-
inputs_tools.messages =
|
|
446
|
-
inputs_tools.tools =
|
|
554
|
+
common_chat_templates_inputs inputs_tools;
|
|
555
|
+
inputs_tools.messages = {message_user};
|
|
556
|
+
inputs_tools.tools = {special_function_tool};
|
|
447
557
|
inputs_tools.extract_reasoning = false;
|
|
448
558
|
|
|
449
|
-
|
|
450
|
-
inputs_tools_think.messages =
|
|
451
|
-
inputs_tools_think.tools =
|
|
559
|
+
common_chat_templates_inputs inputs_tools_think;
|
|
560
|
+
inputs_tools_think.messages = {message_user};
|
|
561
|
+
inputs_tools_think.tools = {special_function_tool};
|
|
452
562
|
inputs_tools_think.extract_reasoning = true;
|
|
453
563
|
|
|
454
|
-
|
|
455
|
-
inputs_tools_builtin.messages =
|
|
456
|
-
inputs_tools_builtin.tools =
|
|
564
|
+
common_chat_templates_inputs inputs_tools_builtin;
|
|
565
|
+
inputs_tools_builtin.messages = {message_user};
|
|
566
|
+
inputs_tools_builtin.tools = {python_tool};
|
|
457
567
|
inputs_tools_builtin.extract_reasoning = false;
|
|
458
568
|
|
|
459
569
|
{
|
|
460
570
|
// Not supported yet
|
|
461
|
-
|
|
462
|
-
assert_equals(COMMON_CHAT_FORMAT_GENERIC,
|
|
571
|
+
auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja");
|
|
572
|
+
assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
463
573
|
}
|
|
464
574
|
{
|
|
465
|
-
|
|
575
|
+
auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja");
|
|
466
576
|
std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
|
|
467
577
|
|
|
468
|
-
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B,
|
|
469
|
-
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B,
|
|
470
|
-
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
|
|
578
|
+
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
579
|
+
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
580
|
+
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
|
|
471
581
|
|
|
472
|
-
assert_msg_equals(
|
|
582
|
+
assert_msg_equals(message_assist,
|
|
473
583
|
common_chat_parse(
|
|
474
584
|
"Hello, world!\nWhat's up?",
|
|
475
585
|
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
|
476
|
-
assert_msg_equals(
|
|
586
|
+
assert_msg_equals(message_assist,
|
|
477
587
|
common_chat_parse(
|
|
478
588
|
"Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
|
479
589
|
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
|
480
|
-
assert_msg_equals(
|
|
590
|
+
assert_msg_equals(message_assist,
|
|
481
591
|
common_chat_parse(
|
|
482
592
|
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
|
483
593
|
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
|
484
|
-
assert_msg_equals(
|
|
594
|
+
assert_msg_equals(message_assist_thoughts_unparsed_r7b,
|
|
485
595
|
common_chat_parse(
|
|
486
596
|
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
|
487
597
|
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
|
488
598
|
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
|
489
|
-
assert_msg_equals(
|
|
599
|
+
assert_msg_equals(message_assist_thoughts_unparsed_r7b,
|
|
490
600
|
common_chat_parse(
|
|
491
601
|
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
|
492
602
|
"Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
|
493
603
|
COMMON_CHAT_FORMAT_COMMAND_R7B));
|
|
494
604
|
|
|
495
|
-
assert_msg_equals(
|
|
605
|
+
assert_msg_equals(message_assist_thoughts,
|
|
496
606
|
common_chat_parse(
|
|
497
607
|
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
|
498
608
|
"<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
|
|
499
609
|
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING));
|
|
500
610
|
|
|
501
|
-
|
|
611
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call_idx, tools,
|
|
502
612
|
"<|START_THINKING|><|END_THINKING|>"
|
|
503
613
|
"<|START_ACTION|>[\n"
|
|
504
614
|
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
|
|
505
615
|
"]<|END_ACTION|>");
|
|
506
|
-
|
|
507
|
-
"<|START_THINKING|>I'm thinking<|END_THINKING|>"
|
|
508
|
-
"<|START_ACTION|>[\n"
|
|
509
|
-
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
|
|
510
|
-
"]<|END_ACTION|>",
|
|
511
|
-
/* expect_grammar_triggered= */ true,
|
|
512
|
-
/* test_grammar_if_triggered= */ true,
|
|
513
|
-
/* think= */ true);
|
|
514
|
-
test_template(tmpl, end_tokens, message_assist, tools,
|
|
616
|
+
test_templates(tmpls.get(), end_tokens, message_assist, tools,
|
|
515
617
|
"<|START_RESPONSE|>Hello, world!\n"
|
|
516
618
|
"What's up?<|END_RESPONSE|>",
|
|
517
619
|
/* expect_grammar_triggered= */ false);
|
|
518
620
|
}
|
|
519
621
|
{
|
|
520
|
-
|
|
622
|
+
auto tmpls = read_templates("models/templates/google-gemma-2-2b-it.jinja");
|
|
521
623
|
std::vector<std::string> end_tokens{ "<end_of_turn>" };
|
|
522
624
|
|
|
523
|
-
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
|
524
|
-
assert_equals(COMMON_CHAT_FORMAT_GENERIC,
|
|
625
|
+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
626
|
+
assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
525
627
|
assert_equals(COMMON_CHAT_FORMAT_GENERIC,
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
"<s>", "</s>"),
|
|
628
|
+
common_chat_templates_apply(
|
|
629
|
+
read_templates("models/templates/microsoft-Phi-3.5-mini-instruct.jinja").get(),
|
|
529
630
|
inputs_tools)
|
|
530
631
|
.format);
|
|
531
632
|
|
|
532
633
|
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
|
533
634
|
|
|
534
|
-
assert_msg_equals(
|
|
635
|
+
assert_msg_equals(message_assist,
|
|
535
636
|
common_chat_parse("{\n"
|
|
536
637
|
" \"response\": \"Hello, world!\\nWhat's up?\"\n"
|
|
537
638
|
"}",
|
|
538
|
-
|
|
539
|
-
|
|
639
|
+
common_chat_templates_apply(tmpls.get(), inputs_tools).format));
|
|
640
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools,
|
|
540
641
|
"{\n"
|
|
541
642
|
" \"tool_calls\": [\n"
|
|
542
643
|
" {\n"
|
|
@@ -550,143 +651,233 @@ static void test_template_output_parsers() {
|
|
|
550
651
|
"}");
|
|
551
652
|
}
|
|
552
653
|
{
|
|
553
|
-
|
|
554
|
-
"</s>");
|
|
654
|
+
auto tmpls = read_templates("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja");
|
|
555
655
|
std::vector<std::string> end_tokens{ "</s>" };
|
|
556
656
|
|
|
557
|
-
assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO,
|
|
657
|
+
assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
558
658
|
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
659
|
+
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
660
|
+
test_templates(
|
|
661
|
+
tmpls.get(), end_tokens, message_assist_call_id, tools,
|
|
562
662
|
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
|
|
563
663
|
}
|
|
564
664
|
{
|
|
565
|
-
|
|
566
|
-
read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
|
665
|
+
auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
|
|
567
666
|
std::vector<std::string> end_tokens{ "<|im_end|>" };
|
|
568
667
|
|
|
569
|
-
assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
|
668
|
+
assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
570
669
|
assert_equals(
|
|
571
670
|
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
"<s>", "</s>"),
|
|
671
|
+
common_chat_templates_apply(
|
|
672
|
+
read_templates("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja").get(),
|
|
575
673
|
inputs_tools)
|
|
576
674
|
.format);
|
|
577
675
|
assert_equals(
|
|
578
676
|
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
|
579
|
-
|
|
580
|
-
|
|
677
|
+
common_chat_templates_apply(
|
|
678
|
+
read_templates("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja").get(),
|
|
581
679
|
inputs_tools)
|
|
582
680
|
.format);
|
|
583
681
|
|
|
584
|
-
|
|
585
|
-
|
|
682
|
+
// Test parsing
|
|
683
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
684
|
+
"<tool_call>\n"
|
|
685
|
+
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
686
|
+
"</tool_call>",
|
|
687
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
688
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
689
|
+
"<function=special_function>{\"arg1\": 1}</function>",
|
|
690
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
691
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
692
|
+
"<function name=\"special_function\">\n"
|
|
693
|
+
"{\"arg1\": 1}\n"
|
|
694
|
+
"</function>",
|
|
695
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
696
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
697
|
+
"<tool>\n"
|
|
698
|
+
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
699
|
+
"</tool>",
|
|
700
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
701
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
702
|
+
"<tools>\n"
|
|
703
|
+
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
704
|
+
"</tools>",
|
|
705
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
706
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
707
|
+
"<response>\n"
|
|
708
|
+
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
709
|
+
"</response>",
|
|
710
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
711
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
712
|
+
"```xml\n"
|
|
713
|
+
"<response>\n"
|
|
714
|
+
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
715
|
+
"</response>\n"
|
|
716
|
+
"```",
|
|
717
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
718
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
719
|
+
"```xml\n"
|
|
720
|
+
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
721
|
+
"```",
|
|
722
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
723
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
724
|
+
"```\n"
|
|
725
|
+
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
726
|
+
"```",
|
|
727
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
728
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
729
|
+
"```\n"
|
|
730
|
+
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
731
|
+
"```",
|
|
732
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
733
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
734
|
+
"```json\n"
|
|
735
|
+
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
736
|
+
"```",
|
|
737
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
738
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
739
|
+
"```json\n"
|
|
740
|
+
"\n"
|
|
741
|
+
" <function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}} \n"
|
|
742
|
+
" </function_call> \n"
|
|
743
|
+
"``` ",
|
|
744
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
745
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
746
|
+
"<json>\n"
|
|
747
|
+
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
748
|
+
"</json>",
|
|
749
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
750
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
751
|
+
"<xml>\n"
|
|
752
|
+
" {\n"
|
|
753
|
+
" \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}\n"
|
|
754
|
+
" }\n"
|
|
755
|
+
"</xml>",
|
|
756
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
757
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
758
|
+
"<JSON>\n"
|
|
759
|
+
" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
760
|
+
"</JSON>",
|
|
761
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
762
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
763
|
+
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
|
|
764
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
765
|
+
assert_msg_equals(message_assist_call, common_chat_parse(
|
|
766
|
+
"{\n \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
|
|
767
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
768
|
+
|
|
769
|
+
assert_msg_equals(message_assist_thoughts_unparsed_think,
|
|
770
|
+
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
|
771
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
772
|
+
assert_msg_equals(message_assist_thoughts_unparsed_think,
|
|
773
|
+
common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
|
|
774
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO));
|
|
775
|
+
assert_msg_equals(message_assist_thoughts,
|
|
776
|
+
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
|
777
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING));
|
|
778
|
+
assert_msg_equals(message_assist_thoughts,
|
|
779
|
+
common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
|
|
780
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING));
|
|
781
|
+
|
|
782
|
+
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
783
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
|
586
784
|
"<tool_call>\n"
|
|
587
785
|
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
588
786
|
"</tool_call>");
|
|
589
|
-
|
|
787
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools,
|
|
590
788
|
"<tool_call>\n"
|
|
591
789
|
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
|
|
592
790
|
"</tool_call>");
|
|
593
791
|
}
|
|
594
792
|
{
|
|
595
|
-
|
|
596
|
-
"</s>");
|
|
793
|
+
auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja");
|
|
597
794
|
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
|
|
598
795
|
|
|
599
|
-
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X,
|
|
796
|
+
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
600
797
|
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
|
601
|
-
|
|
798
|
+
common_chat_templates_apply(tmpls.get(), inputs_tools_builtin).format);
|
|
602
799
|
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
"<s>", "</s>"),
|
|
800
|
+
common_chat_templates_apply(
|
|
801
|
+
read_templates("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja").get(),
|
|
606
802
|
inputs_tools_builtin)
|
|
607
803
|
.format);
|
|
608
804
|
|
|
609
|
-
//
|
|
610
|
-
|
|
805
|
+
// test_templates(tmpls.get(), end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
|
|
806
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call_code_interpreter, llama_3_1_tools,
|
|
611
807
|
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
|
|
612
|
-
|
|
808
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools,
|
|
613
809
|
"<|python_tag|>python.call(code=\"print('hey')\")");
|
|
614
|
-
|
|
810
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
|
615
811
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
|
616
812
|
}
|
|
617
813
|
{
|
|
618
|
-
|
|
619
|
-
"</s>");
|
|
814
|
+
auto tmpls = read_templates("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja");
|
|
620
815
|
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
|
|
621
816
|
|
|
622
|
-
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X,
|
|
817
|
+
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
623
818
|
|
|
624
|
-
|
|
625
|
-
|
|
819
|
+
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
820
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
|
626
821
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
|
627
822
|
}
|
|
628
823
|
{
|
|
629
|
-
|
|
630
|
-
"</s>");
|
|
824
|
+
auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja");
|
|
631
825
|
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
|
|
632
826
|
|
|
633
827
|
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
|
634
|
-
|
|
828
|
+
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
635
829
|
|
|
636
|
-
|
|
637
|
-
|
|
830
|
+
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
831
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
|
638
832
|
"<function=special_function>{\"arg1\": 1}</function>");
|
|
639
833
|
}
|
|
640
834
|
{
|
|
641
|
-
|
|
642
|
-
"</s>");
|
|
835
|
+
auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.2.jinja");
|
|
643
836
|
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
|
|
644
837
|
|
|
645
|
-
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
|
646
|
-
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
|
838
|
+
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
|
839
|
+
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
647
840
|
|
|
648
|
-
|
|
841
|
+
test_templates(tmpls.get(), end_tokens, message_assist, {},
|
|
649
842
|
"all\n"
|
|
650
843
|
"Hello, world!\n"
|
|
651
844
|
"What's up?",
|
|
652
845
|
/* expect_grammar_triggered= */ false);
|
|
653
|
-
|
|
846
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
|
654
847
|
"special_function\n"
|
|
655
848
|
"{\"arg1\": 1}");
|
|
656
849
|
}
|
|
657
850
|
{
|
|
658
|
-
|
|
659
|
-
"</s>");
|
|
851
|
+
auto tmpls = read_templates("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja");
|
|
660
852
|
std::vector<std::string> end_tokens{ "<|eot_id|>" };
|
|
661
853
|
|
|
662
|
-
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
|
854
|
+
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
663
855
|
|
|
664
|
-
|
|
665
|
-
|
|
856
|
+
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
857
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
|
666
858
|
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
|
667
859
|
}
|
|
668
860
|
{
|
|
669
861
|
// Original DeepSeek R1 template. Leaves <|tool▁calls▁begin|> and others unclosed. Our logic fixes the prompt.
|
|
670
|
-
|
|
671
|
-
"<s>", "</s>");
|
|
862
|
+
auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja");
|
|
672
863
|
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
|
|
673
864
|
|
|
674
|
-
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
|
675
|
-
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
|
|
865
|
+
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
866
|
+
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
|
|
676
867
|
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
assert_msg_equals(
|
|
868
|
+
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
869
|
+
test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
870
|
+
assert_msg_equals(message_assist_thoughts_unparsed_think,
|
|
680
871
|
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
|
681
872
|
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
|
682
|
-
assert_msg_equals(
|
|
873
|
+
assert_msg_equals(message_assist_thoughts,
|
|
683
874
|
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
|
684
875
|
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
|
|
685
|
-
assert_msg_equals(
|
|
876
|
+
assert_msg_equals(message_assist_thoughts,
|
|
686
877
|
// Latest template update (ast of 20250209) adds a trailing <think>\n if add_generation_prompt is true.
|
|
687
878
|
common_chat_parse("I'm thinking</think>Hello, world!\nWhat's up?",
|
|
688
879
|
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
|
|
689
|
-
//
|
|
880
|
+
// test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
|
690
881
|
// "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
|
691
882
|
// "```json\n"
|
|
692
883
|
// "{\"arg1\": 1}\n"
|
|
@@ -697,23 +888,22 @@ static void test_template_output_parsers() {
|
|
|
697
888
|
}
|
|
698
889
|
{
|
|
699
890
|
// Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all.
|
|
700
|
-
|
|
701
|
-
"<s>", "</s>");
|
|
891
|
+
auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja");
|
|
702
892
|
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
|
|
703
893
|
|
|
704
|
-
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
|
705
|
-
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
|
|
894
|
+
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
|
895
|
+
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
|
|
706
896
|
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
assert_msg_equals(
|
|
897
|
+
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
898
|
+
test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
899
|
+
assert_msg_equals(message_assist_thoughts_unparsed_think,
|
|
710
900
|
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
|
711
901
|
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
|
712
|
-
assert_msg_equals(
|
|
902
|
+
assert_msg_equals(message_assist_thoughts,
|
|
713
903
|
common_chat_parse("<think>I'm thinking</think>Hello, world!\nWhat's up?",
|
|
714
904
|
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
|
|
715
905
|
|
|
716
|
-
assert_msg_equals(
|
|
906
|
+
assert_msg_equals(message_assist_call_thoughts_unparsed,
|
|
717
907
|
common_chat_parse(
|
|
718
908
|
"<think>I'm\nthinking</think>\n\n"
|
|
719
909
|
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
|
@@ -721,7 +911,7 @@ static void test_template_output_parsers() {
|
|
|
721
911
|
"{\"arg1\": 1}\n"
|
|
722
912
|
"```<|tool▁call▁end|><|tool▁calls▁end|>",
|
|
723
913
|
COMMON_CHAT_FORMAT_DEEPSEEK_R1));
|
|
724
|
-
assert_msg_equals(
|
|
914
|
+
assert_msg_equals(message_assist_call_thoughts,
|
|
725
915
|
common_chat_parse(
|
|
726
916
|
"<think>I'm\nthinking</think>\n\n"
|
|
727
917
|
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
|
@@ -729,7 +919,7 @@ static void test_template_output_parsers() {
|
|
|
729
919
|
"{\"arg1\": 1}\n"
|
|
730
920
|
"```<|tool▁call▁end|><|tool▁calls▁end|>",
|
|
731
921
|
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
|
|
732
|
-
|
|
922
|
+
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
|
733
923
|
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
|
734
924
|
"```json\n"
|
|
735
925
|
"{\"arg1\": 1}\n"
|
|
@@ -738,38 +928,46 @@ static void test_template_output_parsers() {
|
|
|
738
928
|
}
|
|
739
929
|
|
|
740
930
|
int main(int argc, char ** argv) {
|
|
931
|
+
// try {
|
|
741
932
|
#ifndef _WIN32
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
933
|
+
if (argc > 1) {
|
|
934
|
+
common_chat_templates_inputs inputs;
|
|
935
|
+
common_chat_msg msg;
|
|
936
|
+
msg.role = "user";
|
|
937
|
+
msg.content = "Hey";
|
|
938
|
+
inputs.messages = {msg};
|
|
939
|
+
inputs.tools = { special_function_tool };
|
|
940
|
+
|
|
941
|
+
std::cout << "| Template | Format |\n";
|
|
942
|
+
std::cout << "|----------|--------|\n";
|
|
943
|
+
|
|
944
|
+
for (int i = 1; i < argc; i++) {
|
|
945
|
+
try {
|
|
946
|
+
std::string path = argv[i];
|
|
947
|
+
if (path.rfind(".jinja") != path.size() - 6) {
|
|
948
|
+
std::cerr << "Skipping non-jinja file: " << path << '\n';
|
|
949
|
+
continue;
|
|
950
|
+
}
|
|
951
|
+
auto tmpls = read_templates(path);
|
|
952
|
+
auto parts = string_split(path, "/");
|
|
953
|
+
auto name = parts[parts.size() - 1];
|
|
954
|
+
auto format = common_chat_format_name(common_chat_templates_apply(tmpls.get(), inputs).format);
|
|
955
|
+
std::cout << "| " << name << " | " << format << " |\n";
|
|
956
|
+
} catch (const std::exception & e) {
|
|
957
|
+
std::cerr << "Failed to process " << argv[i] << ": " << e.what() << '\n';
|
|
758
958
|
}
|
|
759
|
-
common_chat_template tmpl(read_file(path), "", "");
|
|
760
|
-
auto parts = string_split(path, "/");
|
|
761
|
-
auto name = parts[parts.size() - 1];
|
|
762
|
-
auto format = common_chat_format_name(common_chat_params_init(tmpl, inputs).format);
|
|
763
|
-
std::cout << "| " << name << " | " << format << " |\n";
|
|
764
|
-
} catch (const std::exception & e) {
|
|
765
|
-
std::cerr << "Failed to process " << argv[i] << ": " << e.what() << std::endl;
|
|
766
959
|
}
|
|
767
|
-
}
|
|
768
|
-
} else
|
|
960
|
+
} else
|
|
769
961
|
#endif
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
962
|
+
{
|
|
963
|
+
test_msgs_oaicompat_json_conversion();
|
|
964
|
+
test_tools_oaicompat_json_conversion();
|
|
965
|
+
test_template_output_parsers();
|
|
966
|
+
std::cout << "\n[chat] All tests passed!" << '\n';
|
|
967
|
+
}
|
|
968
|
+
return 0;
|
|
969
|
+
// } catch (const std::exception & e) {
|
|
970
|
+
// std::cerr << "Error: " << e.what() << '\n';
|
|
971
|
+
// return 1;
|
|
972
|
+
// }
|
|
775
973
|
}
|