cui-llama.rn 1.4.4 → 1.4.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. package/android/src/main/CMakeLists.txt +2 -2
  2. package/android/src/main/jni.cpp +12 -10
  3. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  11. package/cpp/chat-template.hpp +529 -529
  12. package/cpp/chat.cpp +959 -265
  13. package/cpp/chat.h +135 -0
  14. package/cpp/common.cpp +2064 -1996
  15. package/cpp/common.h +700 -744
  16. package/cpp/ggml-alloc.c +1039 -1030
  17. package/cpp/ggml-alloc.h +1 -1
  18. package/cpp/ggml-backend-impl.h +255 -255
  19. package/cpp/ggml-backend-reg.cpp +586 -582
  20. package/cpp/ggml-backend.cpp +2004 -2002
  21. package/cpp/ggml-backend.h +354 -354
  22. package/cpp/ggml-common.h +1851 -1851
  23. package/cpp/ggml-cpp.h +39 -39
  24. package/cpp/ggml-cpu-aarch64.cpp +4248 -4247
  25. package/cpp/ggml-cpu-aarch64.h +8 -8
  26. package/cpp/ggml-cpu-impl.h +531 -380
  27. package/cpp/ggml-cpu-quants.c +12527 -11517
  28. package/cpp/ggml-cpu-traits.cpp +36 -36
  29. package/cpp/ggml-cpu-traits.h +38 -38
  30. package/cpp/ggml-cpu.c +15766 -14485
  31. package/cpp/ggml-cpu.cpp +655 -633
  32. package/cpp/ggml-cpu.h +138 -135
  33. package/cpp/ggml-impl.h +567 -567
  34. package/cpp/ggml-metal-impl.h +235 -0
  35. package/cpp/ggml-metal.h +66 -66
  36. package/cpp/ggml-metal.m +5146 -5002
  37. package/cpp/ggml-opt.cpp +854 -854
  38. package/cpp/ggml-opt.h +216 -216
  39. package/cpp/ggml-quants.c +5238 -5238
  40. package/cpp/ggml-threading.h +14 -14
  41. package/cpp/ggml.c +6529 -6524
  42. package/cpp/ggml.h +2198 -2194
  43. package/cpp/gguf.cpp +1329 -1329
  44. package/cpp/gguf.h +202 -202
  45. package/cpp/json-schema-to-grammar.cpp +1024 -1025
  46. package/cpp/json-schema-to-grammar.h +21 -22
  47. package/cpp/json.hpp +24766 -24766
  48. package/cpp/llama-adapter.cpp +347 -347
  49. package/cpp/llama-adapter.h +74 -74
  50. package/cpp/llama-arch.cpp +1513 -1492
  51. package/cpp/llama-arch.h +403 -402
  52. package/cpp/llama-batch.cpp +368 -368
  53. package/cpp/llama-batch.h +88 -88
  54. package/cpp/llama-chat.cpp +588 -587
  55. package/cpp/llama-chat.h +53 -53
  56. package/cpp/llama-context.cpp +1775 -1775
  57. package/cpp/llama-context.h +128 -128
  58. package/cpp/llama-cparams.cpp +1 -1
  59. package/cpp/llama-cparams.h +37 -37
  60. package/cpp/llama-cpp.h +30 -30
  61. package/cpp/llama-grammar.cpp +1219 -1219
  62. package/cpp/llama-grammar.h +173 -164
  63. package/cpp/llama-hparams.cpp +71 -71
  64. package/cpp/llama-hparams.h +139 -139
  65. package/cpp/llama-impl.cpp +167 -167
  66. package/cpp/llama-impl.h +61 -61
  67. package/cpp/llama-kv-cache.cpp +718 -718
  68. package/cpp/llama-kv-cache.h +219 -218
  69. package/cpp/llama-mmap.cpp +600 -590
  70. package/cpp/llama-mmap.h +68 -68
  71. package/cpp/llama-model-loader.cpp +1124 -1124
  72. package/cpp/llama-model-loader.h +167 -167
  73. package/cpp/llama-model.cpp +4087 -4023
  74. package/cpp/llama-model.h +370 -370
  75. package/cpp/llama-sampling.cpp +2558 -2525
  76. package/cpp/llama-sampling.h +32 -32
  77. package/cpp/llama-vocab.cpp +3264 -3252
  78. package/cpp/llama-vocab.h +125 -125
  79. package/cpp/llama.cpp +10284 -10137
  80. package/cpp/llama.h +1354 -1340
  81. package/cpp/log.cpp +393 -423
  82. package/cpp/log.h +132 -132
  83. package/cpp/minja/chat-template.hpp +529 -0
  84. package/cpp/minja/minja.hpp +2915 -0
  85. package/cpp/minja.hpp +2915 -2883
  86. package/cpp/rn-llama.cpp +20 -37
  87. package/cpp/rn-llama.h +12 -2
  88. package/cpp/sampling.cpp +570 -532
  89. package/cpp/sgemm.cpp +2598 -2598
  90. package/cpp/sgemm.h +14 -14
  91. package/cpp/speculative.cpp +278 -277
  92. package/cpp/speculative.h +28 -28
  93. package/package.json +1 -1
  94. package/android/src/main/build-arm64/CMakeCache.txt +0 -429
  95. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCCompiler.cmake +0 -81
  96. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCXXCompiler.cmake +0 -101
  97. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_C.bin +0 -0
  98. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_CXX.bin +0 -0
  99. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeSystem.cmake +0 -15
  100. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.c +0 -904
  101. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.o +0 -0
  102. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.cpp +0 -919
  103. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.o +0 -0
  104. package/android/src/main/build-arm64/CMakeFiles/CMakeConfigureLog.yaml +0 -431
  105. package/android/src/main/build-arm64/CMakeFiles/CMakeDirectoryInformation.cmake +0 -16
  106. package/android/src/main/build-arm64/CMakeFiles/Makefile.cmake +0 -165
  107. package/android/src/main/build-arm64/CMakeFiles/Makefile2 +0 -297
  108. package/android/src/main/build-arm64/CMakeFiles/Progress/1 +0 -1
  109. package/android/src/main/build-arm64/CMakeFiles/Progress/2 +0 -1
  110. package/android/src/main/build-arm64/CMakeFiles/Progress/3 +0 -1
  111. package/android/src/main/build-arm64/CMakeFiles/Progress/4 +0 -1
  112. package/android/src/main/build-arm64/CMakeFiles/Progress/5 +0 -1
  113. package/android/src/main/build-arm64/CMakeFiles/Progress/6 +0 -1
  114. package/android/src/main/build-arm64/CMakeFiles/Progress/count.txt +0 -1
  115. package/android/src/main/build-arm64/CMakeFiles/TargetDirectories.txt +0 -8
  116. package/android/src/main/build-arm64/CMakeFiles/cmake.check_cache +0 -1
  117. package/android/src/main/build-arm64/CMakeFiles/progress.marks +0 -1
  118. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o +0 -0
  119. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o.d +0 -58
  120. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o +0 -0
  121. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o.d +0 -756
  122. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o +0 -0
  123. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o.d +0 -709
  124. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o +0 -0
  125. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o.d +0 -714
  126. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o +0 -0
  127. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o.d +0 -62
  128. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o +0 -0
  129. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o.d +0 -708
  130. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o +0 -0
  131. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o.d +0 -113
  132. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o +0 -0
  133. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o.d +0 -713
  134. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o +0 -0
  135. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o.d +0 -763
  136. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o +0 -0
  137. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o.d +0 -61
  138. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o +0 -0
  139. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o.d +0 -707
  140. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o +0 -0
  141. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o.d +0 -104
  142. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o +0 -0
  143. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o.d +0 -714
  144. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o +0 -0
  145. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o.d +0 -723
  146. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/DependInfo.cmake +0 -62
  147. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/build.make +0 -722
  148. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/cmake_clean.cmake +0 -89
  149. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.make +0 -2
  150. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.ts +0 -2
  151. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/depend.make +0 -2
  152. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/flags.make +0 -17
  153. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/progress.make +0 -41
  154. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/DependInfo.cmake +0 -62
  155. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/build.make +0 -722
  156. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/cmake_clean.cmake +0 -89
  157. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.make +0 -2
  158. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.ts +0 -2
  159. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/depend.make +0 -2
  160. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/flags.make +0 -17
  161. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/progress.make +0 -41
  162. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/DependInfo.cmake +0 -62
  163. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/build.make +0 -722
  164. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/cmake_clean.cmake +0 -89
  165. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.make +0 -2
  166. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.ts +0 -2
  167. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/depend.make +0 -2
  168. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/flags.make +0 -17
  169. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/progress.make +0 -41
  170. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/DependInfo.cmake +0 -62
  171. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/build.make +0 -722
  172. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/cmake_clean.cmake +0 -89
  173. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.make +0 -2
  174. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.ts +0 -2
  175. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/depend.make +0 -2
  176. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/flags.make +0 -17
  177. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/progress.make +0 -41
  178. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/DependInfo.cmake +0 -62
  179. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/build.make +0 -722
  180. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/cmake_clean.cmake +0 -89
  181. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.make +0 -2
  182. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.ts +0 -2
  183. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/depend.make +0 -2
  184. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/flags.make +0 -17
  185. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/progress.make +0 -41
  186. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/DependInfo.cmake +0 -62
  187. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/build.make +0 -722
  188. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/cmake_clean.cmake +0 -89
  189. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.make +0 -2
  190. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.ts +0 -2
  191. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/depend.make +0 -2
  192. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/flags.make +0 -17
  193. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/progress.make +0 -41
  194. package/android/src/main/build-arm64/Makefile +0 -1862
  195. package/android/src/main/build-arm64/cmake_install.cmake +0 -66
  196. package/cpp/chat.hpp +0 -55
  197. package/cpp/rn-llama.hpp +0 -913
