@fugood/llama.node 0.3.8 → 0.3.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.js +2 -2
- package/lib/binding.ts +52 -8
- package/lib/index.ts +3 -1
- package/package.json +8 -1
- package/src/LlamaCompletionWorker.cpp +33 -6
- package/src/LlamaCompletionWorker.h +3 -1
- package/src/LlamaContext.cpp +387 -28
- package/src/LlamaContext.h +5 -0
- package/src/common.hpp +19 -2
- package/src/llama.cpp/.github/workflows/build.yml +289 -107
- package/src/llama.cpp/.github/workflows/close-issue.yml +1 -1
- package/src/llama.cpp/.github/workflows/docker.yml +2 -1
- package/src/llama.cpp/.github/workflows/server.yml +25 -2
- package/src/llama.cpp/CMakeLists.txt +10 -19
- package/src/llama.cpp/cmake/build-info.cmake +1 -1
- package/src/llama.cpp/common/CMakeLists.txt +32 -0
- package/src/llama.cpp/common/arg.cpp +66 -16
- package/src/llama.cpp/common/chat-template.hpp +515 -0
- package/src/llama.cpp/common/chat.cpp +966 -0
- package/src/llama.cpp/common/chat.hpp +52 -0
- package/src/llama.cpp/common/common.cpp +159 -36
- package/src/llama.cpp/common/common.h +56 -14
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +46 -66
- package/src/llama.cpp/common/json-schema-to-grammar.h +15 -1
- package/src/llama.cpp/common/llguidance.cpp +270 -0
- package/src/llama.cpp/common/log.cpp +1 -10
- package/src/llama.cpp/common/log.h +10 -0
- package/src/llama.cpp/common/minja.hpp +2868 -0
- package/src/llama.cpp/common/sampling.cpp +22 -1
- package/src/llama.cpp/common/sampling.h +3 -0
- package/src/llama.cpp/docs/build.md +54 -9
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +12 -2
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +1 -1
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +59 -0
- package/src/llama.cpp/examples/llava/clip.cpp +133 -14
- package/src/llama.cpp/examples/llava/clip.h +2 -0
- package/src/llama.cpp/examples/llava/llava.cpp +22 -8
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +9 -1
- package/src/llama.cpp/examples/main/main.cpp +26 -25
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +136 -137
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +18 -4
- package/src/llama.cpp/examples/run/run.cpp +224 -69
- package/src/llama.cpp/examples/server/server.cpp +252 -81
- package/src/llama.cpp/examples/server/utils.hpp +73 -21
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +6 -4
- package/src/llama.cpp/examples/simple-cmake-pkg/CMakeLists.txt +11 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +78 -1
- package/src/llama.cpp/ggml/include/ggml.h +1 -1
- package/src/llama.cpp/ggml/src/CMakeLists.txt +21 -4
- package/src/llama.cpp/ggml/src/ggml-alloc.c +1 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +91 -78
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +7 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +46 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +16 -1
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +28 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +5 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +33 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.hpp +1 -5
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +323 -121
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +13 -3
- package/src/llama.cpp/ggml/src/ggml.c +23 -13
- package/src/llama.cpp/include/llama.h +14 -1
- package/src/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +46 -0
- package/src/llama.cpp/src/CMakeLists.txt +1 -1
- package/src/llama.cpp/src/llama-arch.cpp +7 -2
- package/src/llama.cpp/src/llama-arch.h +3 -1
- package/src/llama.cpp/src/llama-chat.cpp +11 -2
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-grammar.cpp +86 -6
- package/src/llama.cpp/src/llama-grammar.h +22 -1
- package/src/llama.cpp/src/llama-mmap.cpp +1 -0
- package/src/llama.cpp/src/llama-model-loader.cpp +1 -1
- package/src/llama.cpp/src/llama-model.cpp +76 -6
- package/src/llama.cpp/src/llama-sampling.cpp +47 -4
- package/src/llama.cpp/src/llama-vocab.cpp +10 -4
- package/src/llama.cpp/src/llama.cpp +181 -123
- package/src/llama.cpp/tests/CMakeLists.txt +4 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +158 -57
- package/src/llama.cpp/tests/test-chat-template.cpp +154 -31
- package/src/llama.cpp/tests/test-chat.cpp +607 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +2 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +1140 -0
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +0 -32
|
@@ -0,0 +1,607 @@
|
|
|
1
|
+
// Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
|
|
2
|
+
//
|
|
3
|
+
// Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
|
|
4
|
+
// e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
|
|
5
|
+
//
|
|
6
|
+
// cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
|
|
7
|
+
//
|
|
8
|
+
#include <fstream>
|
|
9
|
+
#include <iostream>
|
|
10
|
+
#include <json.hpp>
|
|
11
|
+
#include <string>
|
|
12
|
+
|
|
13
|
+
#include "chat-template.hpp"
|
|
14
|
+
#include "chat.hpp"
|
|
15
|
+
#include "llama-grammar.h"
|
|
16
|
+
#include "unicode.h"
|
|
17
|
+
|
|
18
|
+
using json = nlohmann::ordered_json;
|
|
19
|
+
|
|
20
|
+
static common_chat_msg msg_from_json(const json & message) {
|
|
21
|
+
common_chat_msg ret;
|
|
22
|
+
ret.role = "assistant";
|
|
23
|
+
if (message.contains("content") && !message.at("content").is_null()) {
|
|
24
|
+
ret.content = message.at("content");
|
|
25
|
+
}
|
|
26
|
+
if (message.contains("tool_plan")) {
|
|
27
|
+
ret.tool_plan = message.at("tool_plan");
|
|
28
|
+
}
|
|
29
|
+
auto has_tool_calls = message.contains("tool_calls");
|
|
30
|
+
if (has_tool_calls) {
|
|
31
|
+
for (const auto & tc : message.at("tool_calls")) {
|
|
32
|
+
const auto & arguments = tc.at("function").at("arguments");
|
|
33
|
+
ret.tool_calls.push_back({
|
|
34
|
+
tc.at("function").at("name").get<std::string>(),
|
|
35
|
+
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
|
36
|
+
tc.contains("id") ? tc.at("id").get<std::string>() : "",
|
|
37
|
+
});
|
|
38
|
+
}
|
|
39
|
+
}
|
|
40
|
+
return ret;
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
template <class T> static void assert_equals(const T & expected, const T & actual) {
|
|
44
|
+
if (expected != actual) {
|
|
45
|
+
std::cerr << "Expected: " << expected << std::endl;
|
|
46
|
+
std::cerr << "Actual: " << actual << std::endl;
|
|
47
|
+
std::cerr << std::flush;
|
|
48
|
+
throw std::runtime_error("Test failed");
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
static std::string read_file(const std::string & path) {
|
|
53
|
+
std::cerr << "# Reading: " << path << std::endl << std::flush;
|
|
54
|
+
std::ifstream fs(path, std::ios_base::binary);
|
|
55
|
+
if (!fs.is_open()) {
|
|
56
|
+
fs = std::ifstream("../" + path, std::ios_base::binary);
|
|
57
|
+
if (!fs.is_open()) {
|
|
58
|
+
throw std::runtime_error("Failed to open file: " + path);
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
fs.seekg(0, std::ios_base::end);
|
|
62
|
+
auto size = fs.tellg();
|
|
63
|
+
fs.seekg(0);
|
|
64
|
+
std::string out;
|
|
65
|
+
out.resize(static_cast<size_t>(size));
|
|
66
|
+
fs.read(&out[0], static_cast<std::streamsize>(size));
|
|
67
|
+
return out;
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
static std::unique_ptr<llama_grammar> build_grammar(const std::string & grammar_str) {
|
|
71
|
+
return std::unique_ptr<llama_grammar>(
|
|
72
|
+
llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0));
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
// TODO: extract to common helper (copied from test-grammar-integration.cpp)
|
|
76
|
+
static bool match_string(const std::string & input, llama_grammar * grammar) {
|
|
77
|
+
const auto cpts = unicode_cpts_from_utf8(input);
|
|
78
|
+
|
|
79
|
+
auto & stacks_cur = llama_grammar_get_stacks(grammar);
|
|
80
|
+
|
|
81
|
+
for (const auto & cpt : cpts) {
|
|
82
|
+
llama_grammar_accept(grammar, cpt);
|
|
83
|
+
|
|
84
|
+
if (stacks_cur.empty()) {
|
|
85
|
+
// no stacks means that the grammar failed to match at this point
|
|
86
|
+
return false;
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
for (const auto & stack : stacks_cur) {
|
|
91
|
+
if (stack.empty()) {
|
|
92
|
+
// An empty stack means that the grammar has been completed
|
|
93
|
+
return true;
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
return false;
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
// Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`.
|
|
101
|
+
static std::string dump(const json & j) {
|
|
102
|
+
return minja::Value(j).dump(-1, /* to_json= */ true);
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
|
|
106
|
+
assert_equals(expected.role, actual.role);
|
|
107
|
+
assert_equals(expected.content, actual.content);
|
|
108
|
+
assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
|
|
109
|
+
for (size_t i = 0; i < expected.tool_calls.size(); i++) {
|
|
110
|
+
const auto & expected_tool_call = expected.tool_calls[i];
|
|
111
|
+
const auto & actual_tool_call = actual.tool_calls[i];
|
|
112
|
+
assert_equals(expected_tool_call.name, actual_tool_call.name);
|
|
113
|
+
assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments)));
|
|
114
|
+
assert_equals(expected_tool_call.id, actual_tool_call.id);
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
const auto special_function_tool = json::parse(R"({
|
|
119
|
+
"type": "function",
|
|
120
|
+
"function": {
|
|
121
|
+
"name": "special_function",
|
|
122
|
+
"description": "I'm special",
|
|
123
|
+
"parameters": {
|
|
124
|
+
"type": "object",
|
|
125
|
+
"properties": {
|
|
126
|
+
"arg1": {
|
|
127
|
+
"type": "integer",
|
|
128
|
+
"description": "The arg."
|
|
129
|
+
}
|
|
130
|
+
},
|
|
131
|
+
"required": ["arg1"]
|
|
132
|
+
}
|
|
133
|
+
}
|
|
134
|
+
})");
|
|
135
|
+
const auto python_tool = json::parse(R"({
|
|
136
|
+
"type": "function",
|
|
137
|
+
"function": {
|
|
138
|
+
"name": "python",
|
|
139
|
+
"description": "an ipython interpreter",
|
|
140
|
+
"parameters": {
|
|
141
|
+
"type": "object",
|
|
142
|
+
"properties": {
|
|
143
|
+
"code": {
|
|
144
|
+
"type": "string",
|
|
145
|
+
"description": "Python code to execute."
|
|
146
|
+
}
|
|
147
|
+
},
|
|
148
|
+
"required": ["code"]
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
})");
|
|
152
|
+
const auto code_interpreter_tool = json::parse(R"({
|
|
153
|
+
"type": "function",
|
|
154
|
+
"function": {
|
|
155
|
+
"name": "code_interpreter",
|
|
156
|
+
"description": "an ipython interpreter",
|
|
157
|
+
"parameters": {
|
|
158
|
+
"type": "object",
|
|
159
|
+
"properties": {
|
|
160
|
+
"code": {
|
|
161
|
+
"type": "string",
|
|
162
|
+
"description": "Python code to execute."
|
|
163
|
+
}
|
|
164
|
+
},
|
|
165
|
+
"required": ["code"]
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
})");
|
|
169
|
+
const json tools = { special_function_tool, python_tool };
|
|
170
|
+
const json llama_3_1_tools = { special_function_tool, code_interpreter_tool };
|
|
171
|
+
|
|
172
|
+
struct delta_data {
|
|
173
|
+
std::string delta;
|
|
174
|
+
common_chat_params params;
|
|
175
|
+
};
|
|
176
|
+
|
|
177
|
+
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
|
|
178
|
+
const json & user_message, const json & delta_message, const json & tools,
|
|
179
|
+
const json & tool_choice) {
|
|
180
|
+
common_chat_inputs inputs;
|
|
181
|
+
inputs.parallel_tool_calls = true;
|
|
182
|
+
inputs.messages = json::array();
|
|
183
|
+
inputs.messages.push_back(user_message);
|
|
184
|
+
inputs.tools = tools;
|
|
185
|
+
inputs.tool_choice = tool_choice;
|
|
186
|
+
auto params_prefix = common_chat_params_init(tmpl, inputs);
|
|
187
|
+
|
|
188
|
+
inputs.messages.push_back(delta_message);
|
|
189
|
+
inputs.add_generation_prompt = false;
|
|
190
|
+
auto params_full = common_chat_params_init(tmpl, inputs);
|
|
191
|
+
|
|
192
|
+
std::string prefix = params_prefix.prompt;
|
|
193
|
+
std::string full = params_full.prompt;
|
|
194
|
+
|
|
195
|
+
// Check full starts with prefix
|
|
196
|
+
if (full.find(prefix) != 0) {
|
|
197
|
+
fprintf(stderr, "Full:\n%s\n\nPrefix:\n%s\n\n", full.c_str(), prefix.c_str());
|
|
198
|
+
throw std::runtime_error("Full message does not start with prefix");
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
if (full == prefix) {
|
|
202
|
+
throw std::runtime_error("Full message is the same as the prefix");
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
auto delta = full.substr(prefix.size());
|
|
206
|
+
|
|
207
|
+
// Strip end tokens
|
|
208
|
+
for (const auto & end_token : end_tokens) {
|
|
209
|
+
// rfind to find the last occurrence
|
|
210
|
+
auto pos = delta.rfind(end_token);
|
|
211
|
+
if (pos != std::string::npos) {
|
|
212
|
+
delta = delta.substr(0, pos);
|
|
213
|
+
break;
|
|
214
|
+
}
|
|
215
|
+
}
|
|
216
|
+
return { delta, params_full };
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
/*
|
|
220
|
+
Applies the template to 1 user message w/ add_generation_prompt=true, then w/ the test message w/ add_generation_prompt=false,
|
|
221
|
+
gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
|
|
222
|
+
the parsed message is the same as the test_message
|
|
223
|
+
*/
|
|
224
|
+
static void test_template(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens,
|
|
225
|
+
const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
|
|
226
|
+
bool expect_grammar_triggered = true) {
|
|
227
|
+
common_chat_msg expected_msg = msg_from_json(test_message);
|
|
228
|
+
|
|
229
|
+
auto user_message = json{
|
|
230
|
+
{ "role", "user" },
|
|
231
|
+
{ "content", "Hello, world!" }
|
|
232
|
+
};
|
|
233
|
+
|
|
234
|
+
for (const auto & tool_choice : json({ "auto", "required" })) {
|
|
235
|
+
auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice);
|
|
236
|
+
if (!expected_delta.empty()) {
|
|
237
|
+
assert_equals(expected_delta, data.delta);
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
if (expect_grammar_triggered) {
|
|
241
|
+
const auto msg = common_chat_parse(data.delta, data.params.format);
|
|
242
|
+
assert_msg_equals(expected_msg, msg);
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
if (!expected_msg.tool_calls.empty()) {
|
|
246
|
+
GGML_ASSERT(!data.params.grammar.empty());
|
|
247
|
+
}
|
|
248
|
+
if (!data.params.grammar.empty()) {
|
|
249
|
+
auto grammar = build_grammar(data.params.grammar);
|
|
250
|
+
if (!grammar) {
|
|
251
|
+
throw std::runtime_error("Failed to build grammar");
|
|
252
|
+
}
|
|
253
|
+
auto earliest_trigger_pos = std::string::npos;
|
|
254
|
+
auto constrained = data.delta;
|
|
255
|
+
for (const auto & trigger : data.params.grammar_triggers) {
|
|
256
|
+
auto pos = constrained.find(trigger.word);
|
|
257
|
+
if (pos == std::string::npos) {
|
|
258
|
+
continue;
|
|
259
|
+
}
|
|
260
|
+
if (pos > 0 && trigger.at_start) {
|
|
261
|
+
fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str());
|
|
262
|
+
continue;
|
|
263
|
+
}
|
|
264
|
+
if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
|
|
265
|
+
earliest_trigger_pos = pos;
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
auto grammar_triggered = false;
|
|
269
|
+
if (earliest_trigger_pos != std::string::npos) {
|
|
270
|
+
constrained = constrained.substr(earliest_trigger_pos);
|
|
271
|
+
grammar_triggered = true;
|
|
272
|
+
}
|
|
273
|
+
if (data.params.grammar_lazy) {
|
|
274
|
+
assert_equals(expect_grammar_triggered, grammar_triggered);
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
if (grammar_triggered && !match_string(constrained, grammar.get())) {
|
|
278
|
+
throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
|
|
279
|
+
"\n\nGrammar: " + data.params.grammar);
|
|
280
|
+
}
|
|
281
|
+
}
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
static void test_template_output_parsers() {
|
|
286
|
+
json text_message {
|
|
287
|
+
{ "role", "assistant" },
|
|
288
|
+
{ "content", "Hello, world!\nWhat's up?" },
|
|
289
|
+
};
|
|
290
|
+
json tool_calls = json::array({{
|
|
291
|
+
{ "type", "function" },
|
|
292
|
+
{ "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
|
|
293
|
+
}});
|
|
294
|
+
|
|
295
|
+
json tool_call_message {
|
|
296
|
+
{ "role", "assistant"},
|
|
297
|
+
{ "content", {}},
|
|
298
|
+
{ "tool_calls", {
|
|
299
|
+
{
|
|
300
|
+
{ "type", "function" },
|
|
301
|
+
{ "function", {
|
|
302
|
+
{ "name", "special_function" },
|
|
303
|
+
{ "arguments", "{\"arg1\": 1}" },
|
|
304
|
+
}},
|
|
305
|
+
},
|
|
306
|
+
}},
|
|
307
|
+
};
|
|
308
|
+
json tool_call_message_with_id {
|
|
309
|
+
{ "role", "assistant"},
|
|
310
|
+
{ "content", {}},
|
|
311
|
+
{ "tool_calls", {
|
|
312
|
+
{
|
|
313
|
+
{ "type", "function" },
|
|
314
|
+
{ "function", {
|
|
315
|
+
{ "name", "special_function" },
|
|
316
|
+
{ "arguments", "{\"arg1\": 1}" },
|
|
317
|
+
}},
|
|
318
|
+
{"id", "123456789"},
|
|
319
|
+
},
|
|
320
|
+
}},
|
|
321
|
+
{ "role", "assistant" },
|
|
322
|
+
{ "content", {} },
|
|
323
|
+
{ "tool_calls", tool_calls }
|
|
324
|
+
};
|
|
325
|
+
json tool_call_plan_message_with_idx {
|
|
326
|
+
{ "role", "assistant"},
|
|
327
|
+
{ "content", {}},
|
|
328
|
+
{ "tool_plan", "I'm not so sure"},
|
|
329
|
+
{ "tool_calls", {
|
|
330
|
+
{
|
|
331
|
+
{ "type", "function" },
|
|
332
|
+
{ "function", {
|
|
333
|
+
{ "name", "special_function" },
|
|
334
|
+
{ "arguments", "{\"arg1\": 1}" },
|
|
335
|
+
}},
|
|
336
|
+
// Index of the tool call in the tool_calls array
|
|
337
|
+
{"id", "0"},
|
|
338
|
+
},
|
|
339
|
+
}},
|
|
340
|
+
{ "role", "assistant" },
|
|
341
|
+
{ "content", {} },
|
|
342
|
+
{ "tool_calls", tool_calls }
|
|
343
|
+
};
|
|
344
|
+
|
|
345
|
+
auto python_tool_call_message = json{
|
|
346
|
+
{ "role", "assistant" },
|
|
347
|
+
{ "content", {} },
|
|
348
|
+
{ "tool_calls", json{ {
|
|
349
|
+
{ "type", "function" },
|
|
350
|
+
{ "function",
|
|
351
|
+
{
|
|
352
|
+
{ "name", "python" },
|
|
353
|
+
{ "arguments",
|
|
354
|
+
{
|
|
355
|
+
{ "code", "print('hey')" },
|
|
356
|
+
} },
|
|
357
|
+
} },
|
|
358
|
+
} } }
|
|
359
|
+
};
|
|
360
|
+
auto code_interpreter_tool_call_message = json{
|
|
361
|
+
{ "role", "assistant" },
|
|
362
|
+
{ "content", {} },
|
|
363
|
+
{ "tool_calls", json{ {
|
|
364
|
+
{ "type", "function" },
|
|
365
|
+
{ "function",
|
|
366
|
+
{
|
|
367
|
+
{ "name", "code_interpreter" },
|
|
368
|
+
{ "arguments",
|
|
369
|
+
{
|
|
370
|
+
{ "code", "print('hey')" },
|
|
371
|
+
} },
|
|
372
|
+
} },
|
|
373
|
+
} } }
|
|
374
|
+
};
|
|
375
|
+
|
|
376
|
+
common_chat_inputs inputs_no_tools;
|
|
377
|
+
inputs_no_tools.messages = {
|
|
378
|
+
{ { "role", "user" }, { "content", "Hey\nThere" } }
|
|
379
|
+
};
|
|
380
|
+
|
|
381
|
+
common_chat_inputs inputs_tools = inputs_no_tools;
|
|
382
|
+
inputs_tools.tools = json::array();
|
|
383
|
+
inputs_tools.tools.push_back(special_function_tool);
|
|
384
|
+
|
|
385
|
+
common_chat_inputs inputs_tools_builtin = inputs_no_tools;
|
|
386
|
+
inputs_tools_builtin.tools = json::array();
|
|
387
|
+
inputs_tools_builtin.tools.push_back(python_tool);
|
|
388
|
+
|
|
389
|
+
{
|
|
390
|
+
// Not supported yet
|
|
391
|
+
const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "<s>", "</s>");
|
|
392
|
+
assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
|
|
393
|
+
}
|
|
394
|
+
{
|
|
395
|
+
const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "<s>", "</s>");
|
|
396
|
+
std::vector<std::string> end_tokens{ "<|END_OF_TURN_TOKEN|>" };
|
|
397
|
+
|
|
398
|
+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
|
|
399
|
+
assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format);
|
|
400
|
+
|
|
401
|
+
test_template(tmpl, end_tokens, tool_call_plan_message_with_idx, tools,
|
|
402
|
+
"<|START_THINKING|>I'm not so sure<|END_THINKING|>"
|
|
403
|
+
"<|START_ACTION|>[\n"
|
|
404
|
+
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
|
|
405
|
+
"]<|END_ACTION|>");
|
|
406
|
+
test_template(tmpl, end_tokens, text_message, tools,
|
|
407
|
+
"<|START_RESPONSE|>Hello, world!\n"
|
|
408
|
+
"What's up?<|END_RESPONSE|>",
|
|
409
|
+
/* expect_grammar_triggered= */ false);
|
|
410
|
+
}
|
|
411
|
+
{
|
|
412
|
+
const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
|
413
|
+
std::vector<std::string> end_tokens{ "<end_of_turn>" };
|
|
414
|
+
|
|
415
|
+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
|
|
416
|
+
assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
|
|
417
|
+
assert_equals(COMMON_CHAT_FORMAT_GENERIC,
|
|
418
|
+
common_chat_params_init(
|
|
419
|
+
common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"),
|
|
420
|
+
"<s>", "</s>"),
|
|
421
|
+
inputs_tools)
|
|
422
|
+
.format);
|
|
423
|
+
|
|
424
|
+
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
|
425
|
+
|
|
426
|
+
assert_msg_equals(msg_from_json(text_message),
|
|
427
|
+
common_chat_parse("{\n"
|
|
428
|
+
" \"response\": \"Hello, world!\\nWhat's up?\"\n"
|
|
429
|
+
"}",
|
|
430
|
+
common_chat_params_init(tmpl, inputs_tools).format));
|
|
431
|
+
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
|
|
432
|
+
"{\n"
|
|
433
|
+
" \"tool_calls\": [\n"
|
|
434
|
+
" {\n"
|
|
435
|
+
" \"name\": \"special_function\",\n"
|
|
436
|
+
" \"arguments\": {\n"
|
|
437
|
+
" \"arg1\": 1\n"
|
|
438
|
+
" },\n"
|
|
439
|
+
" \"id\": \"123456789\"\n"
|
|
440
|
+
" }\n"
|
|
441
|
+
" ]\n"
|
|
442
|
+
"}");
|
|
443
|
+
}
|
|
444
|
+
{
|
|
445
|
+
const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>",
|
|
446
|
+
"</s>");
|
|
447
|
+
std::vector<std::string> end_tokens{ "</s>" };
|
|
448
|
+
|
|
449
|
+
assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
|
|
450
|
+
|
|
451
|
+
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
452
|
+
test_template(
|
|
453
|
+
tmpl, end_tokens, tool_call_message_with_id, tools,
|
|
454
|
+
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
|
|
455
|
+
}
|
|
456
|
+
{
|
|
457
|
+
const common_chat_template tmpl(
|
|
458
|
+
read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
|
459
|
+
std::vector<std::string> end_tokens{ "<|im_end|>" };
|
|
460
|
+
|
|
461
|
+
assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(tmpl, inputs_tools).format);
|
|
462
|
+
assert_equals(
|
|
463
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
|
464
|
+
common_chat_params_init(
|
|
465
|
+
common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"),
|
|
466
|
+
"<s>", "</s>"),
|
|
467
|
+
inputs_tools)
|
|
468
|
+
.format);
|
|
469
|
+
assert_equals(
|
|
470
|
+
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
|
471
|
+
common_chat_params_init(
|
|
472
|
+
common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"),
|
|
473
|
+
inputs_tools)
|
|
474
|
+
.format);
|
|
475
|
+
|
|
476
|
+
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
477
|
+
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
478
|
+
"<tool_call>\n"
|
|
479
|
+
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
|
|
480
|
+
"</tool_call>");
|
|
481
|
+
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
|
482
|
+
"<tool_call>\n"
|
|
483
|
+
"{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
|
|
484
|
+
"</tool_call>");
|
|
485
|
+
}
|
|
486
|
+
{
|
|
487
|
+
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>",
|
|
488
|
+
"</s>");
|
|
489
|
+
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
|
|
490
|
+
|
|
491
|
+
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
|
|
492
|
+
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
|
493
|
+
common_chat_params_init(tmpl, inputs_tools_builtin).format);
|
|
494
|
+
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
|
495
|
+
common_chat_params_init(
|
|
496
|
+
common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"),
|
|
497
|
+
"<s>", "</s>"),
|
|
498
|
+
inputs_tools_builtin)
|
|
499
|
+
.format);
|
|
500
|
+
|
|
501
|
+
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* expect_grammar_triggered= */ false);
|
|
502
|
+
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
|
503
|
+
"<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
|
|
504
|
+
test_template(tmpl, end_tokens, python_tool_call_message, tools,
|
|
505
|
+
"<|python_tag|>python.call(code=\"print('hey')\")");
|
|
506
|
+
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
507
|
+
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
|
508
|
+
}
|
|
509
|
+
{
|
|
510
|
+
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>",
|
|
511
|
+
"</s>");
|
|
512
|
+
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
|
|
513
|
+
|
|
514
|
+
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
|
|
515
|
+
|
|
516
|
+
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
517
|
+
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
518
|
+
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
|
519
|
+
}
|
|
520
|
+
{
|
|
521
|
+
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>",
|
|
522
|
+
"</s>");
|
|
523
|
+
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
|
|
524
|
+
|
|
525
|
+
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
|
526
|
+
common_chat_params_init(tmpl, inputs_tools).format);
|
|
527
|
+
|
|
528
|
+
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
529
|
+
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
530
|
+
"<function=special_function>{\"arg1\": 1}</function>");
|
|
531
|
+
}
|
|
532
|
+
{
|
|
533
|
+
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>",
|
|
534
|
+
"</s>");
|
|
535
|
+
std::vector<std::string> end_tokens{ "<|eom_id|>", "<|eot_id|>" };
|
|
536
|
+
|
|
537
|
+
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format);
|
|
538
|
+
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format);
|
|
539
|
+
|
|
540
|
+
test_template(tmpl, end_tokens, text_message, {},
|
|
541
|
+
"all\n"
|
|
542
|
+
"Hello, world!\n"
|
|
543
|
+
"What's up?",
|
|
544
|
+
/* expect_grammar_triggered= */ false);
|
|
545
|
+
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
546
|
+
"special_function\n"
|
|
547
|
+
"{\"arg1\": 1}");
|
|
548
|
+
}
|
|
549
|
+
{
|
|
550
|
+
const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>",
|
|
551
|
+
"</s>");
|
|
552
|
+
std::vector<std::string> end_tokens{ "<|eot_id|>" };
|
|
553
|
+
|
|
554
|
+
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
|
|
555
|
+
|
|
556
|
+
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
557
|
+
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
558
|
+
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
|
559
|
+
}
|
|
560
|
+
{
|
|
561
|
+
const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"),
|
|
562
|
+
"<s>", "</s>");
|
|
563
|
+
std::vector<std::string> end_tokens{ "<|end▁of▁sentence|>" };
|
|
564
|
+
|
|
565
|
+
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
|
|
566
|
+
|
|
567
|
+
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
|
568
|
+
test_template(tmpl, end_tokens, tool_call_message, tools,
|
|
569
|
+
"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
|
|
570
|
+
"```json\n"
|
|
571
|
+
"{\"arg1\": 1}\n"
|
|
572
|
+
"```<|tool▁call▁end|>");
|
|
573
|
+
}
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
int main(int argc, char ** argv) {
|
|
577
|
+
#ifndef _WIN32
|
|
578
|
+
if (argc > 1) {
|
|
579
|
+
common_chat_inputs inputs;
|
|
580
|
+
inputs.messages = {
|
|
581
|
+
{ { "role", "user" }, { "content", "Hey" } }
|
|
582
|
+
};
|
|
583
|
+
inputs.tools = json::array({ special_function_tool });
|
|
584
|
+
|
|
585
|
+
std::cout << "| Template | Format |\n";
|
|
586
|
+
std::cout << "|----------|--------|\n";
|
|
587
|
+
|
|
588
|
+
for (int i = 1; i < argc; i++) {
|
|
589
|
+
std::string path = argv[i];
|
|
590
|
+
if (path.rfind(".jinja") != path.size() - 6) {
|
|
591
|
+
std::cerr << "Skipping non-jinja file: " << path << std::endl;
|
|
592
|
+
continue;
|
|
593
|
+
}
|
|
594
|
+
common_chat_template tmpl(read_file(path), "", "");
|
|
595
|
+
auto parts = string_split(path, "/");
|
|
596
|
+
auto name = parts[parts.size() - 1];
|
|
597
|
+
std::cout << "| " << name << " | " << common_chat_format_name(common_chat_params_init(tmpl, inputs).format)
|
|
598
|
+
<< " |\n";
|
|
599
|
+
}
|
|
600
|
+
} else
|
|
601
|
+
#endif
|
|
602
|
+
{
|
|
603
|
+
test_template_output_parsers();
|
|
604
|
+
std::cout << "\n[chat] All tests passed!" << std::endl;
|
|
605
|
+
}
|
|
606
|
+
return 0;
|
|
607
|
+
}
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
using json = nlohmann::ordered_json;
|
|
14
14
|
|
|
15
15
|
static llama_grammar * build_grammar(const std::string & grammar_str) {
|
|
16
|
-
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
|
|
16
|
+
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0);
|
|
17
17
|
}
|
|
18
18
|
|
|
19
19
|
static bool test_build_grammar_fails(const std::string & grammar_str) {
|
|
@@ -129,7 +129,7 @@ static void test_grammar(const std::string & test_desc, const std::string & gram
|
|
|
129
129
|
test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
|
|
130
130
|
}
|
|
131
131
|
static void test_schema(const std::string & test_desc, const std::string & schema_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
|
|
132
|
-
test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str)), passing_strings, failing_strings);
|
|
132
|
+
test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str), true), passing_strings, failing_strings);
|
|
133
133
|
}
|
|
134
134
|
|
|
135
135
|
static void test_simple_grammar() {
|