@fugood/llama.node 0.3.17 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +3 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +39 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +366 -19
- package/src/LlamaCompletionWorker.h +30 -10
- package/src/LlamaContext.cpp +213 -5
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
- package/src/llama.cpp/.github/workflows/build.yml +41 -762
- package/src/llama.cpp/.github/workflows/docker.yml +5 -2
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +12 -12
- package/src/llama.cpp/CMakeLists.txt +5 -17
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +31 -3
- package/src/llama.cpp/common/arg.cpp +48 -29
- package/src/llama.cpp/common/chat.cpp +128 -106
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +37 -1
- package/src/llama.cpp/common/common.h +18 -9
- package/src/llama.cpp/common/llguidance.cpp +1 -0
- package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
- package/src/llama.cpp/common/minja/minja.hpp +69 -36
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +57 -50
- package/src/llama.cpp/examples/CMakeLists.txt +2 -23
- package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
- package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml.h +10 -7
- package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
- package/src/llama.cpp/ggml/src/ggml.c +29 -20
- package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
- package/src/llama.cpp/include/llama.h +52 -11
- package/src/llama.cpp/requirements/requirements-all.txt +3 -3
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +1 -0
- package/src/llama.cpp/src/llama-adapter.cpp +6 -0
- package/src/llama.cpp/src/llama-arch.cpp +3 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +17 -7
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-context.cpp +389 -501
- package/src/llama.cpp/src/llama-context.h +44 -32
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +20 -38
- package/src/llama.cpp/src/llama-graph.h +12 -8
- package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
- package/src/llama.cpp/src/llama-kv-cache.h +271 -85
- package/src/llama.cpp/src/llama-memory.h +11 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +316 -69
- package/src/llama.cpp/src/llama-model.h +8 -1
- package/src/llama.cpp/src/llama-quant.cpp +15 -13
- package/src/llama.cpp/src/llama-sampling.cpp +18 -6
- package/src/llama.cpp/src/llama-vocab.cpp +42 -4
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +14 -0
- package/src/llama.cpp/tests/CMakeLists.txt +10 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
- package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
- package/src/llama.cpp/tests/test-chat.cpp +3 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
- package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
- package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
- package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
- package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.h +0 -135
- package/src/llama.cpp/examples/llava/llava.cpp +0 -586
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/mtmd.h +0 -168
- package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
|
@@ -6,6 +6,15 @@
|
|
|
6
6
|
|
|
7
7
|
#include <optional>
|
|
8
8
|
|
|
9
|
+
static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
|
|
10
|
+
auto time = std::chrono::system_clock::to_time_t(now);
|
|
11
|
+
auto local_time = *std::localtime(&time);
|
|
12
|
+
std::ostringstream ss;
|
|
13
|
+
ss << std::put_time(&local_time, format.c_str());
|
|
14
|
+
auto res = ss.str();
|
|
15
|
+
return res;
|
|
16
|
+
}
|
|
17
|
+
|
|
9
18
|
typedef minja::chat_template common_chat_template;
|
|
10
19
|
|
|
11
20
|
struct common_chat_templates {
|
|
@@ -24,6 +33,7 @@ struct templates_params {
|
|
|
24
33
|
std::string grammar;
|
|
25
34
|
bool add_generation_prompt = true;
|
|
26
35
|
bool extract_reasoning = true;
|
|
36
|
+
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
|
27
37
|
};
|
|
28
38
|
|
|
29
39
|
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
|
|
@@ -125,7 +135,9 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
|
|
|
125
135
|
msgs.push_back(msg);
|
|
126
136
|
}
|
|
127
137
|
} catch (const std::exception & e) {
|
|
128
|
-
|
|
138
|
+
// @ngxson : disable otherwise it's bloating the API response
|
|
139
|
+
// printf("%s\n", std::string("; messages = ") + messages.dump(2));
|
|
140
|
+
throw std::runtime_error("Failed to parse messages: " + std::string(e.what()));
|
|
129
141
|
}
|
|
130
142
|
|
|
131
143
|
return msgs;
|
|
@@ -937,78 +949,83 @@ static void expect_tool_parameters(const std::string & name, const json & parame
|
|
|
937
949
|
}
|
|
938
950
|
}
|
|
939
951
|
|
|
940
|
-
static common_chat_params
|
|
952
|
+
static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
|
|
941
953
|
auto builtin_tools = json::array();
|
|
942
954
|
common_chat_params data;
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
955
|
+
if (!inputs.tools.is_null()) {
|
|
956
|
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
957
|
+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
958
|
+
std::vector<std::string> tool_rules;
|
|
946
959
|
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
960
|
+
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
|
961
|
+
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
|
|
962
|
+
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
|
963
|
+
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
|
964
|
+
expect_tool_parameters(name, parameters, {"query"});
|
|
965
|
+
} else if (name == "python" || name == "code_interpreter") {
|
|
966
|
+
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
|
|
967
|
+
expect_tool_parameters(name, parameters, {"code"});
|
|
968
|
+
} else {
|
|
969
|
+
return false;
|
|
970
|
+
}
|
|
958
971
|
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
972
|
+
std::vector<std::string> kvs;
|
|
973
|
+
for (const auto & [key, value] : parameters.at("properties").items()) {
|
|
974
|
+
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
|
|
975
|
+
}
|
|
963
976
|
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
977
|
+
tool_rules.push_back(
|
|
978
|
+
builder.add_rule(
|
|
979
|
+
name + "-call",
|
|
980
|
+
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
|
|
981
|
+
builtin_tools.push_back(name);
|
|
969
982
|
|
|
970
|
-
|
|
971
|
-
|
|
983
|
+
return true;
|
|
984
|
+
};
|
|
972
985
|
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
986
|
+
foreach_function(inputs.tools, [&](const json & tool) {
|
|
987
|
+
const auto & function = tool.at("function");
|
|
988
|
+
std::string name = function.at("name");
|
|
989
|
+
auto parameters = function.at("parameters");
|
|
990
|
+
builder.resolve_refs(parameters);
|
|
978
991
|
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
992
|
+
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
|
|
993
|
+
if (allow_python_tag_builtin_tools) {
|
|
994
|
+
handle_builtin_tool(name, parameters);
|
|
995
|
+
}
|
|
996
|
+
tool_rules.push_back(
|
|
997
|
+
builder.add_rule(
|
|
998
|
+
name + "-call",
|
|
999
|
+
"\"{\" space "
|
|
1000
|
+
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
|
|
1001
|
+
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
|
|
1002
|
+
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
|
|
1003
|
+
"\"}\" space"));
|
|
1004
|
+
});
|
|
1005
|
+
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
|
|
1006
|
+
data.grammar_triggers.push_back({
|
|
1007
|
+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
|
1008
|
+
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
|
|
1009
|
+
});
|
|
1010
|
+
if (!builtin_tools.empty()) {
|
|
1011
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
|
1012
|
+
data.preserved_tokens.push_back("<|python_tag|>");
|
|
982
1013
|
}
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
"\"{\" space "
|
|
987
|
-
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
|
|
988
|
-
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
|
|
989
|
-
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
|
|
990
|
-
"\"}\" space"));
|
|
1014
|
+
// Allow a few empty lines on top of the usual constrained json schema space rule.
|
|
1015
|
+
builder.add_rule("root", string_join(tool_rules, " | "));
|
|
1016
|
+
data.additional_stops.push_back("<|eom_id|>");
|
|
991
1017
|
});
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
|
999
|
-
data.preserved_tokens.push_back("<|python_tag|>");
|
|
1000
|
-
}
|
|
1001
|
-
// Allow a few empty lines on top of the usual constrained json schema space rule.
|
|
1002
|
-
builder.add_rule("root", string_join(tool_rules, " | "));
|
|
1003
|
-
});
|
|
1004
|
-
data.additional_stops.push_back("<|eom_id|>");
|
|
1018
|
+
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
|
|
1019
|
+
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
|
|
1020
|
+
: COMMON_CHAT_FORMAT_LLAMA_3_X;
|
|
1021
|
+
} else {
|
|
1022
|
+
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
1023
|
+
}
|
|
1005
1024
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
|
1025
|
+
{"date_string", format_time(inputs.now, "%d %b %Y")},
|
|
1006
1026
|
{"tools_in_user_message", false},
|
|
1007
1027
|
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
|
1008
1028
|
});
|
|
1009
|
-
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
|
|
1010
|
-
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
|
|
1011
|
-
: COMMON_CHAT_FORMAT_LLAMA_3_X;
|
|
1012
1029
|
return data;
|
|
1013
1030
|
}
|
|
1014
1031
|
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
|
|
@@ -1148,7 +1165,7 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
|
|
|
1148
1165
|
LOG_DBG("%s\n", __func__);
|
|
1149
1166
|
common_chat_params data;
|
|
1150
1167
|
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
|
1151
|
-
{"datetime", "
|
|
1168
|
+
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
|
|
1152
1169
|
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
|
1153
1170
|
});
|
|
1154
1171
|
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
|
@@ -1283,55 +1300,59 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
|
|
|
1283
1300
|
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
|
1284
1301
|
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
|
1285
1302
|
common_chat_params data;
|
|
1286
|
-
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
|
1287
|
-
std::string python_code_argument_name;
|
|
1288
|
-
auto has_raw_python = false;
|
|
1289
1303
|
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
std::string
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1304
|
+
if (!inputs.tools.is_null()) {
|
|
1305
|
+
std::string python_code_argument_name;
|
|
1306
|
+
auto has_raw_python = false;
|
|
1307
|
+
|
|
1308
|
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
1309
|
+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
1310
|
+
std::vector<std::string> tool_rules;
|
|
1311
|
+
foreach_function(inputs.tools, [&](const json & tool) {
|
|
1312
|
+
const auto & function = tool.at("function");
|
|
1313
|
+
const auto & parameters = function.at("parameters");
|
|
1314
|
+
std::string name = function.at("name");
|
|
1315
|
+
if (name == "python" || name == "ipython") {
|
|
1316
|
+
if (!parameters.contains("type")) {
|
|
1317
|
+
throw std::runtime_error("Missing type in python tool");
|
|
1318
|
+
}
|
|
1319
|
+
has_raw_python = true;
|
|
1320
|
+
const auto & type = parameters.at("type");
|
|
1321
|
+
if (type == "object") {
|
|
1322
|
+
auto properties = parameters.at("properties");
|
|
1323
|
+
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
|
1324
|
+
if (it.value().at("type") == "string") {
|
|
1325
|
+
if (!python_code_argument_name.empty()) {
|
|
1326
|
+
throw std::runtime_error("Multiple string arguments found in python tool");
|
|
1327
|
+
}
|
|
1328
|
+
python_code_argument_name = it.key();
|
|
1309
1329
|
}
|
|
1310
|
-
python_code_argument_name = it.key();
|
|
1311
1330
|
}
|
|
1331
|
+
if (python_code_argument_name.empty()) {
|
|
1332
|
+
throw std::runtime_error("No string argument found in python tool");
|
|
1333
|
+
}
|
|
1334
|
+
} else if (type != "string") {
|
|
1335
|
+
throw std::runtime_error("Invalid type in python tool: " + type.dump());
|
|
1312
1336
|
}
|
|
1313
|
-
if (python_code_argument_name.empty()) {
|
|
1314
|
-
throw std::runtime_error("No string argument found in python tool");
|
|
1315
|
-
}
|
|
1316
|
-
} else if (type != "string") {
|
|
1317
|
-
throw std::runtime_error("Invalid type in python tool: " + type.dump());
|
|
1318
1337
|
}
|
|
1338
|
+
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
|
1339
|
+
});
|
|
1340
|
+
if (has_raw_python) {
|
|
1341
|
+
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
|
1342
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
|
1343
|
+
data.preserved_tokens.push_back("<|python_tag|>");
|
|
1319
1344
|
}
|
|
1320
|
-
|
|
1345
|
+
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
|
1346
|
+
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
|
1347
|
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
|
|
1321
1348
|
});
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
}
|
|
1327
|
-
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
|
1328
|
-
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
|
1329
|
-
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
|
|
1330
|
-
});
|
|
1349
|
+
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
|
1350
|
+
} else {
|
|
1351
|
+
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
1352
|
+
}
|
|
1331
1353
|
|
|
1332
1354
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
|
1333
1355
|
// TODO: if (has_raw_python)
|
|
1334
|
-
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
|
1335
1356
|
return data;
|
|
1336
1357
|
}
|
|
1337
1358
|
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
|
|
@@ -1591,6 +1612,7 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|
|
1591
1612
|
params.extract_reasoning = inputs.extract_reasoning;
|
|
1592
1613
|
params.tool_choice = inputs.tool_choice;
|
|
1593
1614
|
params.grammar = inputs.grammar;
|
|
1615
|
+
params.now = inputs.now;
|
|
1594
1616
|
if (!inputs.json_schema.empty()) {
|
|
1595
1617
|
params.json_schema = json::parse(inputs.json_schema);
|
|
1596
1618
|
}
|
|
@@ -1642,21 +1664,21 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|
|
1642
1664
|
return common_chat_params_init_firefunction_v2(tmpl, params);
|
|
1643
1665
|
}
|
|
1644
1666
|
|
|
1645
|
-
// Plain handler (no tools)
|
|
1646
|
-
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
|
1647
|
-
return common_chat_params_init_without_tools(tmpl, params);
|
|
1648
|
-
}
|
|
1649
|
-
|
|
1650
1667
|
// Functionary v3.1 (w/ tools)
|
|
1651
1668
|
if (src.find("<|start_header_id|>") != std::string::npos
|
|
1652
1669
|
&& src.find("<function=") != std::string::npos) {
|
|
1653
1670
|
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
|
|
1654
1671
|
}
|
|
1655
1672
|
|
|
1656
|
-
// Llama 3.1, 3.2, 3.3 (w/ tools)
|
|
1673
|
+
// Llama 3.1, 3.2, 3.3 (also requires date_string so using it even w/o tools)
|
|
1657
1674
|
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
|
1658
1675
|
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
|
1659
|
-
return
|
|
1676
|
+
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
|
|
1677
|
+
}
|
|
1678
|
+
|
|
1679
|
+
// Plain handler (no tools)
|
|
1680
|
+
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
|
1681
|
+
return common_chat_params_init_without_tools(tmpl, params);
|
|
1660
1682
|
}
|
|
1661
1683
|
|
|
1662
1684
|
// Mistral Nemo (w/ tools)
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
#pragma once
|
|
4
4
|
|
|
5
5
|
#include "common.h"
|
|
6
|
+
#include <chrono>
|
|
6
7
|
#include <string>
|
|
7
8
|
#include <vector>
|
|
8
9
|
|
|
@@ -71,6 +72,7 @@ struct common_chat_templates_inputs {
|
|
|
71
72
|
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
|
72
73
|
bool parallel_tool_calls = false;
|
|
73
74
|
bool extract_reasoning = true;
|
|
75
|
+
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
|
74
76
|
};
|
|
75
77
|
|
|
76
78
|
struct common_chat_params {
|
|
@@ -443,6 +443,25 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
|
|
443
443
|
s = std::move(builder);
|
|
444
444
|
}
|
|
445
445
|
|
|
446
|
+
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
|
|
447
|
+
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
|
448
|
+
}
|
|
449
|
+
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
|
|
450
|
+
if (!str.empty() && !stop.empty()) {
|
|
451
|
+
const char text_last_char = str.back();
|
|
452
|
+
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
|
453
|
+
if (stop[char_index] == text_last_char) {
|
|
454
|
+
const auto current_partial = stop.substr(0, char_index + 1);
|
|
455
|
+
if (string_ends_with(str, current_partial)) {
|
|
456
|
+
return str.size() - char_index - 1;
|
|
457
|
+
}
|
|
458
|
+
}
|
|
459
|
+
}
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
return std::string::npos;
|
|
463
|
+
}
|
|
464
|
+
|
|
446
465
|
std::string regex_escape(const std::string & s) {
|
|
447
466
|
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
|
448
467
|
return std::regex_replace(s, special_chars, "\\$0");
|
|
@@ -1096,7 +1115,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
|
|
1096
1115
|
cparams.n_threads = params.cpuparams.n_threads;
|
|
1097
1116
|
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
|
|
1098
1117
|
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
|
|
1099
|
-
cparams.logits_all = params.logits_all;
|
|
1100
1118
|
cparams.embeddings = params.embedding;
|
|
1101
1119
|
cparams.rope_scaling_type = params.rope_scaling_type;
|
|
1102
1120
|
cparams.rope_freq_base = params.rope_freq_base;
|
|
@@ -1114,6 +1132,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
|
|
1114
1132
|
cparams.offload_kqv = !params.no_kv_offload;
|
|
1115
1133
|
cparams.flash_attn = params.flash_attn;
|
|
1116
1134
|
cparams.no_perf = params.no_perf;
|
|
1135
|
+
cparams.op_offload = !params.no_op_offload;
|
|
1117
1136
|
|
|
1118
1137
|
if (params.reranking) {
|
|
1119
1138
|
cparams.embeddings = true;
|
|
@@ -1565,3 +1584,20 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
|
|
|
1565
1584
|
|
|
1566
1585
|
return result;
|
|
1567
1586
|
}
|
|
1587
|
+
|
|
1588
|
+
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride) {
|
|
1589
|
+
const int64_t ne_datapoint = llama_n_ctx(ctx);
|
|
1590
|
+
const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride;
|
|
1591
|
+
ggml_opt_dataset_t result = ggml_opt_dataset_init(
|
|
1592
|
+
GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1);
|
|
1593
|
+
|
|
1594
|
+
llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data;
|
|
1595
|
+
llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data;
|
|
1596
|
+
|
|
1597
|
+
for (int64_t idata = 0; idata < ndata; ++idata) {
|
|
1598
|
+
memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token));
|
|
1599
|
+
memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token));
|
|
1600
|
+
}
|
|
1601
|
+
|
|
1602
|
+
return result;
|
|
1603
|
+
}
|
|
@@ -6,6 +6,7 @@
|
|
|
6
6
|
|
|
7
7
|
#include <set>
|
|
8
8
|
#include <string>
|
|
9
|
+
#include <string_view>
|
|
9
10
|
#include <vector>
|
|
10
11
|
#include <sstream>
|
|
11
12
|
|
|
@@ -66,7 +67,6 @@ enum llama_example {
|
|
|
66
67
|
LLAMA_EXAMPLE_COMMON,
|
|
67
68
|
LLAMA_EXAMPLE_SPECULATIVE,
|
|
68
69
|
LLAMA_EXAMPLE_MAIN,
|
|
69
|
-
LLAMA_EXAMPLE_INFILL,
|
|
70
70
|
LLAMA_EXAMPLE_EMBEDDING,
|
|
71
71
|
LLAMA_EXAMPLE_PERPLEXITY,
|
|
72
72
|
LLAMA_EXAMPLE_RETRIEVAL,
|
|
@@ -96,6 +96,7 @@ enum common_sampler_type {
|
|
|
96
96
|
COMMON_SAMPLER_TYPE_XTC = 8,
|
|
97
97
|
COMMON_SAMPLER_TYPE_INFILL = 9,
|
|
98
98
|
COMMON_SAMPLER_TYPE_PENALTIES = 10,
|
|
99
|
+
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
|
|
99
100
|
};
|
|
100
101
|
|
|
101
102
|
// dimensionality reduction methods, used by cvector-generator
|
|
@@ -161,6 +162,7 @@ struct common_params_sampling {
|
|
|
161
162
|
std::vector<enum common_sampler_type> samplers = {
|
|
162
163
|
COMMON_SAMPLER_TYPE_PENALTIES,
|
|
163
164
|
COMMON_SAMPLER_TYPE_DRY,
|
|
165
|
+
COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
|
|
164
166
|
COMMON_SAMPLER_TYPE_TOP_K,
|
|
165
167
|
COMMON_SAMPLER_TYPE_TYPICAL_P,
|
|
166
168
|
COMMON_SAMPLER_TYPE_TOP_P,
|
|
@@ -323,7 +325,6 @@ struct common_params {
|
|
|
323
325
|
bool ctx_shift = true; // context shift on inifinite text generation
|
|
324
326
|
|
|
325
327
|
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
|
326
|
-
bool logits_all = false; // return logits for all tokens in the batch
|
|
327
328
|
bool use_mmap = true; // use mmap for faster loads
|
|
328
329
|
bool use_mlock = false; // use mlock to keep model in memory
|
|
329
330
|
bool verbose_prompt = false; // print prompt tokens before generation
|
|
@@ -332,6 +333,7 @@ struct common_params {
|
|
|
332
333
|
bool no_kv_offload = false; // disable KV offloading
|
|
333
334
|
bool warmup = true; // warmup run
|
|
334
335
|
bool check_tensors = false; // validate tensor data
|
|
336
|
+
bool no_op_offload = false; // globally disable offload host tensor operations to device
|
|
335
337
|
|
|
336
338
|
bool single_turn = false; // single turn chat conversation
|
|
337
339
|
|
|
@@ -340,7 +342,7 @@ struct common_params {
|
|
|
340
342
|
|
|
341
343
|
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
|
|
342
344
|
|
|
343
|
-
// multimodal models (see
|
|
345
|
+
// multimodal models (see tools/mtmd)
|
|
344
346
|
struct common_params_model mmproj;
|
|
345
347
|
bool mmproj_use_gpu = true; // use GPU for multimodal model
|
|
346
348
|
bool no_mmproj = false; // explicitly disable multimodal model
|
|
@@ -366,6 +368,7 @@ struct common_params {
|
|
|
366
368
|
bool use_jinja = false; // NOLINT
|
|
367
369
|
bool enable_chat_template = true;
|
|
368
370
|
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
|
371
|
+
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
|
369
372
|
|
|
370
373
|
std::vector<std::string> api_keys;
|
|
371
374
|
|
|
@@ -409,13 +412,14 @@ struct common_params {
|
|
|
409
412
|
|
|
410
413
|
bool process_output = false; // collect data for the output tensor
|
|
411
414
|
bool compute_ppl = true; // whether to compute perplexity
|
|
415
|
+
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
|
|
412
416
|
|
|
413
417
|
// cvector-generator params
|
|
414
418
|
int n_pca_batch = 100;
|
|
415
419
|
int n_pca_iterations = 1000;
|
|
416
420
|
dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
|
|
417
|
-
std::string cvector_positive_file = "
|
|
418
|
-
std::string cvector_negative_file = "
|
|
421
|
+
std::string cvector_positive_file = "tools/cvector-generator/positive.txt";
|
|
422
|
+
std::string cvector_negative_file = "tools/cvector-generator/negative.txt";
|
|
419
423
|
|
|
420
424
|
bool spm_infill = false; // suffix/prefix/middle pattern for infill
|
|
421
425
|
|
|
@@ -501,10 +505,9 @@ static bool string_starts_with(const std::string & str,
|
|
|
501
505
|
return str.rfind(prefix, 0) == 0;
|
|
502
506
|
}
|
|
503
507
|
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
}
|
|
508
|
+
// While we wait for C++20's std::string::ends_with...
|
|
509
|
+
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
|
|
510
|
+
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
|
|
508
511
|
|
|
509
512
|
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
|
510
513
|
void string_process_escapes(std::string & input);
|
|
@@ -664,3 +667,9 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count";
|
|
|
664
667
|
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
|
|
665
668
|
|
|
666
669
|
}
|
|
670
|
+
|
|
671
|
+
//
|
|
672
|
+
// training utils
|
|
673
|
+
//
|
|
674
|
+
|
|
675
|
+
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
|
|
@@ -189,6 +189,7 @@ static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab)
|
|
|
189
189
|
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
|
|
190
190
|
/* .use_approximate_greedy_tokenize_fn = */ false,
|
|
191
191
|
/* .tokenize_user_data = */ vocab,
|
|
192
|
+
/* .slices = */ nullptr,
|
|
192
193
|
};
|
|
193
194
|
|
|
194
195
|
char error_buffer[1024];
|
|
@@ -13,10 +13,12 @@
|
|
|
13
13
|
#include <chrono>
|
|
14
14
|
#include <cstddef>
|
|
15
15
|
#include <cstdio>
|
|
16
|
+
#include <ctime>
|
|
16
17
|
#include <exception>
|
|
17
18
|
#include <iomanip>
|
|
18
19
|
#include <memory>
|
|
19
20
|
#include <sstream>
|
|
21
|
+
#include <stdexcept>
|
|
20
22
|
#include <string>
|
|
21
23
|
#include <vector>
|
|
22
24
|
|
|
@@ -393,8 +395,8 @@ class chat_template {
|
|
|
393
395
|
|
|
394
396
|
for (const auto & message_ : adjusted_messages) {
|
|
395
397
|
auto message = message_;
|
|
396
|
-
if (!message.contains("role") || !message.contains("content")) {
|
|
397
|
-
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
|
398
|
+
if (!message.contains("role") || (!message.contains("content") && !message.contains("tool_calls"))) {
|
|
399
|
+
throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump());
|
|
398
400
|
}
|
|
399
401
|
std::string role = message.at("role");
|
|
400
402
|
|
|
@@ -415,7 +417,6 @@ class chat_template {
|
|
|
415
417
|
}
|
|
416
418
|
}
|
|
417
419
|
if (polyfill_tool_calls) {
|
|
418
|
-
auto content = message.at("content");
|
|
419
420
|
auto tool_calls = json::array();
|
|
420
421
|
for (const auto & tool_call : message.at("tool_calls")) {
|
|
421
422
|
if (tool_call.at("type") != "function") {
|
|
@@ -434,8 +435,11 @@ class chat_template {
|
|
|
434
435
|
auto obj = json {
|
|
435
436
|
{"tool_calls", tool_calls},
|
|
436
437
|
};
|
|
437
|
-
if (
|
|
438
|
-
|
|
438
|
+
if (message.contains("content")) {
|
|
439
|
+
auto content = message.at("content");
|
|
440
|
+
if (!content.is_null() && !content.empty()) {
|
|
441
|
+
obj["content"] = content;
|
|
442
|
+
}
|
|
439
443
|
}
|
|
440
444
|
message["content"] = obj.dump(2);
|
|
441
445
|
message.erase("tool_calls");
|