package/cpp/chat.cpp CHANGED
@@ -1,9 +1,437 @@
1
- #include "chat.hpp"
2
- #include "chat-template.hpp"
1
+ #include "chat.h"
3
2
  #include "json-schema-to-grammar.h"
4
3
  #include "log.h"
4
+ #include "chat-template.hpp"
5
5
  #include "minja.hpp"
6
6
 
7
+ #include <optional>
8
+
9
+ typedef minja::chat_template common_chat_template;
10
+
11
+ struct common_chat_templates {
12
+ bool has_explicit_template; // Model had builtin template or template overridde was specified.
13
+ std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
14
+ std::unique_ptr<common_chat_template> template_tool_use;
15
+ };
16
+
17
+ struct templates_params {
18
+ json messages;
19
+ json tools;
20
+ common_chat_tool_choice tool_choice;
21
+ json json_schema;
22
+ bool parallel_tool_calls;
23
+ bool stream;
24
+ std::string grammar;
25
+ bool add_generation_prompt = true;
26
+ bool extract_reasoning = true;
27
+ };
28
+
29
+ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
30
+ if (tool_choice == "auto") {
31
+ return COMMON_CHAT_TOOL_CHOICE_AUTO;
32
+ }
33
+ if (tool_choice == "none") {
34
+ return COMMON_CHAT_TOOL_CHOICE_NONE;
35
+ }
36
+ if (tool_choice == "required") {
37
+ return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
38
+ }
39
+ throw std::runtime_error("Invalid tool_choice: " + tool_choice);
40
+ }
41
+
42
+ template <>
43
+ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
44
+ std::vector<common_chat_msg> msgs;
45
+
46
+ try {
47
+
48
+ if (!messages.is_array()) {
49
+ throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
50
+ }
51
+
52
+ for (const auto & message : messages) {
53
+ if (!message.is_object()) {
54
+ throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
55
+ }
56
+
57
+ common_chat_msg msg;
58
+ if (!message.contains("role")) {
59
+ throw std::runtime_error("Missing 'role' in message: " + message.dump());
60
+ }
61
+ msg.role = message.at("role");
62
+
63
+ auto has_content = message.contains("content");
64
+ auto has_tool_calls = message.contains("tool_calls");
65
+ if (has_content) {
66
+ const auto & content = message.at("content");
67
+ if (content.is_string()) {
68
+ msg.content = content;
69
+ } else if (content.is_array()) {
70
+ for (const auto & part : content) {
71
+ if (!part.contains("type")) {
72
+ throw std::runtime_error("Missing content part type: " + part.dump());
73
+ }
74
+ const auto & type = part.at("type");
75
+ if (type != "text") {
76
+ throw std::runtime_error("Unsupported content part type: " + type.dump());
77
+ }
78
+ common_chat_msg_content_part msg_part;
79
+ msg_part.type = type;
80
+ msg_part.text = part.at("text");
81
+ msg.content_parts.push_back(msg_part);
82
+ }
83
+ } else if (!content.is_null()) {
84
+ throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
85
+ }
86
+ }
87
+ if (has_tool_calls) {
88
+ for (const auto & tool_call : message.at("tool_calls")) {
89
+ common_chat_tool_call tc;
90
+ if (!tool_call.contains("type")) {
91
+ throw std::runtime_error("Missing tool call type: " + tool_call.dump());
92
+ }
93
+ const auto & type = tool_call.at("type");
94
+ if (type != "function") {
95
+ throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
96
+ }
97
+ if (!tool_call.contains("function")) {
98
+ throw std::runtime_error("Missing tool call function: " + tool_call.dump());
99
+ }
100
+ const auto & fc = tool_call.at("function");
101
+ if (!fc.contains("name")) {
102
+ throw std::runtime_error("Missing tool call name: " + tool_call.dump());
103
+ }
104
+ tc.name = fc.at("name");
105
+ tc.arguments = fc.at("arguments");
106
+ if (tool_call.contains("id")) {
107
+ tc.id = tool_call.at("id");
108
+ }
109
+ msg.tool_calls.push_back(tc);
110
+ }
111
+ }
112
+ if (!has_content && !has_tool_calls) {
113
+ throw std::runtime_error("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)");
114
+ }
115
+ if (message.contains("reasoning_content")) {
116
+ msg.reasoning_content = message.at("reasoning_content");
117
+ }
118
+ if (message.contains("name")) {
119
+ msg.tool_name = message.at("name");
120
+ }
121
+ if (message.contains("tool_call_id")) {
122
+ msg.tool_call_id = message.at("tool_call_id");
123
+ }
124
+
125
+ msgs.push_back(msg);
126
+ }
127
+ } catch (const std::exception & e) {
128
+ throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2));
129
+ }
130
+
131
+ return msgs;
132
+ }
133
+
134
+ template <>
135
+ json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
136
+ json messages = json::array();
137
+ for (const auto & msg : msgs) {
138
+ if (!msg.content.empty() && !msg.content_parts.empty()) {
139
+ throw std::runtime_error("Cannot specify both content and content_parts");
140
+ }
141
+ json jmsg {
142
+ {"role", msg.role},
143
+ };
144
+ if (!msg.content.empty()) {
145
+ jmsg["content"] = msg.content;
146
+ } else if (!msg.content_parts.empty()) {
147
+ if (concat_typed_text) {
148
+ std::string text;
149
+ for (const auto & part : msg.content_parts) {
150
+ if (part.type != "text") {
151
+ LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
152
+ continue;
153
+ }
154
+ if (!text.empty()) {
155
+ text += '\n';
156
+ }
157
+ text += part.text;
158
+ }
159
+ jmsg["content"] = text;
160
+ } else {
161
+ auto & parts = jmsg["content"] = json::array();
162
+ for (const auto & part : msg.content_parts) {
163
+ parts.push_back({
164
+ {"type", part.type},
165
+ {"text", part.text},
166
+ });
167
+ }
168
+ }
169
+ } else {
170
+ jmsg["content"] = json(); // null
171
+ }
172
+ if (!msg.reasoning_content.empty()) {
173
+ jmsg["reasoning_content"] = msg.reasoning_content;
174
+ }
175
+ if (!msg.tool_name.empty()) {
176
+ jmsg["name"] = msg.tool_name;
177
+ }
178
+ if (!msg.tool_call_id.empty()) {
179
+ jmsg["tool_call_id"] = msg.tool_call_id;
180
+ }
181
+ if (!msg.tool_calls.empty()) {
182
+ auto & tool_calls = jmsg["tool_calls"] = json::array();
183
+ for (const auto & tool_call : msg.tool_calls) {
184
+ json tc {
185
+ {"type", "function"},
186
+ {"function", {
187
+ {"name", tool_call.name},
188
+ {"arguments", tool_call.arguments},
189
+ }},
190
+ };
191
+ if (!tool_call.id.empty()) {
192
+ tc["id"] = tool_call.id;
193
+ }
194
+ tool_calls.push_back(tc);
195
+ }
196
+ }
197
+ messages.push_back(jmsg);
198
+ }
199
+ return messages;
200
+ }
201
+
202
+ template <>
203
+ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
204
+ return common_chat_msgs_parse_oaicompat(json::parse(messages));
205
+ }
206
+
207
+ template <>
208
+ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
209
+ std::vector<common_chat_tool> result;
210
+
211
+ try {
212
+ if (!tools.is_null()) {
213
+ if (!tools.is_array()) {
214
+ throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
215
+ }
216
+ for (const auto & tool : tools) {
217
+ if (!tool.contains("type")) {
218
+ throw std::runtime_error("Missing tool type: " + tool.dump());
219
+ }
220
+ const auto & type = tool.at("type");
221
+ if (!type.is_string() || type != "function") {
222
+ throw std::runtime_error("Unsupported tool type: " + tool.dump());
223
+ }
224
+ if (!tool.contains("function")) {
225
+ throw std::runtime_error("Missing tool function: " + tool.dump());
226
+ }
227
+
228
+ const auto & function = tool.at("function");
229
+ result.push_back({
230
+ /* .name = */ function.at("name"),
231
+ /* .description = */ function.at("description"),
232
+ /* .parameters = */ function.at("parameters").dump(),
233
+ });
234
+ }
235
+ }
236
+ } catch (const std::exception & e) {
237
+ throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2));
238
+ }
239
+
240
+ return result;
241
+ }
242
+
243
+ template <>
244
+ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
245
+ return common_chat_tools_parse_oaicompat(json::parse(tools));
246
+ }
247
+
248
+ template <>
249
+ json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
250
+ if (tools.empty()) {
251
+ return json();
252
+ }
253
+
254
+ auto result = json::array();
255
+ for (const auto & tool : tools) {
256
+ result.push_back({
257
+ {"type", "function"},
258
+ {"function", {
259
+ {"name", tool.name},
260
+ {"description", tool.description},
261
+ {"parameters", json::parse(tool.parameters)},
262
+ }},
263
+ });
264
+ }
265
+ return result;
266
+ }
267
+
268
+ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
269
+ if (use_jinja) {
270
+ try {
271
+ common_chat_msg msg;
272
+ msg.role = "user";
273
+ msg.content = "test";
274
+
275
+ auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);
276
+
277
+ common_chat_templates_inputs inputs;
278
+ inputs.messages = {msg};
279
+
280
+ common_chat_templates_apply(tmpls.get(), inputs);
281
+ return true;
282
+ } catch (const std::exception & e) {
283
+ LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
284
+ return false;
285
+ }
286
+ }
287
+ llama_chat_message chat[] = {{"user", "test"}};
288
+ const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
289
+ return res >= 0;
290
+ }
291
+
292
+ std::string common_chat_format_single(
293
+ const struct common_chat_templates * tmpls,
294
+ const std::vector<common_chat_msg> & past_msg,
295
+ const common_chat_msg & new_msg,
296
+ bool add_ass,
297
+ bool use_jinja) {
298
+
299
+ common_chat_templates_inputs inputs;
300
+ inputs.use_jinja = use_jinja;
301
+
302
+ std::string fmt_past_msg;
303
+ if (!past_msg.empty()) {
304
+ inputs.messages = past_msg;
305
+ inputs.add_generation_prompt = false;
306
+ fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt;
307
+ }
308
+ std::ostringstream ss;
309
+ // if the past_msg ends with a newline, we must preserve it in the formatted version
310
+ if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
311
+ ss << "\n";
312
+ };
313
+ // format chat with new_msg
314
+ inputs.messages.push_back(new_msg);
315
+ inputs.add_generation_prompt = add_ass;
316
+ auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt;
317
+ // get the diff part
318
+ ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
319
+ return ss.str();
320
+ }
321
+
322
+ std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
323
+ common_chat_templates_inputs inputs;
324
+ inputs.use_jinja = use_jinja;
325
+ auto add_simple_msg = [&](auto role, auto content) {
326
+ common_chat_msg msg;
327
+ msg.role = role;
328
+ msg.content = content;
329
+ inputs.messages.push_back(msg);
330
+ };
331
+ add_simple_msg("system", "You are a helpful assistant");
332
+ add_simple_msg("user", "Hello");
333
+ add_simple_msg("assistant", "Hi there");
334
+ add_simple_msg("user", "How are you?");
335
+ return common_chat_templates_apply(tmpls, inputs).prompt;
336
+ }
337
+
338
+ #define CHATML_TEMPLATE_SRC \
339
+ "{%- for message in messages -%}\n" \
340
+ " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
341
+ "{%- endfor -%}\n" \
342
+ "{%- if add_generation_prompt -%}\n" \
343
+ " {{- '<|im_start|>assistant\n' -}}\n" \
344
+ "{%- endif -%}"
345
+
346
+ void common_chat_templates_free(struct common_chat_templates * tmpls) {
347
+ delete tmpls;
348
+ }
349
+
350
+ bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) {
351
+ return tmpls->has_explicit_template;
352
+ }
353
+
354
+ const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
355
+ if (variant != nullptr) {
356
+ if (strcmp(variant, "tool_use") == 0) {
357
+ if (tmpls->template_tool_use) {
358
+ return tmpls->template_tool_use->source().c_str();
359
+ }
360
+ return nullptr;
361
+ } else {
362
+ LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
363
+ }
364
+ }
365
+ return tmpls->template_default->source().c_str();
366
+ }
367
+
368
+ common_chat_templates_ptr common_chat_templates_init(
369
+ const struct llama_model * model,
370
+ const std::string & chat_template_override,
371
+ const std::string & bos_token_override,
372
+ const std::string & eos_token_override)
373
+ {
374
+ std::string default_template_src;
375
+ std::string template_tool_use_src;
376
+
377
+ bool has_explicit_template = !chat_template_override.empty();
378
+ if (chat_template_override.empty()) {
379
+ LM_GGML_ASSERT(model != nullptr);
380
+ const auto * str = llama_model_chat_template(model, /* name */ nullptr);
381
+ if (str) {
382
+ default_template_src = str;
383
+ has_explicit_template = true;
384
+ }
385
+ str = llama_model_chat_template(model, /* name */ "tool_use");
386
+ if (str) {
387
+ template_tool_use_src = str;
388
+ has_explicit_template = true;
389
+ }
390
+ } else {
391
+ default_template_src = chat_template_override;
392
+ }
393
+ if (default_template_src.empty() || default_template_src == "chatml") {
394
+ if (!template_tool_use_src.empty()) {
395
+ default_template_src = template_tool_use_src;
396
+ } else {
397
+ default_template_src = CHATML_TEMPLATE_SRC;
398
+ }
399
+ }
400
+ std::string token_bos = bos_token_override;
401
+ std::string token_eos = eos_token_override;
402
+ if (model) {
403
+ const auto * vocab = llama_model_get_vocab(model);
404
+ const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
405
+ if (token == LLAMA_TOKEN_NULL) {
406
+ if (default_template_src.find(jinja_variable_name) != std::string::npos
407
+ || template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
408
+ LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name);
409
+ }
410
+ return std::string();
411
+ }
412
+ return common_token_to_piece(vocab, token, true);
413
+ };
414
+ token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
415
+ token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
416
+ }
417
+ common_chat_templates_ptr tmpls(new common_chat_templates());
418
+ tmpls->has_explicit_template = has_explicit_template;
419
+ try {
420
+ tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
421
+ } catch (const std::exception & e) {
422
+ LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
423
+ tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
424
+ }
425
+ if (!template_tool_use_src.empty()) {
426
+ try {
427
+ tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
428
+ } catch (const std::exception & e) {
429
+ LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
430
+ }
431
+ }
432
+ return tmpls;
433
+ }
434
+
7
435
  std::string common_chat_format_name(common_chat_format format) {
8
436
  switch (format) {
9
437
  case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only";
@@ -17,6 +445,7 @@ std::string common_chat_format_name(common_chat_format format) {
17
445
  case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
18
446
  case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
19
447
  case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
448
+ case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING: return "Hermes 2 Pro (extract reasoning)";
20
449
  case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
21
450
  case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)";
22
451
  default:
@@ -24,12 +453,6 @@ std::string common_chat_format_name(common_chat_format format) {
24
453
  }
25
454
  }
26
455
 
27
- const common_grammar_options grammar_options {
28
- /* .dotall = */ false,
29
- /* .compact_spaces = */ false,
30
- // /* .compact_spaces = */ true,
31
- };
32
-
33
456
  static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
34
457
  // // https://json.nlohmann.me/features/parsing/sax_interface/
35
458
  struct json_error_locator : public nlohmann::json_sax<json> {
@@ -38,22 +461,22 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
38
461
 
39
462
  json_error_locator() : position(0), found_error(false) {}
40
463
 
41
- bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
464
+ bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT
42
465
  this->position = position - 1;
43
466
  this->found_error = true;
44
467
  return false;
45
468
  }
46
- bool null() override { return true; }
47
- bool boolean(bool) override { return true; }
48
- bool number_integer(number_integer_t) override { return true; }
49
- bool number_unsigned(number_unsigned_t) override { return true; }
50
- bool number_float(number_float_t, const string_t &) override { return true; }
51
- bool string(string_t &) override { return true; }
52
- bool binary(binary_t &) override { return true; }
53
- bool start_object(std::size_t) override { return true; }
54
- bool key(string_t &) override { return true; }
469
+ bool null() override { return true; } // NOLINT
470
+ bool boolean(bool) override { return true; } // NOLINT
471
+ bool number_integer(number_integer_t) override { return true; } // NOLINT
472
+ bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT
473
+ bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT
474
+ bool string(string_t &) override { return true; } // NOLINT
475
+ bool binary(binary_t &) override { return true; } // NOLINT
476
+ bool start_object(std::size_t) override { return true; } // NOLINT
477
+ bool key(string_t &) override { return true; } // NOLINT
55
478
  bool end_object() override { return true; }
56
- bool start_array(std::size_t) override { return true; }
479
+ bool start_array(std::size_t) override { return true; } // NOLINT
57
480
  bool end_array() override { return true; }
58
481
  };
59
482
  json_error_locator err_loc;
@@ -75,6 +498,34 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
75
498
  }
76
499
  }
77
500
 
501
+ static bool parse_literal(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
502
+ auto expected_it = expected.begin();
503
+ auto tmp_it = it;
504
+ while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
505
+ ++tmp_it;
506
+ ++expected_it;
507
+ }
508
+ if (expected_it == expected.end()) {
509
+ it = tmp_it;
510
+ return true;
511
+ }
512
+ return false;
513
+ }
514
+
515
+ static std::optional<std::smatch> parse_pattern(std::string::const_iterator & it, const std::string::const_iterator & end, const std::regex & expected) {
516
+ std::smatch match;
517
+ if (std::regex_match(it, end, match, expected)) {
518
+ it = match.suffix().first;
519
+ return match;
520
+ }
521
+ return std::nullopt;
522
+ }
523
+
524
+ static void consume_spaces(std::string::const_iterator & it, const std::string::const_iterator & end) {
525
+ while (it != end && std::isspace(*it)) {
526
+ ++it;
527
+ }
528
+ }
78
529
 
79
530
  /**
80
531
  * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
@@ -84,7 +535,8 @@ static common_chat_msg parse_json_tool_calls(
84
535
  const std::string& input,
85
536
  const std::optional<std::regex> & trigger_opt,
86
537
  const std::regex & function_regex,
87
- const std::regex & close_regex) {
538
+ const std::regex & close_regex,
539
+ bool allow_raw_python = false) {
88
540
  std::smatch match;
89
541
 
90
542
  common_chat_msg result;
@@ -115,14 +567,19 @@ static common_chat_msg parse_json_tool_calls(
115
567
  it = rit->suffix().first;
116
568
 
117
569
  json arguments;
118
- if (!parse_json(it, end, arguments)) {
570
+ if (parse_json(it, end, arguments)) {
571
+ if (!std::regex_search(it, end, match, close_regex)) {
572
+ throw std::runtime_error("Malformed input, missing closing pattern: " + input);
573
+ }
574
+ it = match.suffix().first;
575
+ result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
576
+ } else {
577
+ if (allow_raw_python && name == "python") {
578
+ result.tool_calls.push_back({name, json({{"code", std::string(it, end)}}).dump(), /* id= */ ""});
579
+ break;
580
+ }
119
581
  throw std::runtime_error("Failed to parse json tool call arguments: " + input);
120
582
  }
121
- if (!std::regex_search(it, end, match, close_regex)) {
122
- throw std::runtime_error("Malformed input, missing closing pattern: " + input);
123
- }
124
- it = match.suffix().first;
125
- result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
126
583
  }
127
584
 
128
585
  if (!result.tool_calls.empty()) {
@@ -134,29 +591,29 @@ static common_chat_msg parse_json_tool_calls(
134
591
  return result;
135
592
  }
136
593
 
594
+ static common_chat_tool_call process_tool_call(const json & tool_call) {
595
+ const auto & arguments = tool_call.at("arguments");
596
+ return {
597
+ /* .name = */ tool_call.at("name"),
598
+ /* .arguments = */ arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
599
+ /* .id = */ tool_call.contains("id") ? tool_call.at("id") : "",
600
+ };
601
+ }
137
602
  static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
138
603
  auto content_end = input.find(prefix);
139
604
  size_t tc_start = std::string::npos;
140
605
 
141
606
  common_chat_msg result;
142
607
  result.role = "assistant";
143
- const auto process_tool_calls = [&](const json & tool_calls) {
144
- for (const auto & tool_call : tool_calls) {
145
- const auto & arguments = tool_call.at("arguments");
146
- result.tool_calls.push_back({
147
- tool_call.at("name"),
148
- arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
149
- tool_call.contains("id") ? tool_call.at("id") : "",
150
- });
151
- }
152
- };
153
608
  if (content_end == std::string::npos) {
154
609
  result.content = input;
155
610
  } else {
156
611
  tc_start = content_end + prefix.size() - rstrip_prefix;
157
612
  result.content = input.substr(0, content_end);
158
613
  auto tool_calls = json::parse(input.substr(tc_start));
159
- process_tool_calls(tool_calls);
614
+ for (const auto & tool_call : tool_calls) {
615
+ result.tool_calls.emplace_back(process_tool_call(tool_call));
616
+ }
160
617
  }
161
618
  return result;
162
619
  }
@@ -187,13 +644,20 @@ static std::string apply(
187
644
  // tmpl_inputs.now = std::chrono::system_clock::now();
188
645
 
189
646
  minja::chat_template_options tmpl_opts;
190
- tmpl_opts.use_bos_token = false;
191
- tmpl_opts.use_eos_token = false;
192
-
193
- return tmpl.apply(tmpl_inputs, tmpl_opts);
647
+ // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
648
+ // instead of using `chat_template_options.use_bos_token = false`, since these tokens
649
+ // may be needed inside the template / between messages too.
650
+ auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
651
+ if (string_starts_with(result, tmpl.bos_token())) {
652
+ result = result.substr(tmpl.bos_token().size());
653
+ }
654
+ if (string_ends_with(result, tmpl.eos_token())) {
655
+ result = result.substr(0, result.size() - tmpl.eos_token().size());
656
+ }
657
+ return result;
194
658
  }
195
659
 
196
- static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
660
+ static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
197
661
  common_chat_params data;
198
662
 
199
663
  auto tool_call_schemas = json::array();
@@ -247,7 +711,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
247
711
  {"required", json::array({"tool_call"})},
248
712
  };
249
713
  const auto schema =
250
- inputs.tool_choice != "required"
714
+ inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED
251
715
  ? json {
252
716
  {"anyOf", json::array({
253
717
  tool_call,
@@ -268,7 +732,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
268
732
  data.grammar_lazy = false;
269
733
  data.grammar = build_grammar([&](const common_grammar_builder & builder) {
270
734
  builder.add_schema("root", schema);
271
- }, grammar_options);
735
+ });
272
736
 
273
737
  auto tweaked_messages = common_chat_template::add_system(
274
738
  inputs.messages,
@@ -303,9 +767,9 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) {
303
767
  return result;
304
768
  }
305
769
 
306
- static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
770
+ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
307
771
  common_chat_params data;
308
- data.grammar_lazy = inputs.tool_choice != "required";
772
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
309
773
  data.grammar = build_grammar([&](const common_grammar_builder & builder) {
310
774
  auto schemas = json::array();
311
775
  foreach_function(inputs.tools, [&](const json & tool) {
@@ -338,8 +802,11 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
338
802
  schema["maxItems"] = 1;
339
803
  }
340
804
  builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
341
- }, grammar_options);
342
- data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
805
+ });
806
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"});
807
+ data.preserved_tokens = {
808
+ "[TOOL_CALLS]",
809
+ };
343
810
  data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
344
811
  data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
345
812
  return data;
@@ -348,9 +815,9 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
348
815
  return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
349
816
  }
350
817
 
351
- static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
818
+ static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
352
819
  common_chat_params data;
353
- data.grammar_lazy = inputs.tool_choice != "required";
820
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
354
821
  data.grammar = build_grammar([&](const common_grammar_builder & builder) {
355
822
  auto schemas = json::array();
356
823
  foreach_function(inputs.tools, [&](const json & tool) {
@@ -381,14 +848,18 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
381
848
  schema["maxItems"] = 1;
382
849
  }
383
850
  builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
384
- }, grammar_options);
385
- data.grammar_triggers.push_back({"<|START_ACTION|>", /* .at_start = */ false});
851
+ });
852
+ data.grammar_triggers.push_back({
853
+ COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
854
+ "<|START_ACTION|>",
855
+ });
386
856
  data.preserved_tokens = {
857
+ "<|START_ACTION|>",
858
+ "<|END_ACTION|>",
387
859
  "<|START_RESPONSE|>",
388
860
  "<|END_RESPONSE|>",
389
861
  "<|START_THINKING|>",
390
862
  "<|END_THINKING|>",
391
- "<|END_ACTION|>",
392
863
  };
393
864
  auto adjusted_messages = json::array();
394
865
  for (const auto & msg : inputs.messages) {
@@ -408,9 +879,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
408
879
  return data;
409
880
  }
410
881
  static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) {
411
- static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|>)([\\s\\S\\n\\r]*)");
412
- static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
413
- static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>");
882
+ static const std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)");
883
+ static const std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>");
884
+ static const std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>");
414
885
 
415
886
  std::smatch match;
416
887
 
@@ -455,10 +926,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame
455
926
  const auto & parameters_required = parameters.at("required");
456
927
  for (const auto & prop : expected_properties) {
457
928
  if (!parameters_properties.contains(prop)) {
458
- throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
929
+ throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT
459
930
  }
460
931
  if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
461
- throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
932
+ throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT
462
933
  }
463
934
  }
464
935
  if (parameters_properties.size() != expected_properties.size()) {
@@ -466,18 +937,16 @@ static void expect_tool_parameters(const std::string & name, const json & parame
466
937
  }
467
938
  }
468
939
 
469
- static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) {
940
+ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
470
941
  auto builtin_tools = json::array();
471
942
  common_chat_params data;
472
- data.grammar_lazy = inputs.tool_choice != "required";
943
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
473
944
  data.grammar = build_grammar([&](const common_grammar_builder & builder) {
474
945
  std::vector<std::string> tool_rules;
475
946
 
476
947
  auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
477
- if (name == "wolfram_alpha") {
948
+ if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
478
949
  // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
479
- expect_tool_parameters(name, parameters, {"query"});
480
- } else if (name == "web_search" || name == "brave_search") {
481
950
  // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
482
951
  expect_tool_parameters(name, parameters, {"query"});
483
952
  } else if (name == "python" || name == "code_interpreter") {
@@ -489,7 +958,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
489
958
 
490
959
  std::vector<std::string> kvs;
491
960
  for (const auto & [key, value] : parameters.at("properties").items()) {
492
- kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
961
+ kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
493
962
  }
494
963
 
495
964
  tool_rules.push_back(
@@ -515,23 +984,23 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
515
984
  builder.add_rule(
516
985
  name + "-call",
517
986
  "\"{\" space "
518
- "( \"\\\"type\\\":\" space \"\\\"function\\\",\" space )? "
519
- "\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
520
- builder.add_schema(name + "-args", parameters) +
521
- " \"}\""));
522
- data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true});
987
+ "( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
988
+ " \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
989
+ " \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
990
+ "\"}\" space"));
991
+ });
992
+ // Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
993
+ data.grammar_triggers.push_back({
994
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
995
+ "\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
523
996
  });
524
- data.grammar_triggers.push_back({"{\"name\":", /* .at_start = */ true});
525
- data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true});
526
- data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true});
527
- data.grammar_triggers.push_back({"{\"type\": \"function\"", /* .at_start = */ true});
528
- data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true});
529
- data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true});
530
997
  if (!builtin_tools.empty()) {
531
- data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
998
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
999
+ data.preserved_tokens.push_back("<|python_tag|>");
532
1000
  }
1001
+ // Allow a few empty lines on top of the usual constrained json schema space rule.
533
1002
  builder.add_rule("root", string_join(tool_rules, " | "));
534
- }, grammar_options);
1003
+ });
535
1004
  data.additional_stops.push_back("<|eom_id|>");
536
1005
  data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
537
1006
  {"tools_in_user_message", false},
@@ -544,54 +1013,53 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
544
1013
  }
545
1014
  static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
546
1015
  // TODO: tighten & simplify the parser, don't accept leading text context.
547
- static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": ");
548
- static std::regex close_regex("\\}");
549
- static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)");
1016
+ static const std::regex function_regex(
1017
+ "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
1018
+ static const std::regex close_regex("\\}\\s*");
1019
+ static const std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)");
550
1020
 
551
1021
  if (with_builtin_tools) {
552
1022
  std::smatch match;
553
1023
  if (std::regex_match(input, match, builtin_call_regex)) {
554
- auto name = match[1].str();
555
- auto raw_args = match[2].str();
556
-
557
- // TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing.
558
- auto it_eq = raw_args.find('=');
559
- auto arg_name = raw_args.substr(0, it_eq);
560
- auto arg_value_str = raw_args.substr(it_eq + 1);
561
- auto arg_value = json::parse(arg_value_str);
562
-
563
- return {
564
- /* .role = */ "assistant",
565
- /* .content = */ match.prefix().str(),
566
- /* .tool_calls = */ {
567
- {
568
- /* .name = */ match[1],
569
- /* .arguments = */ (json {
570
- {arg_name, arg_value},
571
- }).dump(),
572
- /* .id = */ "",
573
- },
574
- },
575
- };
1024
+ try {
1025
+ auto name = match[1].str();
1026
+ auto arg_name = match[2].str();
1027
+ auto arg_value_str = match[3].str();
1028
+ auto arg_value = json::parse(arg_value_str);
1029
+
1030
+ common_chat_msg msg;
1031
+ msg.role = "assistant";
1032
+ msg.tool_calls.push_back({
1033
+ /* .name = */ name,
1034
+ /* .arguments = */ (json {
1035
+ {arg_name, arg_value},
1036
+ }).dump(),
1037
+ /* .id = */ "",
1038
+ });
1039
+ return msg;
1040
+ } catch (const std::exception & e) {
1041
+ LOG_WRN("Failed to parse builtin tool call arguments (%s): %s", e.what(), input.c_str());
1042
+ }
576
1043
  }
577
1044
  }
578
1045
  return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
579
1046
  }
580
1047
 
581
- static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
1048
+ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
582
1049
  common_chat_params data;
583
1050
  if (inputs.tools.is_array() && !inputs.tools.empty()) {
584
- data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null();
1051
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
585
1052
  data.grammar = build_grammar([&](const common_grammar_builder & builder) {
586
1053
  std::vector<std::string> tool_rules;
587
1054
  foreach_function(inputs.tools, [&](const json & tool) {
588
1055
  const auto & function = tool.at("function");
589
1056
  std::string name = function.at("name");
590
1057
  auto parameters = function.at("parameters");
591
- auto args_rule = builder.add_schema(name + "-args", parameters);
1058
+ builder.resolve_refs(parameters);
592
1059
  tool_rules.push_back(builder.add_rule(name + "-call",
593
1060
  "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
594
- "```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
1061
+ "```json\\n\" " + builder.add_schema(name + "-args", parameters) + " "
1062
+ "\"```<|tool▁call▁end|>\""));
595
1063
  });
596
1064
  // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
597
1065
  // so we accept common variants (then it's all constrained)
@@ -600,18 +1068,20 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
600
1068
  "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
601
1069
  "\"<|tool▁calls▁end|>\""
602
1070
  " space");
603
- data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
604
- data.grammar_triggers.push_back({"<|tool_calls_begin|>", /* .at_start = */ false});
605
- data.grammar_triggers.push_back({"<|tool calls begin|>", /* .at_start = */ false});
606
- data.grammar_triggers.push_back({"<|tool\\_calls\\_begin|>", /* .at_start = */ false});
1071
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool▁calls▁begin|>"});
1072
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls_begin|>"});
1073
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool calls begin|>"});
1074
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool\\_calls\\_begin|>"});
607
1075
  data.preserved_tokens = {
608
1076
  "<think>",
609
1077
  "</think>",
1078
+ "<|tool▁calls▁begin|>",
1079
+ "<|tool▁call▁begin|>",
610
1080
  "<|tool▁sep|>",
611
- "<|tool▁calls▁end|",
612
1081
  "<|tool▁call▁end|>",
1082
+ "<|tool▁calls▁end|",
613
1083
  };
614
- }, grammar_options);
1084
+ });
615
1085
  }
616
1086
  auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
617
1087
 
@@ -636,45 +1106,53 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
636
1106
  data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1;
637
1107
  return data;
638
1108
  }
639
- static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) {
640
- static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
641
- static std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
642
- static std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
643
- static std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>");
644
- common_chat_msg msg;
645
- msg.role = "assistant";
1109
+ static common_chat_msg handle_think_tag_prelude(const std::string & input, bool extract_reasoning, const std::function<common_chat_msg(const std::string &)> & rest_parser) {
646
1110
  std::smatch match;
1111
+ static const std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
647
1112
  if (std::regex_match(input, match, reasoning_content_regex)) {
648
- std::string rest;
1113
+ auto rest = match[3].str();
1114
+ auto msg = rest_parser(rest);
1115
+ auto reasoning_content = string_strip(match[2].str());
649
1116
  if (extract_reasoning) {
650
- msg.reasoning_content = string_strip(match[2].str());
651
- } else {
652
- msg.content = match[1].str();
1117
+ msg.reasoning_content = reasoning_content;
1118
+ } else if (!reasoning_content.empty()) {
1119
+ std::ostringstream content;
1120
+ content << "<think>" << reasoning_content << "</think>" << msg.content;
1121
+ msg.content = content.str();
653
1122
  }
654
- rest = match[3].str();
1123
+ return msg;
1124
+ }
1125
+ return rest_parser(input);
1126
+ }
1127
+ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) {
1128
+ return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
1129
+ static const std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
1130
+ static const std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
1131
+ static const std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>");
655
1132
 
656
- if (std::regex_search(rest, match, tool_calls_regex)) {
1133
+ common_chat_msg msg;
1134
+ msg.role = "assistant";
1135
+ std::smatch match;
1136
+ if (std::regex_search(input, match, tool_calls_regex)) {
657
1137
  auto tool_calls = match[1].str();
658
1138
  auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
659
1139
  msg.tool_calls = std::move(msg2.tool_calls);
660
1140
  } else {
661
- msg.content += std::string(rest.begin() + rest.find_first_not_of(" \r\n"), rest.end());
1141
+ msg.content = input;
662
1142
  }
663
- } else {
664
- msg.content = input;
665
- }
666
- return msg;
1143
+ return msg;
1144
+ });
667
1145
  }
668
1146
 
669
- static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
670
- fprintf(stderr, "%s\n", __func__);
1147
+ static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
1148
+ LOG_DBG("%s\n", __func__);
671
1149
  common_chat_params data;
672
1150
  data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
673
1151
  {"datetime", "Jan 29 2025 13:00:00 GMT"},
674
1152
  {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
675
1153
  });
676
1154
  if (inputs.tools.is_array() && !inputs.tools.empty()) {
677
- data.grammar_lazy = inputs.tool_choice != "required";
1155
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
678
1156
  data.grammar = build_grammar([&](const common_grammar_builder & builder) {
679
1157
  auto schemas = json::array();
680
1158
  foreach_function(inputs.tools, [&](const json & tool) {
@@ -700,8 +1178,11 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
700
1178
  schema["maxItems"] = 1;
701
1179
  }
702
1180
  builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
703
- }, grammar_options);
704
- data.grammar_triggers.push_back({" functools[", /* .at_start = */ false});
1181
+ });
1182
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, " functools["});
1183
+ data.preserved_tokens = {
1184
+ " functools[",
1185
+ };
705
1186
  data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2;
706
1187
  } else {
707
1188
  data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
@@ -712,14 +1193,14 @@ static common_chat_msg common_chat_parse_firefunction_v2(const std::string & inp
712
1193
  return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
713
1194
  }
714
1195
 
715
- static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
1196
+ static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
716
1197
  // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
717
1198
  // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
718
1199
  common_chat_params data;
719
1200
  data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
720
1201
  data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
721
1202
  if (inputs.tools.is_array() && !inputs.tools.empty()) {
722
- data.grammar_lazy = inputs.tool_choice != "required";
1203
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
723
1204
  data.grammar = build_grammar([&](const common_grammar_builder & builder) {
724
1205
  std::vector<std::string> first_tool_rules;
725
1206
  std::vector<std::string> subsequent_tool_rules;
@@ -727,12 +1208,30 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
727
1208
  const auto & function = tool.at("function");
728
1209
  std::string name = function.at("name");
729
1210
  auto parameters = function.at("parameters");
1211
+ builder.resolve_refs(parameters);
730
1212
  auto args_rule = builder.add_schema(name + "-args", parameters);
731
- first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
1213
+ first_tool_rules.push_back(builder.add_rule(name + "-call", "( \"assistant<|end_header_id|>\\n\" )? \"" + name + "\\n\" " + args_rule));
732
1214
  subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
733
- data.grammar_triggers.push_back({name, /* .at_start = */ true});
734
- data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false});
1215
+ data.grammar_triggers.push_back({
1216
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
1217
+ regex_escape(name + "\n"),
1218
+ });
1219
+ data.grammar_triggers.push_back({
1220
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
1221
+ regex_escape("assistant<|end_header_id|>\n" + name + "\n"),
1222
+ });
1223
+ data.grammar_triggers.push_back({
1224
+ COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
1225
+ regex_escape(">>>" + name + "\n"),
1226
+ });
1227
+ data.grammar_triggers.push_back({
1228
+ COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
1229
+ ">>>assistant<|end_header_id|>\n" + name,
1230
+ });
735
1231
  });
1232
+ data.preserved_tokens = {
1233
+ "<|end_header_id|>",
1234
+ };
736
1235
  auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
737
1236
  if (inputs.parallel_tool_calls) {
738
1237
  auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
@@ -741,34 +1240,20 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
741
1240
  builder.add_rule("root", first_rule);
742
1241
  }
743
1242
 
744
- }, grammar_options);
1243
+ });
745
1244
  }
746
1245
  return data;
747
1246
  }
748
1247
 
749
- static bool consume(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
750
- auto expected_it = expected.begin();
751
- auto tmp_it = it;
752
- while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
753
- ++tmp_it;
754
- ++expected_it;
755
- }
756
- if (expected_it == expected.end()) {
757
- it = tmp_it;
758
- return true;
759
- }
760
- return false;
761
- }
762
-
763
1248
  static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) {
764
- static std::regex function_regex(R"((?:>>>)?(\w+)\n)");
765
- static std::regex close_regex(R"($|(?=>>>))");
1249
+ static const std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)");
1250
+ static const std::regex close_regex(R"($|(?=>>>))");
766
1251
 
767
1252
  std::string content;
768
1253
  auto it = input.begin();
769
1254
  const auto end = input.end();
770
1255
 
771
- if (consume(it, end, "all\n")) {
1256
+ if (parse_literal(it, end, "all\n")) {
772
1257
  std::smatch match;
773
1258
  if (std::regex_search(it, end, match, function_regex)) {
774
1259
  auto fun_it = match.prefix().second;
@@ -783,7 +1268,7 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
783
1268
  }
784
1269
  // TODO: tighten & simplify.
785
1270
  try {
786
- auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex);
1271
+ auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex, /* allow_raw_python= */ true);
787
1272
  res.content = content + res.content;
788
1273
  return res;
789
1274
  } catch (const std::exception & e) {
@@ -795,14 +1280,14 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
795
1280
  }
796
1281
  }
797
1282
 
798
- static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
1283
+ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
799
1284
  // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
800
1285
  common_chat_params data;
801
1286
  json tools = inputs.tools.is_null() ? inputs.tools : json::array();
802
1287
  std::string python_code_argument_name;
803
1288
  auto has_raw_python = false;
804
1289
 
805
- data.grammar_lazy = inputs.tool_choice != "required";
1290
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
806
1291
  data.grammar = build_grammar([&](const common_grammar_builder & builder) {
807
1292
  std::vector<std::string> tool_rules;
808
1293
  foreach_function(inputs.tools, [&](const json & tool) {
@@ -814,7 +1299,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
814
1299
  throw std::runtime_error("Missing type in python tool");
815
1300
  }
816
1301
  has_raw_python = true;
817
- auto type = parameters.at("type");
1302
+ const auto & type = parameters.at("type");
818
1303
  if (type == "object") {
819
1304
  auto properties = parameters.at("properties");
820
1305
  for (auto it = properties.begin(); it != properties.end(); ++it) {
@@ -836,12 +1321,13 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
836
1321
  });
837
1322
  if (has_raw_python) {
838
1323
  tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
839
- data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false});
1324
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
1325
+ data.preserved_tokens.push_back("<|python_tag|>");
840
1326
  }
841
1327
  auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
842
1328
  builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
843
- data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
844
- }, grammar_options);
1329
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
1330
+ });
845
1331
 
846
1332
  data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
847
1333
  // TODO: if (has_raw_python)
@@ -850,34 +1336,33 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
850
1336
  }
851
1337
  static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
852
1338
  // This version of Functionary still supports the llama 3.1 tool call format for the python tool.
853
- static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
1339
+ static const std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
854
1340
  std::smatch match;
855
1341
  if (std::regex_search(input, match, python_tag_regex)) {
856
1342
  auto code = match[1].str();
857
- return {
858
- /* .role = */ "assistant",
859
- /* .content = */ match.prefix().str(),
860
- /* .tool_calls = */ {
861
- {
862
- /* .name = */ "python",
863
- /* .arguments = */ (json {{"code", code}}).dump(),
864
- /* .id = */ "",
865
- },
866
- }
867
- };
1343
+ common_chat_msg msg;
1344
+ msg.role = "assistant";
1345
+ msg.content = match.prefix().str();
1346
+ msg.tool_calls.push_back({
1347
+ /* .name = */ "python",
1348
+ /* .arguments = */ (json {{"code", code}}).dump(),
1349
+ /* .id = */ "",
1350
+ });
1351
+ return msg;
868
1352
  }
869
- static std::regex function_regex(R"(<function=(\w+)>)");
870
- static std::regex close_regex(R"(</function>)");
1353
+ static const std::regex function_regex(R"(<function=(\w+)>)");
1354
+ static const std::regex close_regex(R"(</function>)");
871
1355
  // TODO: tighten & simplify.
872
1356
  return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
873
1357
  }
874
1358
 
875
- static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
1359
+ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
876
1360
  common_chat_params data;
877
1361
  // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
878
- data.grammar_lazy = inputs.tool_choice != "required";
1362
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
879
1363
  data.grammar = build_grammar([&](const common_grammar_builder & builder) {
880
1364
  std::vector<std::string> tool_rules;
1365
+ std::vector<std::string> tool_call_alts;
881
1366
  foreach_function(inputs.tools, [&](const json & tool) {
882
1367
  const auto & function = tool.at("function");
883
1368
  std::string name = function.at("name");
@@ -891,73 +1376,190 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
891
1376
  }},
892
1377
  {"required", json::array({"name", "arguments"})},
893
1378
  }));
1379
+ tool_call_alts.push_back(builder.add_rule(
1380
+ name + "-function-tag",
1381
+ "\"<function\" ( \"=" + name + "\" | \" name=\\\"" + name + "\\\"\" ) \">\" space " +
1382
+ builder.add_schema(name + "-args", parameters) + " "
1383
+ "\"</function>\" space"));
1384
+
1385
+ data.grammar_triggers.push_back({
1386
+ COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
1387
+ "<function=" + name + ">",
1388
+ });
1389
+ auto escaped_name = regex_escape(name);
1390
+ data.grammar_triggers.push_back({
1391
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
1392
+ "<function\\s+name\\s*=\\s*\"" + escaped_name + "\"",
1393
+ });
894
1394
  });
895
- auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
1395
+ auto any_tool_call = builder.add_rule("any_tool_call", "( " + string_join(tool_rules, " | ") + " ) space");
1396
+ std::vector<std::string> alt_tags {
1397
+ any_tool_call,
1398
+ "\"<tool_call>\" space " + any_tool_call + " \"</tool_call>\"",
1399
+ // The rest is just to accommodate common "good bad" outputs.
1400
+ "\"<function_call>\" space " + any_tool_call + " \"</function_call>\"",
1401
+ "\"<response>\" space " + any_tool_call + " \"</response>\"",
1402
+ "\"<tools>\" space " + any_tool_call + " \"</tools>\"",
1403
+ "\"<json>\" space " + any_tool_call + " \"</json>\"",
1404
+ "\"<xml>\" space " + any_tool_call + " \"</xml>\"",
1405
+ "\"<JSON>\" space " + any_tool_call + " \"</JSON>\"",
1406
+ };
1407
+ auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space");
1408
+ tool_call_alts.push_back(wrappable_tool_call);
1409
+ tool_call_alts.push_back(
1410
+ "( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space ");
1411
+ auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | "));
896
1412
  builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
897
- data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
898
- data.preserved_tokens = { "</tool_call>" };
899
- }, grammar_options);
1413
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<tool_call>"});
1414
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function"});
1415
+ // Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
1416
+ data.grammar_triggers.push_back({
1417
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
1418
+ "(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?\\s*\\{\\s*\"", //name\"\\s*:\\s*\"" + escaped_name + "\"",
1419
+ });
1420
+ data.preserved_tokens = {
1421
+ "<think>",
1422
+ "</think>",
1423
+ "<tool_call>",
1424
+ "</tool_call>",
1425
+ "<function",
1426
+ "<tools>",
1427
+ "</tools>",
1428
+ "<response>",
1429
+ "</response>",
1430
+ "<function_call>",
1431
+ "</function_call>",
1432
+ "<json>",
1433
+ "</json>",
1434
+ "<JSON>",
1435
+ "</JSON>",
1436
+ "```",
1437
+ "```json",
1438
+ "```xml",
1439
+ };
1440
+ });
900
1441
 
901
1442
  data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
902
- data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
1443
+ data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING : COMMON_CHAT_FORMAT_HERMES_2_PRO;
903
1444
  return data;
904
1445
  }
905
- static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) {
906
- try {
907
- std::regex start_pattern(R"([\n\s]*<tool_call>)");
908
- std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
909
- std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
910
-
911
- auto end = input.end();
912
- std::sregex_iterator rend;
913
- std::sregex_iterator rit(input.begin(), end, start_pattern);
914
- if (rit == rend) {
915
- return {
916
- /* .role = */ "assistant",
917
- /* .content = */ input,
918
- /* .tool_calls = */ {},
919
- };
920
- }
921
-
922
- common_chat_msg result;
923
- result.role = "assistant";
924
- result.content = rit->prefix();
925
-
926
- auto it = rit->suffix().first;
927
- while (it != end) {
928
- json call;
929
- if (!parse_json(it, end, call)) {
930
- throw std::runtime_error("Failed to parse json tool call");
931
- }
932
- const auto & arguments = call.at("arguments");
933
- result.tool_calls.push_back({
934
- call.at("name"),
935
- arguments.dump(),
936
- // arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
937
- /* id= */ "",
938
- });
939
- rit = {it, end, middle_pattern};
940
- if (rit != rend) {
941
- it = rit->suffix().first;
942
- } else {
943
- rit = {it, end, end_pattern};
944
- if (rit == rend) {
945
- throw std::runtime_error("Malformed input, missing </tool_call>");
1446
+ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input, bool extract_reasoning) {
1447
+ return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
1448
+ static const std::regex open_regex(
1449
+ "(?:"
1450
+ "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
1451
+ "(<tool_call>" // match 2 (open_tag)
1452
+ "|<function_call>"
1453
+ "|<tool>"
1454
+ "|<tools>"
1455
+ "|<response>"
1456
+ "|<json>"
1457
+ "|<xml>"
1458
+ "|<JSON>"
1459
+ ")?"
1460
+ "(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)" // match 3 (named tool call + rest)
1461
+ ")"
1462
+ "|"
1463
+ "(?:<function=([^>]+)>" // match 4 (function name)
1464
+ "|<function name=\"([^\"]+)\">)" // match 5 (function name again)
1465
+ "([\\s\\S]*)" // match 6 (function arguments + rest)})"
1466
+ );
1467
+
1468
+ try {
1469
+ common_chat_msg msg;
1470
+ msg.role = "assistant";
1471
+
1472
+ std::string::const_iterator it = input.begin();
1473
+ const std::string::const_iterator end = input.end();
1474
+ std::smatch match;
1475
+
1476
+ while (it != end) {
1477
+ if (std::regex_search(it, end, match, open_regex)) {
1478
+ // Add content before the match
1479
+ msg.content += std::string(it, match[0].first);
1480
+
1481
+ auto block_start = match[1].str();
1482
+ std::string block_end = block_start.empty() ? "" : "```";
1483
+
1484
+ auto open_tag = match[2].str();
1485
+ std::string close_tag;
1486
+
1487
+ if (match[3].matched) {
1488
+ close_tag = open_tag.empty() ? "" : "</" + open_tag.substr(1);
1489
+ auto json_it = match[3].first;
1490
+ json tool_call;
1491
+ if (parse_json(json_it, end, tool_call) && tool_call.contains("name") && tool_call.contains("arguments")) {
1492
+
1493
+ msg.tool_calls.emplace_back(process_tool_call(tool_call));
1494
+ it = json_it; // Move iterator past parsed JSON
1495
+
1496
+ // Handle close tags
1497
+ consume_spaces(it, end);
1498
+ if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
1499
+ throw std::runtime_error("Failed to parse closing tag");
1500
+ }
1501
+ consume_spaces(it, end);
1502
+ if (!block_end.empty() && !parse_literal(it, end, block_end)) {
1503
+ throw std::runtime_error("Failed to parse block end");
1504
+ }
1505
+ consume_spaces(it, end);
1506
+ } else {
1507
+ // Not a valid tool call, treat as content
1508
+ msg.content += std::string(match[0].first, match[0].second);
1509
+ it = match[0].second;
1510
+ }
1511
+ } else {
1512
+ auto function_name = match[4].str();
1513
+ if (function_name.empty()) {
1514
+ function_name = match[5].str();
1515
+ }
1516
+ LM_GGML_ASSERT(!function_name.empty());
1517
+
1518
+ close_tag = "</function>";
1519
+ // Start parsing from after the opening tags
1520
+ auto json_it = match[6].first;
1521
+ json arguments;
1522
+ if (parse_json(json_it, end, arguments)) {
1523
+ msg.tool_calls.emplace_back(process_tool_call({
1524
+ {"name", function_name},
1525
+ {"arguments", arguments},
1526
+ }));
1527
+ it = json_it; // Move iterator past parsed JSON
1528
+
1529
+ // Handle close tags
1530
+ consume_spaces(it, end);
1531
+ if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
1532
+ throw std::runtime_error("Failed to parse closing tag");
1533
+ }
1534
+ consume_spaces(it, end);
1535
+ if (!block_end.empty() && !parse_literal(it, end, block_end)) {
1536
+ throw std::runtime_error("Failed to parse block end");
1537
+ }
1538
+ consume_spaces(it, end);
1539
+ } else {
1540
+ // Not a valid tool call, treat as content
1541
+ msg.content += std::string(match[0].first, match[0].second);
1542
+ it = match[0].second;
1543
+ }
1544
+ }
1545
+ } else {
1546
+ // Add remaining content
1547
+ msg.content += std::string(it, end);
1548
+ break;
946
1549
  }
947
- break;
948
1550
  }
1551
+ return msg;
1552
+ } catch (const std::exception & e) {
1553
+ LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
1554
+ common_chat_msg msg;
1555
+ msg.role = "assistant";
1556
+ msg.content = input;
1557
+ return msg;
949
1558
  }
950
- return result;
951
- } catch (const std::exception & e) {
952
- return {
953
- /* .role = */ "assistant",
954
- /* .content = */ input,
955
- /* .tool_calls = */ {},
956
- };
957
- }
1559
+ });
958
1560
  }
959
1561
 
960
- static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
1562
+ static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
961
1563
  common_chat_params data;
962
1564
  data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
963
1565
  data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
@@ -973,12 +1575,35 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
973
1575
  return data;
974
1576
  }
975
1577
 
976
- common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
1578
+ static common_chat_params common_chat_templates_apply_jinja(
1579
+ const struct common_chat_templates * tmpls,
1580
+ const struct common_chat_templates_inputs & inputs)
1581
+ {
1582
+ templates_params params;
1583
+ params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
1584
+ const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
1585
+ ? *tmpls->template_tool_use
1586
+ : *tmpls->template_default;
977
1587
  const auto & src = tmpl.source();
978
1588
  const auto & caps = tmpl.original_caps();
1589
+ params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
1590
+ params.add_generation_prompt = inputs.add_generation_prompt;
1591
+ params.extract_reasoning = inputs.extract_reasoning;
1592
+ params.tool_choice = inputs.tool_choice;
1593
+ params.grammar = inputs.grammar;
1594
+ if (!inputs.json_schema.empty()) {
1595
+ params.json_schema = json::parse(inputs.json_schema);
1596
+ }
1597
+
1598
+ if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
1599
+ LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
1600
+ params.parallel_tool_calls = false;
1601
+ } else {
1602
+ params.parallel_tool_calls = inputs.parallel_tool_calls;
1603
+ }
979
1604
 
980
- if (inputs.tools.is_array()) {
981
- if (inputs.tool_choice != "none" && !inputs.grammar.empty()) {
1605
+ if (params.tools.is_array()) {
1606
+ if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
982
1607
  throw std::runtime_error("Cannot specify grammar with tools");
983
1608
  }
984
1609
  if (caps.supports_tool_calls && !caps.supports_tools) {
@@ -987,68 +1612,135 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
987
1612
  }
988
1613
 
989
1614
  // DeepSeek R1: use handler in all cases except json schema (thinking / tools).
990
- if (src.find("<|tool▁calls▁begin|>") != std::string::npos && inputs.json_schema.is_null()) {
991
- return common_chat_params_init_deepseek_r1(tmpl, inputs);
1615
+ if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) {
1616
+ return common_chat_params_init_deepseek_r1(tmpl, params);
992
1617
  }
993
1618
 
994
1619
  // Command R7B: : use handler in all cases except json schema (thinking / tools).
995
- if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) {
996
- return common_chat_params_init_command_r7b(tmpl, inputs);
1620
+ if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) {
1621
+ return common_chat_params_init_command_r7b(tmpl, params);
1622
+ }
1623
+
1624
+ // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
1625
+ if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
1626
+ return common_chat_params_init_hermes_2_pro(tmpl, params);
997
1627
  }
998
1628
 
999
1629
  // Use generic handler when mixing tools + JSON schema.
1000
1630
  // TODO: support that mix in handlers below.
1001
- if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) {
1002
- return common_chat_params_init_generic(tmpl, inputs);
1631
+ if ((params.tools.is_array() && params.json_schema.is_object())) {
1632
+ return common_chat_params_init_generic(tmpl, params);
1003
1633
  }
1004
1634
 
1005
1635
  // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
1006
1636
  if (src.find(">>>all") != std::string::npos) {
1007
- return common_chat_params_init_functionary_v3_2(tmpl, inputs);
1637
+ return common_chat_params_init_functionary_v3_2(tmpl, params);
1008
1638
  }
1009
1639
 
1010
1640
  // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
1011
1641
  if (src.find(" functools[") != std::string::npos) {
1012
- return common_chat_params_init_firefunction_v2(tmpl, inputs);
1642
+ return common_chat_params_init_firefunction_v2(tmpl, params);
1013
1643
  }
1014
1644
 
1015
1645
  // Plain handler (no tools)
1016
- if (inputs.tools.is_null() || inputs.tool_choice == "none") {
1017
- return common_chat_params_init_without_tools(tmpl, inputs);
1018
- }
1019
-
1020
- // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
1021
- if (src.find("<tool_call>") != std::string::npos) {
1022
- return common_chat_params_init_hermes_2_pro(tmpl, inputs);
1646
+ if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
1647
+ return common_chat_params_init_without_tools(tmpl, params);
1023
1648
  }
1024
1649
 
1025
1650
  // Functionary v3.1 (w/ tools)
1026
1651
  if (src.find("<|start_header_id|>") != std::string::npos
1027
1652
  && src.find("<function=") != std::string::npos) {
1028
- return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
1653
+ return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
1029
1654
  }
1030
1655
 
1031
1656
  // Llama 3.1, 3.2, 3.3 (w/ tools)
1032
1657
  if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
1033
1658
  auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
1034
- return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
1659
+ return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
1035
1660
  }
1036
1661
 
1037
1662
  // Mistral Nemo (w/ tools)
1038
1663
  if (src.find("[TOOL_CALLS]") != std::string::npos) {
1039
- return common_chat_params_init_mistral_nemo(tmpl, inputs);
1664
+ return common_chat_params_init_mistral_nemo(tmpl, params);
1040
1665
  }
1041
1666
 
1042
1667
  // Generic fallback
1043
- return common_chat_params_init_generic(tmpl, inputs);
1668
+ return common_chat_params_init_generic(tmpl, params);
1669
+ }
1670
+
1671
+ // Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template.
1672
+ static common_chat_params common_chat_templates_apply_legacy(
1673
+ const struct common_chat_templates * tmpls,
1674
+ const struct common_chat_templates_inputs & inputs)
1675
+ {
1676
+ int alloc_size = 0;
1677
+ std::vector<llama_chat_message> chat;
1678
+ std::vector<std::string> contents;
1679
+ for (const auto & msg : inputs.messages) {
1680
+ auto content = msg.content;
1681
+ for (const auto & part : msg.content_parts) {
1682
+ if (part.type != "text") {
1683
+ LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
1684
+ continue;
1685
+ }
1686
+ if (!content.empty()) {
1687
+ content += "\n";;
1688
+ }
1689
+ content += part.text;
1690
+ }
1691
+ contents.emplace_back(std::move(content));
1692
+ }
1693
+ for (size_t i = 0; i < contents.size(); ++i) {
1694
+ const auto & msg = inputs.messages[i];
1695
+ const auto & content = contents[i];
1696
+ chat.push_back({msg.role.c_str(), content.c_str()});
1697
+ alloc_size += (msg.role.size() + content.size()) * 1.25;
1698
+ }
1699
+
1700
+ std::vector<char> buf(alloc_size);
1701
+
1702
+ // run the first time to get the total output length
1703
+ const auto & src = tmpls->template_default->source();
1704
+ int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
1705
+
1706
+ // error: chat template is not supported
1707
+ if (res < 0) {
1708
+ // if the custom "tmpl" is not supported, we throw an error
1709
+ // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
1710
+ throw std::runtime_error("this custom template is not supported");
1711
+ }
1712
+
1713
+ // if it turns out that our buffer is too small, we resize it
1714
+ if ((size_t) res > buf.size()) {
1715
+ buf.resize(res);
1716
+ res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
1717
+ }
1718
+
1719
+ common_chat_params params;
1720
+ params.prompt = std::string(buf.data(), res);
1721
+ if (!inputs.json_schema.empty()) {
1722
+ params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema));
1723
+ } else {
1724
+ params.grammar = inputs.grammar;
1725
+ }
1726
+ return params;
1727
+ }
1728
+
1729
+ common_chat_params common_chat_templates_apply(
1730
+ const struct common_chat_templates * tmpls,
1731
+ const struct common_chat_templates_inputs & inputs)
1732
+ {
1733
+ LM_GGML_ASSERT(tmpls != nullptr);
1734
+ return inputs.use_jinja
1735
+ ? common_chat_templates_apply_jinja(tmpls, inputs)
1736
+ : common_chat_templates_apply_legacy(tmpls, inputs);
1044
1737
  }
1045
1738
 
1046
1739
  static common_chat_msg common_chat_parse_content_only(const std::string & input) {
1047
- return {
1048
- /* .role = */ "assistant",
1049
- /* .content = */ input,
1050
- /* .tool_calls = */ {},
1051
- };
1740
+ common_chat_msg msg;
1741
+ msg.role = "assistant";
1742
+ msg.content = input;
1743
+ return msg;
1052
1744
  }
1053
1745
 
1054
1746
  common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
@@ -1072,7 +1764,9 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
1072
1764
  case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
1073
1765
  return common_chat_parse_functionary_v3_1_llama_3_1(input);
1074
1766
  case COMMON_CHAT_FORMAT_HERMES_2_PRO:
1075
- return common_chat_parse_hermes_2_pro(input);
1767
+ return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ false);
1768
+ case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING:
1769
+ return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ true);
1076
1770
  case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
1077
1771
  return common_chat_parse_firefunction_v2(input);
1078
1772
  case COMMON_CHAT_FORMAT_COMMAND_R7B: