cui-llama.rn 1.4.4 → 1.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. package/android/src/main/CMakeLists.txt +9 -2
  2. package/android/src/main/jni.cpp +54 -34
  3. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  11. package/cpp/binary-ops.cpp +158 -0
  12. package/cpp/binary-ops.h +16 -0
  13. package/cpp/chat.cpp +1769 -1085
  14. package/cpp/chat.h +143 -0
  15. package/cpp/common.cpp +1562 -1996
  16. package/cpp/common.h +677 -744
  17. package/cpp/cpu-common.h +72 -0
  18. package/cpp/ggml-alloc.c +1039 -1030
  19. package/cpp/ggml-alloc.h +1 -1
  20. package/cpp/ggml-backend-impl.h +255 -255
  21. package/cpp/ggml-backend-reg.cpp +586 -582
  22. package/cpp/ggml-backend.cpp +2004 -2002
  23. package/cpp/ggml-backend.h +354 -354
  24. package/cpp/ggml-common.h +1857 -1851
  25. package/cpp/ggml-cpp.h +39 -39
  26. package/cpp/ggml-cpu-aarch64.cpp +5725 -4247
  27. package/cpp/ggml-cpu-aarch64.h +8 -8
  28. package/cpp/ggml-cpu-impl.h +512 -380
  29. package/cpp/ggml-cpu-quants.c +13026 -11517
  30. package/cpp/ggml-cpu-traits.cpp +36 -36
  31. package/cpp/ggml-cpu-traits.h +38 -38
  32. package/cpp/ggml-cpu.c +3438 -14485
  33. package/cpp/ggml-cpu.cpp +655 -633
  34. package/cpp/ggml-cpu.h +138 -135
  35. package/cpp/ggml-impl.h +594 -567
  36. package/cpp/ggml-metal-impl.h +312 -3
  37. package/cpp/ggml-metal.h +66 -66
  38. package/cpp/ggml-metal.m +5360 -5002
  39. package/cpp/ggml-opt.cpp +854 -854
  40. package/cpp/ggml-opt.h +216 -216
  41. package/cpp/ggml-quants.c +5238 -5238
  42. package/cpp/ggml-threading.h +14 -14
  43. package/cpp/ggml.c +6618 -6524
  44. package/cpp/ggml.h +2222 -2194
  45. package/cpp/gguf.cpp +1330 -1329
  46. package/cpp/gguf.h +202 -202
  47. package/cpp/json-schema-to-grammar.cpp +1024 -1025
  48. package/cpp/json-schema-to-grammar.h +21 -22
  49. package/cpp/json.hpp +24766 -24766
  50. package/cpp/llama-adapter.cpp +382 -347
  51. package/cpp/llama-adapter.h +76 -74
  52. package/cpp/llama-arch.cpp +1714 -1492
  53. package/cpp/llama-arch.h +428 -402
  54. package/cpp/llama-batch.cpp +368 -368
  55. package/cpp/llama-batch.h +88 -88
  56. package/cpp/llama-chat.cpp +640 -587
  57. package/cpp/llama-chat.h +56 -53
  58. package/cpp/llama-context.cpp +2831 -1775
  59. package/cpp/llama-context.h +265 -128
  60. package/cpp/llama-cparams.cpp +1 -1
  61. package/cpp/llama-cparams.h +38 -37
  62. package/cpp/llama-cpp.h +30 -30
  63. package/cpp/llama-grammar.cpp +1219 -1219
  64. package/cpp/llama-grammar.h +173 -164
  65. package/cpp/llama-graph.cpp +1695 -0
  66. package/cpp/llama-graph.h +592 -0
  67. package/cpp/llama-hparams.cpp +79 -71
  68. package/cpp/llama-hparams.h +156 -139
  69. package/cpp/llama-impl.cpp +167 -167
  70. package/cpp/llama-impl.h +61 -61
  71. package/cpp/llama-io.cpp +15 -0
  72. package/cpp/llama-io.h +35 -0
  73. package/cpp/llama-kv-cache.cpp +1380 -718
  74. package/cpp/llama-kv-cache.h +213 -218
  75. package/cpp/llama-memory.cpp +1 -0
  76. package/cpp/llama-memory.h +21 -0
  77. package/cpp/llama-mmap.cpp +600 -590
  78. package/cpp/llama-mmap.h +68 -68
  79. package/cpp/llama-model-loader.cpp +1129 -1124
  80. package/cpp/llama-model-loader.h +169 -167
  81. package/cpp/llama-model.cpp +13080 -4023
  82. package/cpp/llama-model.h +409 -370
  83. package/cpp/llama-sampling.cpp +2563 -2525
  84. package/cpp/llama-sampling.h +32 -32
  85. package/cpp/llama-vocab.cpp +3295 -3252
  86. package/cpp/llama-vocab.h +125 -125
  87. package/cpp/llama.cpp +351 -10137
  88. package/cpp/llama.h +1434 -1340
  89. package/cpp/log.cpp +427 -423
  90. package/cpp/log.h +132 -132
  91. package/cpp/{chat-template.hpp → minja/chat-template.hpp} +537 -529
  92. package/cpp/{minja.hpp → minja/minja.hpp} +2941 -2883
  93. package/cpp/ops.cpp +8723 -0
  94. package/cpp/ops.h +128 -0
  95. package/cpp/rn-llama.cpp +45 -71
  96. package/cpp/rn-llama.h +3 -3
  97. package/cpp/sampling.cpp +573 -532
  98. package/cpp/sgemm.cpp +3043 -2598
  99. package/cpp/sgemm.h +14 -14
  100. package/cpp/simd-mappings.h +888 -0
  101. package/cpp/speculative.cpp +278 -277
  102. package/cpp/speculative.h +28 -28
  103. package/cpp/unary-ops.cpp +186 -0
  104. package/cpp/unary-ops.h +28 -0
  105. package/cpp/vec.cpp +258 -0
  106. package/cpp/vec.h +802 -0
  107. package/ios/CMakeLists.txt +5 -2
  108. package/ios/RNLlama.mm +2 -2
  109. package/ios/RNLlamaContext.mm +40 -24
  110. package/package.json +1 -1
  111. package/src/NativeRNLlama.ts +6 -4
  112. package/src/index.ts +3 -1
  113. package/android/src/main/build-arm64/CMakeCache.txt +0 -429
  114. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCCompiler.cmake +0 -81
  115. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCXXCompiler.cmake +0 -101
  116. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_C.bin +0 -0
  117. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_CXX.bin +0 -0
  118. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeSystem.cmake +0 -15
  119. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.c +0 -904
  120. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.o +0 -0
  121. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.cpp +0 -919
  122. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.o +0 -0
  123. package/android/src/main/build-arm64/CMakeFiles/CMakeConfigureLog.yaml +0 -431
  124. package/android/src/main/build-arm64/CMakeFiles/CMakeDirectoryInformation.cmake +0 -16
  125. package/android/src/main/build-arm64/CMakeFiles/Makefile.cmake +0 -165
  126. package/android/src/main/build-arm64/CMakeFiles/Makefile2 +0 -297
  127. package/android/src/main/build-arm64/CMakeFiles/Progress/1 +0 -1
  128. package/android/src/main/build-arm64/CMakeFiles/Progress/2 +0 -1
  129. package/android/src/main/build-arm64/CMakeFiles/Progress/3 +0 -1
  130. package/android/src/main/build-arm64/CMakeFiles/Progress/4 +0 -1
  131. package/android/src/main/build-arm64/CMakeFiles/Progress/5 +0 -1
  132. package/android/src/main/build-arm64/CMakeFiles/Progress/6 +0 -1
  133. package/android/src/main/build-arm64/CMakeFiles/Progress/count.txt +0 -1
  134. package/android/src/main/build-arm64/CMakeFiles/TargetDirectories.txt +0 -8
  135. package/android/src/main/build-arm64/CMakeFiles/cmake.check_cache +0 -1
  136. package/android/src/main/build-arm64/CMakeFiles/progress.marks +0 -1
  137. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o +0 -0
  138. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o.d +0 -58
  139. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o +0 -0
  140. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o.d +0 -756
  141. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o +0 -0
  142. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o.d +0 -709
  143. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o +0 -0
  144. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o.d +0 -714
  145. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o +0 -0
  146. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o.d +0 -62
  147. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o +0 -0
  148. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o.d +0 -708
  149. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o +0 -0
  150. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o.d +0 -113
  151. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o +0 -0
  152. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o.d +0 -713
  153. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o +0 -0
  154. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o.d +0 -763
  155. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o +0 -0
  156. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o.d +0 -61
  157. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o +0 -0
  158. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o.d +0 -707
  159. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o +0 -0
  160. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o.d +0 -104
  161. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o +0 -0
  162. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o.d +0 -714
  163. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o +0 -0
  164. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o.d +0 -723
  165. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/DependInfo.cmake +0 -62
  166. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/build.make +0 -722
  167. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/cmake_clean.cmake +0 -89
  168. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.make +0 -2
  169. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.ts +0 -2
  170. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/depend.make +0 -2
  171. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/flags.make +0 -17
  172. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/progress.make +0 -41
  173. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/DependInfo.cmake +0 -62
  174. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/build.make +0 -722
  175. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/cmake_clean.cmake +0 -89
  176. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.make +0 -2
  177. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.ts +0 -2
  178. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/depend.make +0 -2
  179. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/flags.make +0 -17
  180. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/progress.make +0 -41
  181. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/DependInfo.cmake +0 -62
  182. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/build.make +0 -722
  183. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/cmake_clean.cmake +0 -89
  184. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.make +0 -2
  185. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.ts +0 -2
  186. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/depend.make +0 -2
  187. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/flags.make +0 -17
  188. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/progress.make +0 -41
  189. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/DependInfo.cmake +0 -62
  190. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/build.make +0 -722
  191. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/cmake_clean.cmake +0 -89
  192. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.make +0 -2
  193. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.ts +0 -2
  194. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/depend.make +0 -2
  195. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/flags.make +0 -17
  196. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/progress.make +0 -41
  197. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/DependInfo.cmake +0 -62
  198. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/build.make +0 -722
  199. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/cmake_clean.cmake +0 -89
  200. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.make +0 -2
  201. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.ts +0 -2
  202. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/depend.make +0 -2
  203. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/flags.make +0 -17
  204. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/progress.make +0 -41
  205. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/DependInfo.cmake +0 -62
  206. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/build.make +0 -722
  207. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/cmake_clean.cmake +0 -89
  208. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.make +0 -2
  209. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.ts +0 -2
  210. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/depend.make +0 -2
  211. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/flags.make +0 -17
  212. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/progress.make +0 -41
  213. package/android/src/main/build-arm64/Makefile +0 -1862
  214. package/android/src/main/build-arm64/cmake_install.cmake +0 -66
  215. package/cpp/chat.hpp +0 -55
  216. package/cpp/rn-llama.hpp +0 -913
@@ -1,1025 +1,1024 @@
1
- #include "json-schema-to-grammar.h"
2
- #include "common.h"
3
-
4
- #include <algorithm>
5
- #include <fstream>
6
- #include <map>
7
- #include <regex>
8
- #include <sstream>
9
- #include <string>
10
- #include <unordered_map>
11
- #include <unordered_set>
12
- #include <vector>
13
-
14
- using json = nlohmann::ordered_json;
15
-
16
- static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
17
- auto has_max = max_items != std::numeric_limits<int>::max();
18
-
19
- if (min_items == 0 && max_items == 1) {
20
- return item_rule + "?";
21
- }
22
-
23
- if (separator_rule.empty()) {
24
- if (min_items == 1 && !has_max) {
25
- return item_rule + "+";
26
- } else if (min_items == 0 && !has_max) {
27
- return item_rule + "*";
28
- } else {
29
- return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}";
30
- }
31
- }
32
-
33
- auto result = item_rule + " " + build_repetition("(" + separator_rule + " " + item_rule + ")", min_items == 0 ? 0 : min_items - 1, has_max ? max_items - 1 : max_items);
34
- if (min_items == 0) {
35
- result = "(" + result + ")?";
36
- }
37
- return result;
38
- }
39
-
40
- /* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
41
- class string_view {
42
- const std::string & _str;
43
- const size_t _start;
44
- const size_t _end;
45
- public:
46
- string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}
47
-
48
- size_t size() const {
49
- return _end - _start;
50
- }
51
-
52
- size_t length() const {
53
- return size();
54
- }
55
-
56
- operator std::string() const {
57
- return str();
58
- }
59
-
60
- std::string str() const {
61
- return _str.substr(_start, _end - _start);
62
- }
63
-
64
- string_view substr(size_t pos, size_t len = std::string::npos) const {
65
- return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
66
- }
67
-
68
- char operator[](size_t pos) const {
69
- auto index = _start + pos;
70
- if (index >= _end) {
71
- throw std::out_of_range("string_view index out of range");
72
- }
73
- return _str[_start + pos];
74
- }
75
-
76
- bool operator==(const string_view & other) const {
77
- std::string this_str = *this;
78
- std::string other_str = other;
79
- return this_str == other_str;
80
- }
81
- };
82
-
83
- static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
84
- auto has_min = min_value != std::numeric_limits<int>::min();
85
- auto has_max = max_value != std::numeric_limits<int>::max();
86
-
87
- auto digit_range = [&](char from, char to) {
88
- out << "[";
89
- if (from == to) {
90
- out << from;
91
- } else {
92
- out << from << "-" << to;
93
- }
94
- out << "]";
95
- };
96
- auto more_digits = [&](int min_digits, int max_digits) {
97
- out << "[0-9]";
98
- if (min_digits == max_digits && min_digits == 1) {
99
- return;
100
- }
101
- out << "{";
102
- out << min_digits;
103
- if (max_digits != min_digits) {
104
- out << ",";
105
- if (max_digits != std::numeric_limits<int>::max()) {
106
- out << max_digits;
107
- }
108
- }
109
- out << "}";
110
- };
111
- std::function<void(const string_view &, const string_view &)> uniform_range =
112
- [&](const string_view & from, const string_view & to) {
113
- size_t i = 0;
114
- while (i < from.length() && i < to.length() && from[i] == to[i]) {
115
- i++;
116
- }
117
- if (i > 0) {
118
- out << "\"" << from.substr(0, i).str() << "\"";
119
- }
120
- if (i < from.length() && i < to.length()) {
121
- if (i > 0) {
122
- out << " ";
123
- }
124
- auto sub_len = from.length() - i - 1;
125
- if (sub_len > 0) {
126
- auto from_sub = from.substr(i + 1);
127
- auto to_sub = to.substr(i + 1);
128
- auto sub_zeros = string_repeat("0", sub_len);
129
- auto sub_nines = string_repeat("9", sub_len);
130
-
131
- auto to_reached = false;
132
- out << "(";
133
- if (from_sub == sub_zeros) {
134
- digit_range(from[i], to[i] - 1);
135
- out << " ";
136
- more_digits(sub_len, sub_len);
137
- } else {
138
- out << "[" << from[i] << "] ";
139
- out << "(";
140
- uniform_range(from_sub, sub_nines);
141
- out << ")";
142
- if (from[i] < to[i] - 1) {
143
- out << " | ";
144
- if (to_sub == sub_nines) {
145
- digit_range(from[i] + 1, to[i]);
146
- to_reached = true;
147
- } else {
148
- digit_range(from[i] + 1, to[i] - 1);
149
- }
150
- out << " ";
151
- more_digits(sub_len, sub_len);
152
- }
153
- }
154
- if (!to_reached) {
155
- out << " | ";
156
- digit_range(to[i], to[i]);
157
- out << " ";
158
- uniform_range(sub_zeros, to_sub);
159
- }
160
- out << ")";
161
- } else {
162
- out << "[" << from[i] << "-" << to[i] << "]";
163
- }
164
- }
165
- };
166
-
167
- if (has_min && has_max) {
168
- if (min_value < 0 && max_value < 0) {
169
- out << "\"-\" (";
170
- _build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true);
171
- out << ")";
172
- return;
173
- }
174
-
175
- if (min_value < 0) {
176
- out << "\"-\" (";
177
- _build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true);
178
- out << ") | ";
179
- min_value = 0;
180
- }
181
-
182
- auto min_s = std::to_string(min_value);
183
- auto max_s = std::to_string(max_value);
184
- auto min_digits = min_s.length();
185
- auto max_digits = max_s.length();
186
-
187
- for (auto digits = min_digits; digits < max_digits; digits++) {
188
- uniform_range(min_s, string_repeat("9", digits));
189
- min_s = "1" + string_repeat("0", digits);
190
- out << " | ";
191
- }
192
- uniform_range(min_s, max_s);
193
- return;
194
- }
195
-
196
- auto less_decimals = std::max(decimals_left - 1, 1);
197
-
198
- if (has_min) {
199
- if (min_value < 0) {
200
- out << "\"-\" (";
201
- _build_min_max_int(std::numeric_limits<int>::min(), -min_value, out, decimals_left, /* top_level= */ false);
202
- out << ") | [0] | [1-9] ";
203
- more_digits(0, decimals_left - 1);
204
- } else if (min_value == 0) {
205
- if (top_level) {
206
- out << "[0] | [1-9] ";
207
- more_digits(0, less_decimals);
208
- } else {
209
- more_digits(1, decimals_left);
210
- }
211
- } else if (min_value <= 9) {
212
- char c = '0' + min_value;
213
- auto range_start = top_level ? '1' : '0';
214
- if (c > range_start) {
215
- digit_range(range_start, c - 1);
216
- out << " ";
217
- more_digits(1, less_decimals);
218
- out << " | ";
219
- }
220
- digit_range(c, '9');
221
- out << " ";
222
- more_digits(0, less_decimals);
223
- } else {
224
- auto min_s = std::to_string(min_value);
225
- auto len = min_s.length();
226
- auto c = min_s[0];
227
-
228
- if (c > '1') {
229
- digit_range(top_level ? '1' : '0', c - 1);
230
- out << " ";
231
- more_digits(len, less_decimals);
232
- out << " | ";
233
- }
234
- digit_range(c, c);
235
- out << " (";
236
- _build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits<int>::max(), out, less_decimals, /* top_level= */ false);
237
- out << ")";
238
- if (c < '9') {
239
- out << " | ";
240
- digit_range(c + 1, '9');
241
- out << " ";
242
- more_digits(len - 1, less_decimals);
243
- }
244
- }
245
- return;
246
- }
247
-
248
- if (has_max) {
249
- if (max_value >= 0) {
250
- if (top_level) {
251
- out << "\"-\" [1-9] ";
252
- more_digits(0, less_decimals);
253
- out << " | ";
254
- }
255
- _build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
256
- } else {
257
- out << "\"-\" (";
258
- _build_min_max_int(-max_value, std::numeric_limits<int>::max(), out, decimals_left, /* top_level= */ false);
259
- out << ")";
260
- }
261
- return;
262
- }
263
-
264
- throw std::runtime_error("At least one of min_value or max_value must be set");
265
- }
266
-
267
- const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";
268
-
269
- struct BuiltinRule {
270
- std::string content;
271
- std::vector<std::string> deps;
272
- };
273
-
274
- std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
275
- {"boolean", {"(\"true\" | \"false\") space", {}}},
276
- {"decimal-part", {"[0-9]{1,16}", {}}},
277
- {"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}},
278
- {"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}},
279
- {"integer", {"(\"-\"? integral-part) space", {"integral-part"}}},
280
- {"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}},
281
- {"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}},
282
- {"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}},
283
- {"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}},
284
- {"char", {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}},
285
- {"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}},
286
- {"null", {"\"null\" space", {}}},
287
- };
288
-
289
- std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES = {
290
- {"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}},
291
- {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}},
292
- {"date-time", {"date \"T\" time", {"date", "time"}}},
293
- {"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}},
294
- {"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}},
295
- {"date-time-string", {"\"\\\"\" date-time \"\\\"\" space", {"date-time"}}}
296
- };
297
-
298
- static bool is_reserved_name(const std::string & name) {
299
- static std::unordered_set<std::string> RESERVED_NAMES;
300
- if (RESERVED_NAMES.empty()) {
301
- RESERVED_NAMES.insert("root");
302
- for (const auto &p : PRIMITIVE_RULES) RESERVED_NAMES.insert(p.first);
303
- for (const auto &p : STRING_FORMAT_RULES) RESERVED_NAMES.insert(p.first);
304
- }
305
- return RESERVED_NAMES.find(name) != RESERVED_NAMES.end();
306
- }
307
-
308
- std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
309
- std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
310
- std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
311
- std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
312
- {'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}
313
- };
314
-
315
- std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
316
- std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
317
-
318
- static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
319
- std::smatch match;
320
- std::string result;
321
-
322
- std::string::const_iterator searchStart(input.cbegin());
323
- std::string::const_iterator searchEnd(input.cend());
324
-
325
- while (std::regex_search(searchStart, searchEnd, match, regex)) {
326
- result.append(searchStart, searchStart + match.position());
327
- result.append(replacement(match));
328
- searchStart = match.suffix().first;
329
- }
330
-
331
- result.append(searchStart, searchEnd);
332
-
333
- return result;
334
- }
335
-
336
- static std::string format_literal(const std::string & literal) {
337
- std::string escaped = replacePattern(literal, GRAMMAR_LITERAL_ESCAPE_RE, [&](const std::smatch & match) {
338
- char c = match.str()[0];
339
- return GRAMMAR_LITERAL_ESCAPES.at(c);
340
- });
341
- return "\"" + escaped + "\"";
342
- }
343
-
344
- class SchemaConverter {
345
- private:
346
- friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
347
- std::function<json(const std::string &)> _fetch_json;
348
- bool _dotall;
349
- std::map<std::string, std::string> _rules;
350
- std::unordered_map<std::string, json> _refs;
351
- std::unordered_set<std::string> _refs_being_resolved;
352
- std::vector<std::string> _errors;
353
- std::vector<std::string> _warnings;
354
-
355
- std::string _add_rule(const std::string & name, const std::string & rule) {
356
- std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-");
357
- if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) {
358
- _rules[esc_name] = rule;
359
- return esc_name;
360
- } else {
361
- int i = 0;
362
- while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) {
363
- i++;
364
- }
365
- std::string key = esc_name + std::to_string(i);
366
- _rules[key] = rule;
367
- return key;
368
- }
369
- }
370
-
371
- std::string _generate_union_rule(const std::string & name, const std::vector<json> & alt_schemas) {
372
- std::vector<std::string> rules;
373
- for (size_t i = 0; i < alt_schemas.size(); i++) {
374
- rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
375
- }
376
- return string_join(rules, " | ");
377
- }
378
-
379
- std::string _visit_pattern(const std::string & pattern, const std::string & name) {
380
- if (!(pattern.front() == '^' && pattern.back() == '$')) {
381
- _errors.push_back("Pattern must start with '^' and end with '$'");
382
- return "";
383
- }
384
- std::string sub_pattern = pattern.substr(1, pattern.length() - 2);
385
- std::unordered_map<std::string, std::string> sub_rule_ids;
386
-
387
- size_t i = 0;
388
- size_t length = sub_pattern.length();
389
-
390
- using literal_or_rule = std::pair<std::string, bool>;
391
- auto to_rule = [&](const literal_or_rule & ls) {
392
- auto is_literal = ls.second;
393
- auto s = ls.first;
394
- return is_literal ? "\"" + s + "\"" : s;
395
- };
396
- std::function<literal_or_rule()> transform = [&]() -> literal_or_rule {
397
- size_t start = i;
398
- std::vector<literal_or_rule> seq;
399
-
400
- auto get_dot = [&]() {
401
- std::string rule;
402
- if (_dotall) {
403
- rule = "[\\U00000000-\\U0010FFFF]";
404
- } else {
405
- rule = "[^\\x0A\\x0D]";
406
- }
407
- return _add_rule("dot", rule);
408
- };
409
-
410
- // Joins the sequence, merging consecutive literals together.
411
- auto join_seq = [&]() {
412
- std::vector<literal_or_rule> ret;
413
-
414
- std::string literal;
415
- auto flush_literal = [&]() {
416
- if (literal.empty()) {
417
- return false;
418
- }
419
- ret.emplace_back(literal, true);
420
- literal.clear();
421
- return true;
422
- };
423
-
424
- for (const auto & item : seq) {
425
- auto is_literal = item.second;
426
- if (is_literal) {
427
- literal += item.first;
428
- } else {
429
- flush_literal();
430
- ret.push_back(item);
431
- }
432
- }
433
- flush_literal();
434
-
435
- std::vector<std::string> results;
436
- for (const auto & item : ret) {
437
- results.push_back(to_rule(item));
438
- }
439
- return std::make_pair(string_join(results, " "), false);
440
- };
441
-
442
- while (i < length) {
443
- char c = sub_pattern[i];
444
- if (c == '.') {
445
- seq.emplace_back(get_dot(), false);
446
- i++;
447
- } else if (c == '(') {
448
- i++;
449
- if (i < length) {
450
- if (sub_pattern[i] == '?') {
451
- _warnings.push_back("Unsupported pattern syntax");
452
- }
453
- }
454
- seq.emplace_back("(" + to_rule(transform()) + ")", false);
455
- } else if (c == ')') {
456
- i++;
457
- if (start > 0 && sub_pattern[start - 1] != '(') {
458
- _errors.push_back("Unbalanced parentheses");
459
- }
460
- return join_seq();
461
- } else if (c == '[') {
462
- std::string square_brackets = std::string(1, c);
463
- i++;
464
- while (i < length && sub_pattern[i] != ']') {
465
- if (sub_pattern[i] == '\\') {
466
- square_brackets += sub_pattern.substr(i, 2);
467
- i += 2;
468
- } else {
469
- square_brackets += sub_pattern[i];
470
- i++;
471
- }
472
- }
473
- if (i >= length) {
474
- _errors.push_back("Unbalanced square brackets");
475
- }
476
- square_brackets += ']';
477
- i++;
478
- seq.emplace_back(square_brackets, false);
479
- } else if (c == '|') {
480
- seq.emplace_back("|", false);
481
- i++;
482
- } else if (c == '*' || c == '+' || c == '?') {
483
- seq.back() = std::make_pair(to_rule(seq.back()) + c, false);
484
- i++;
485
- } else if (c == '{') {
486
- std::string curly_brackets = std::string(1, c);
487
- i++;
488
- while (i < length && sub_pattern[i] != '}') {
489
- curly_brackets += sub_pattern[i];
490
- i++;
491
- }
492
- if (i >= length) {
493
- _errors.push_back("Unbalanced curly brackets");
494
- }
495
- curly_brackets += '}';
496
- i++;
497
- auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
498
- int min_times = 0;
499
- int max_times = std::numeric_limits<int>::max();
500
- try {
501
- if (nums.size() == 1) {
502
- min_times = max_times = std::stoi(nums[0]);
503
- } else if (nums.size() != 2) {
504
- _errors.push_back("Wrong number of values in curly brackets");
505
- } else {
506
- if (!nums[0].empty()) {
507
- min_times = std::stoi(nums[0]);
508
- }
509
- if (!nums[1].empty()) {
510
- max_times = std::stoi(nums[1]);
511
- }
512
- }
513
- } catch (const std::invalid_argument & e) {
514
- _errors.push_back("Invalid number in curly brackets");
515
- return std::make_pair("", false);
516
- }
517
- auto &last = seq.back();
518
- auto &sub = last.first;
519
- auto sub_is_literal = last.second;
520
-
521
- if (!sub_is_literal) {
522
- std::string & sub_id = sub_rule_ids[sub];
523
- if (sub_id.empty()) {
524
- sub_id = _add_rule(name + "-" + std::to_string(sub_rule_ids.size()), sub);
525
- }
526
- sub = sub_id;
527
- }
528
- seq.back().first = build_repetition(
529
- sub_is_literal ? "\"" + sub + "\"" : sub,
530
- min_times,
531
- max_times,
532
- ""
533
- );
534
- seq.back().second = false;
535
- } else {
536
- std::string literal;
537
- auto is_non_literal = [&](char c) {
538
- return NON_LITERAL_SET.find(c) != NON_LITERAL_SET.end();
539
- };
540
- while (i < length) {
541
- if (sub_pattern[i] == '\\' && i < length - 1) {
542
- char next = sub_pattern[i + 1];
543
- if (ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.find(next) != ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.end()) {
544
- i++;
545
- literal += sub_pattern[i];
546
- i++;
547
- } else {
548
- literal += sub_pattern.substr(i, 2);
549
- i += 2;
550
- }
551
- } else if (sub_pattern[i] == '"') {
552
- literal += "\\\"";
553
- i++;
554
- } else if (!is_non_literal(sub_pattern[i]) &&
555
- (i == length - 1 || literal.empty() || sub_pattern[i + 1] == '.' || !is_non_literal(sub_pattern[i + 1]))) {
556
- literal += sub_pattern[i];
557
- i++;
558
- } else {
559
- break;
560
- }
561
- }
562
- if (!literal.empty()) {
563
- seq.emplace_back(literal, true);
564
- }
565
- }
566
- }
567
- return join_seq();
568
- };
569
- return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space");
570
- }
571
-
572
- /*
573
- Returns a rule that matches a JSON string that is none of the provided strings
574
-
575
- not_strings({"a"})
576
- -> ["] ( [a] char+ | [^"a] char* )? ["] space
577
- not_strings({"and", "also"})
578
- -> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
579
- */
580
- std::string _not_strings(const std::vector<std::string> & strings) {
581
-
582
- struct TrieNode {
583
- std::map<char, TrieNode> children;
584
- bool is_end_of_string;
585
-
586
- TrieNode() : is_end_of_string(false) {}
587
-
588
- void insert(const std::string & string) {
589
- auto node = this;
590
- for (char c : string) {
591
- node = &node->children[c];
592
- }
593
- node->is_end_of_string = true;
594
- }
595
- };
596
-
597
- TrieNode trie;
598
- for (const auto & s : strings) {
599
- trie.insert(s);
600
- }
601
-
602
- std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
603
- std::ostringstream out;
604
- out << "[\"] ( ";
605
- std::function<void(const TrieNode &)> visit = [&](const TrieNode & node) {
606
- std::ostringstream rejects;
607
- auto first = true;
608
- for (const auto & kv : node.children) {
609
- rejects << kv.first;
610
- if (first) {
611
- first = false;
612
- } else {
613
- out << " | ";
614
- }
615
- out << "[" << kv.first << "]";
616
- if (!kv.second.children.empty()) {
617
- out << " (";
618
- visit(kv.second);
619
- out << ")";
620
- } else if (kv.second.is_end_of_string) {
621
- out << " " << char_rule << "+";
622
- }
623
- }
624
- if (!node.children.empty()) {
625
- if (!first) {
626
- out << " | ";
627
- }
628
- out << "[^\"" << rejects.str() << "] " << char_rule << "*";
629
- }
630
- };
631
- visit(trie);
632
-
633
- out << " )";
634
- if (!trie.is_end_of_string) {
635
- out << "?";
636
- }
637
- out << " [\"] space";
638
- return out.str();
639
- }
640
-
641
- std::string _resolve_ref(const std::string & ref) {
642
- std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
643
- if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
644
- _refs_being_resolved.insert(ref);
645
- json resolved = _refs[ref];
646
- ref_name = visit(resolved, ref_name);
647
- _refs_being_resolved.erase(ref);
648
- }
649
- return ref_name;
650
- }
651
-
652
- std::string _build_object_rule(
653
- const std::vector<std::pair<std::string, json>> & properties,
654
- const std::unordered_set<std::string> & required,
655
- const std::string & name,
656
- const json & additional_properties)
657
- {
658
- std::vector<std::string> required_props;
659
- std::vector<std::string> optional_props;
660
- std::unordered_map<std::string, std::string> prop_kv_rule_names;
661
- std::vector<std::string> prop_names;
662
- for (const auto & kv : properties) {
663
- const auto &prop_name = kv.first;
664
- const auto &prop_schema = kv.second;
665
-
666
- std::string prop_rule_name = visit(prop_schema, name + (name.empty() ? "" : "-") + prop_name);
667
- prop_kv_rule_names[prop_name] = _add_rule(
668
- name + (name.empty() ? "" : "-") + prop_name + "-kv",
669
- format_literal(json(prop_name).dump()) + " space \":\" space " + prop_rule_name
670
- );
671
- if (required.find(prop_name) != required.end()) {
672
- required_props.push_back(prop_name);
673
- } else {
674
- optional_props.push_back(prop_name);
675
- }
676
- prop_names.push_back(prop_name);
677
- }
678
- if ((additional_properties.is_boolean() && additional_properties.get<bool>()) || additional_properties.is_object()) {
679
- std::string sub_name = name + (name.empty() ? "" : "-") + "additional";
680
- std::string value_rule =
681
- additional_properties.is_object() ? visit(additional_properties, sub_name + "-value")
682
- : _add_primitive("value", PRIMITIVE_RULES.at("value"));
683
-
684
- auto key_rule =
685
- prop_names.empty() ? _add_primitive("string", PRIMITIVE_RULES.at("string"))
686
- : _add_rule(sub_name + "-k", _not_strings(prop_names));
687
- std::string kv_rule = _add_rule(sub_name + "-kv", key_rule + " \":\" space " + value_rule);
688
- prop_kv_rule_names["*"] = kv_rule;
689
- optional_props.push_back("*");
690
- }
691
-
692
- std::string rule = "\"{\" space ";
693
- for (size_t i = 0; i < required_props.size(); i++) {
694
- if (i > 0) {
695
- rule += " \",\" space ";
696
- }
697
- rule += prop_kv_rule_names[required_props[i]];
698
- }
699
-
700
- if (!optional_props.empty()) {
701
- rule += " (";
702
- if (!required_props.empty()) {
703
- rule += " \",\" space ( ";
704
- }
705
-
706
- std::function<std::string(const std::vector<std::string> &, bool)> get_recursive_refs = [&](const std::vector<std::string> & ks, bool first_is_optional) {
707
- std::string res;
708
- if (ks.empty()) {
709
- return res;
710
- }
711
- std::string k = ks[0];
712
- std::string kv_rule_name = prop_kv_rule_names[k];
713
- std::string comma_ref = "( \",\" space " + kv_rule_name + " )";
714
- if (first_is_optional) {
715
- res = comma_ref + (k == "*" ? "*" : "?");
716
- } else {
717
- res = kv_rule_name + (k == "*" ? " " + comma_ref + "*" : "");
718
- }
719
- if (ks.size() > 1) {
720
- res += " " + _add_rule(
721
- name + (name.empty() ? "" : "-") + k + "-rest",
722
- get_recursive_refs(std::vector<std::string>(ks.begin() + 1, ks.end()), true)
723
- );
724
- }
725
- return res;
726
- };
727
-
728
- for (size_t i = 0; i < optional_props.size(); i++) {
729
- if (i > 0) {
730
- rule += " | ";
731
- }
732
- rule += get_recursive_refs(std::vector<std::string>(optional_props.begin() + i, optional_props.end()), false);
733
- }
734
- if (!required_props.empty()) {
735
- rule += " )";
736
- }
737
- rule += " )?";
738
- }
739
-
740
- rule += " \"}\" space";
741
-
742
- return rule;
743
- }
744
-
745
- std::string _add_primitive(const std::string & name, const BuiltinRule & rule) {
746
- auto n = _add_rule(name, rule.content);
747
- for (const auto & dep : rule.deps) {
748
- BuiltinRule dep_rule;
749
- auto it = PRIMITIVE_RULES.find(dep);
750
- if (it == PRIMITIVE_RULES.end()) {
751
- it = STRING_FORMAT_RULES.find(dep);
752
- if (it == STRING_FORMAT_RULES.end()) {
753
- _errors.push_back("Rule " + dep + " not known");
754
- continue;
755
- }
756
- }
757
- if (_rules.find(dep) == _rules.end()) {
758
- _add_primitive(dep, it->second);
759
- }
760
- }
761
- return n;
762
- }
763
-
764
- public:
765
- SchemaConverter(
766
- const std::function<json(const std::string &)> & fetch_json,
767
- bool dotall,
768
- bool compact_spaces)
769
- : _fetch_json(fetch_json), _dotall(dotall)
770
- {
771
- _rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE;
772
- }
773
-
774
- void resolve_refs(json & schema, const std::string & url) {
775
- /*
776
- * Resolves all $ref fields in the given schema, fetching any remote schemas,
777
- * replacing each $ref with absolute reference URL and populates _refs with the
778
- * respective referenced (sub)schema dictionaries.
779
- */
780
- std::function<void(json &)> visit_refs = [&](json & n) {
781
- if (n.is_array()) {
782
- for (auto & x : n) {
783
- visit_refs(x);
784
- }
785
- } else if (n.is_object()) {
786
- if (n.contains("$ref")) {
787
- std::string ref = n["$ref"];
788
- if (_refs.find(ref) == _refs.end()) {
789
- json target;
790
- if (ref.find("https://") == 0) {
791
- std::string base_url = ref.substr(0, ref.find('#'));
792
- auto it = _refs.find(base_url);
793
- if (it != _refs.end()) {
794
- target = it->second;
795
- } else {
796
- // Fetch the referenced schema and resolve its refs
797
- auto referenced = _fetch_json(ref);
798
- resolve_refs(referenced, base_url);
799
- _refs[base_url] = referenced;
800
- }
801
- if (ref.find('#') == std::string::npos || ref.substr(ref.find('#') + 1).empty()) {
802
- return;
803
- }
804
- } else if (ref.find("#/") == 0) {
805
- target = schema;
806
- n["$ref"] = url + ref;
807
- ref = url + ref;
808
- } else {
809
- _errors.push_back("Unsupported ref: " + ref);
810
- return;
811
- }
812
- std::string pointer = ref.substr(ref.find('#') + 1);
813
- std::vector<std::string> tokens = string_split(pointer, "/");
814
- for (size_t i = 1; i < tokens.size(); ++i) {
815
- std::string sel = tokens[i];
816
- if (target.is_null() || !target.contains(sel)) {
817
- _errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
818
- return;
819
- }
820
- target = target[sel];
821
- }
822
- _refs[ref] = target;
823
- }
824
- } else {
825
- for (auto & kv : n.items()) {
826
- visit_refs(kv.value());
827
- }
828
- }
829
- }
830
- };
831
-
832
- visit_refs(schema);
833
- }
834
-
835
- std::string _generate_constant_rule(const json & value) {
836
- return format_literal(value.dump());
837
- }
838
-
839
- std::string visit(const json & schema, const std::string & name) {
840
- json schema_type = schema.contains("type") ? schema["type"] : json();
841
- std::string schema_format = schema.contains("format") ? schema["format"].get<std::string>() : "";
842
- std::string rule_name = is_reserved_name(name) ? name + "-" : name.empty() ? "root" : name;
843
-
844
- if (schema.contains("$ref")) {
845
- return _add_rule(rule_name, _resolve_ref(schema["$ref"]));
846
- } else if (schema.contains("oneOf") || schema.contains("anyOf")) {
847
- std::vector<json> alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get<std::vector<json>>() : schema["anyOf"].get<std::vector<json>>();
848
- return _add_rule(rule_name, _generate_union_rule(name, alt_schemas));
849
- } else if (schema_type.is_array()) {
850
- std::vector<json> schema_types;
851
- for (const auto & t : schema_type) {
852
- json schema_copy(schema);
853
- schema_copy["type"] = t;
854
- schema_types.push_back(schema_copy);
855
- }
856
- return _add_rule(rule_name, _generate_union_rule(name, schema_types));
857
- } else if (schema.contains("const")) {
858
- return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space");
859
- } else if (schema.contains("enum")) {
860
- std::vector<std::string> enum_values;
861
- for (const auto & v : schema["enum"]) {
862
- enum_values.push_back(_generate_constant_rule(v));
863
- }
864
- return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
865
- } else if ((schema_type.is_null() || schema_type == "object")
866
- && (schema.contains("properties") ||
867
- (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
868
- std::unordered_set<std::string> required;
869
- if (schema.contains("required") && schema["required"].is_array()) {
870
- for (const auto & item : schema["required"]) {
871
- if (item.is_string()) {
872
- required.insert(item.get<std::string>());
873
- }
874
- }
875
- }
876
- std::vector<std::pair<std::string, json>> properties;
877
- if (schema.contains("properties")) {
878
- for (const auto & prop : schema["properties"].items()) {
879
- properties.emplace_back(prop.key(), prop.value());
880
- }
881
- }
882
- return _add_rule(rule_name,
883
- _build_object_rule(
884
- properties, required, name,
885
- schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
886
- } else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
887
- std::unordered_set<std::string> required;
888
- std::vector<std::pair<std::string, json>> properties;
889
- std::string hybrid_name = name;
890
- std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
891
- if (comp_schema.contains("$ref")) {
892
- add_component(_refs[comp_schema["$ref"]], is_required);
893
- } else if (comp_schema.contains("properties")) {
894
- for (const auto & prop : comp_schema["properties"].items()) {
895
- properties.emplace_back(prop.key(), prop.value());
896
- if (is_required) {
897
- required.insert(prop.key());
898
- }
899
- }
900
- } else {
901
- // todo warning
902
- }
903
- };
904
- for (auto & t : schema["allOf"]) {
905
- if (t.contains("anyOf")) {
906
- for (auto & tt : t["anyOf"]) {
907
- add_component(tt, false);
908
- }
909
- } else {
910
- add_component(t, true);
911
- }
912
- }
913
- return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
914
- } else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
915
- json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
916
- if (items.is_array()) {
917
- std::string rule = "\"[\" space ";
918
- for (size_t i = 0; i < items.size(); i++) {
919
- if (i > 0) {
920
- rule += " \",\" space ";
921
- }
922
- rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i));
923
- }
924
- rule += " \"]\" space";
925
- return _add_rule(rule_name, rule);
926
- } else {
927
- std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
928
- int min_items = schema.contains("minItems") ? schema["minItems"].get<int>() : 0;
929
- json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
930
- int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max();
931
-
932
- return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
933
- }
934
- } else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
935
- return _visit_pattern(schema["pattern"], rule_name);
936
- } else if ((schema_type.is_null() || schema_type == "string") && std::regex_match(schema_format, std::regex("^uuid[1-5]?$"))) {
937
- return _add_primitive(rule_name == "root" ? "root" : schema_format, PRIMITIVE_RULES.at("uuid"));
938
- } else if ((schema_type.is_null() || schema_type == "string") && STRING_FORMAT_RULES.find(schema_format + "-string") != STRING_FORMAT_RULES.end()) {
939
- auto prim_name = schema_format + "-string";
940
- return _add_rule(rule_name, _add_primitive(prim_name, STRING_FORMAT_RULES.at(prim_name)));
941
- } else if (schema_type == "string" && (schema.contains("minLength") || schema.contains("maxLength"))) {
942
- std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
943
- int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
944
- int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
945
- return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
946
- } else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
947
- int min_value = std::numeric_limits<int>::min();
948
- int max_value = std::numeric_limits<int>::max();
949
- if (schema.contains("minimum")) {
950
- min_value = schema["minimum"].get<int>();
951
- } else if (schema.contains("exclusiveMinimum")) {
952
- min_value = schema["exclusiveMinimum"].get<int>() + 1;
953
- }
954
- if (schema.contains("maximum")) {
955
- max_value = schema["maximum"].get<int>();
956
- } else if (schema.contains("exclusiveMaximum")) {
957
- max_value = schema["exclusiveMaximum"].get<int>() - 1;
958
- }
959
- std::stringstream out;
960
- out << "(";
961
- _build_min_max_int(min_value, max_value, out);
962
- out << ") space";
963
- return _add_rule(rule_name, out.str());
964
- } else if (schema.empty() || schema_type == "object") {
965
- return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
966
- } else {
967
- if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get<std::string>()) == PRIMITIVE_RULES.end()) {
968
- _errors.push_back("Unrecognized schema: " + schema.dump());
969
- return "";
970
- }
971
- // TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
972
- return _add_primitive(rule_name == "root" ? "root" : schema_type.get<std::string>(), PRIMITIVE_RULES.at(schema_type.get<std::string>()));
973
- }
974
- }
975
-
976
- void check_errors() {
977
- if (!_errors.empty()) {
978
- throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
979
- }
980
- if (!_warnings.empty()) {
981
- fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
982
- }
983
- }
984
-
985
- std::string format_grammar() {
986
- std::stringstream ss;
987
- for (const auto & kv : _rules) {
988
- ss << kv.first << " ::= " << kv.second << std::endl;
989
- }
990
- return ss.str();
991
- }
992
- };
993
-
994
- std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
995
- #ifdef LLAMA_USE_LLGUIDANCE
996
- if (!force_gbnf) {
997
- return "%llguidance {}\nstart: %json " + schema.dump();
998
- }
999
- #else
1000
- (void)force_gbnf;
1001
- #endif // LLAMA_USE_LLGUIDANCE
1002
- return build_grammar([&](const common_grammar_builder & callbacks) {
1003
- auto copy = schema;
1004
- callbacks.resolve_refs(copy);
1005
- callbacks.add_schema("", copy);
1006
- });
1007
- }
1008
-
1009
- std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
1010
- SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces);
1011
- common_grammar_builder builder {
1012
- /* .add_rule = */ [&](const std::string & name, const std::string & rule) {
1013
- return converter._add_rule(name, rule);
1014
- },
1015
- /* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) {
1016
- return converter.visit(schema, name == "root" ? "" : name);
1017
- },
1018
- /* .resolve_refs = */ [&](nlohmann::ordered_json & schema) {
1019
- converter.resolve_refs(schema, "");
1020
- }
1021
- };
1022
- cb(builder);
1023
- converter.check_errors();
1024
- return converter.format_grammar();
1025
- }
1
+ #include "json-schema-to-grammar.h"
2
+ #include "common.h"
3
+
4
+ #include <algorithm>
5
+ #include <fstream>
6
+ #include <map>
7
+ #include <regex>
8
+ #include <sstream>
9
+ #include <string>
10
+ #include <unordered_map>
11
+ #include <unordered_set>
12
+ #include <vector>
13
+
14
+ using json = nlohmann::ordered_json;
15
+
16
+ static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
17
+ auto has_max = max_items != std::numeric_limits<int>::max();
18
+
19
+ if (min_items == 0 && max_items == 1) {
20
+ return item_rule + "?";
21
+ }
22
+
23
+ if (separator_rule.empty()) {
24
+ if (min_items == 1 && !has_max) {
25
+ return item_rule + "+";
26
+ } else if (min_items == 0 && !has_max) {
27
+ return item_rule + "*";
28
+ } else {
29
+ return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}";
30
+ }
31
+ }
32
+
33
+ auto result = item_rule + " " + build_repetition("(" + separator_rule + " " + item_rule + ")", min_items == 0 ? 0 : min_items - 1, has_max ? max_items - 1 : max_items);
34
+ if (min_items == 0) {
35
+ result = "(" + result + ")?";
36
+ }
37
+ return result;
38
+ }
39
+
40
+ /* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
41
+ class string_view {
42
+ const std::string & _str;
43
+ const size_t _start;
44
+ const size_t _end;
45
+ public:
46
+ string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}
47
+
48
+ size_t size() const {
49
+ return _end - _start;
50
+ }
51
+
52
+ size_t length() const {
53
+ return size();
54
+ }
55
+
56
+ operator std::string() const {
57
+ return str();
58
+ }
59
+
60
+ std::string str() const {
61
+ return _str.substr(_start, _end - _start);
62
+ }
63
+
64
+ string_view substr(size_t pos, size_t len = std::string::npos) const {
65
+ return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
66
+ }
67
+
68
+ char operator[](size_t pos) const {
69
+ auto index = _start + pos;
70
+ if (index >= _end) {
71
+ throw std::out_of_range("string_view index out of range");
72
+ }
73
+ return _str[_start + pos];
74
+ }
75
+
76
+ bool operator==(const string_view & other) const {
77
+ std::string this_str = *this;
78
+ std::string other_str = other;
79
+ return this_str == other_str;
80
+ }
81
+ };
82
+
83
+ static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
84
+ auto has_min = min_value != std::numeric_limits<int>::min();
85
+ auto has_max = max_value != std::numeric_limits<int>::max();
86
+
87
+ auto digit_range = [&](char from, char to) {
88
+ out << "[";
89
+ if (from == to) {
90
+ out << from;
91
+ } else {
92
+ out << from << "-" << to;
93
+ }
94
+ out << "]";
95
+ };
96
+ auto more_digits = [&](int min_digits, int max_digits) {
97
+ out << "[0-9]";
98
+ if (min_digits == max_digits && min_digits == 1) {
99
+ return;
100
+ }
101
+ out << "{";
102
+ out << min_digits;
103
+ if (max_digits != min_digits) {
104
+ out << ",";
105
+ if (max_digits != std::numeric_limits<int>::max()) {
106
+ out << max_digits;
107
+ }
108
+ }
109
+ out << "}";
110
+ };
111
+ std::function<void(const string_view &, const string_view &)> uniform_range =
112
+ [&](const string_view & from, const string_view & to) {
113
+ size_t i = 0;
114
+ while (i < from.length() && i < to.length() && from[i] == to[i]) {
115
+ i++;
116
+ }
117
+ if (i > 0) {
118
+ out << "\"" << from.substr(0, i).str() << "\"";
119
+ }
120
+ if (i < from.length() && i < to.length()) {
121
+ if (i > 0) {
122
+ out << " ";
123
+ }
124
+ auto sub_len = from.length() - i - 1;
125
+ if (sub_len > 0) {
126
+ auto from_sub = from.substr(i + 1);
127
+ auto to_sub = to.substr(i + 1);
128
+ auto sub_zeros = string_repeat("0", sub_len);
129
+ auto sub_nines = string_repeat("9", sub_len);
130
+
131
+ auto to_reached = false;
132
+ out << "(";
133
+ if (from_sub == sub_zeros) {
134
+ digit_range(from[i], to[i] - 1);
135
+ out << " ";
136
+ more_digits(sub_len, sub_len);
137
+ } else {
138
+ out << "[" << from[i] << "] ";
139
+ out << "(";
140
+ uniform_range(from_sub, sub_nines);
141
+ out << ")";
142
+ if (from[i] < to[i] - 1) {
143
+ out << " | ";
144
+ if (to_sub == sub_nines) {
145
+ digit_range(from[i] + 1, to[i]);
146
+ to_reached = true;
147
+ } else {
148
+ digit_range(from[i] + 1, to[i] - 1);
149
+ }
150
+ out << " ";
151
+ more_digits(sub_len, sub_len);
152
+ }
153
+ }
154
+ if (!to_reached) {
155
+ out << " | ";
156
+ digit_range(to[i], to[i]);
157
+ out << " ";
158
+ uniform_range(sub_zeros, to_sub);
159
+ }
160
+ out << ")";
161
+ } else {
162
+ out << "[" << from[i] << "-" << to[i] << "]";
163
+ }
164
+ }
165
+ };
166
+
167
+ if (has_min && has_max) {
168
+ if (min_value < 0 && max_value < 0) {
169
+ out << "\"-\" (";
170
+ _build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true);
171
+ out << ")";
172
+ return;
173
+ }
174
+
175
+ if (min_value < 0) {
176
+ out << "\"-\" (";
177
+ _build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true);
178
+ out << ") | ";
179
+ min_value = 0;
180
+ }
181
+
182
+ auto min_s = std::to_string(min_value);
183
+ auto max_s = std::to_string(max_value);
184
+ auto min_digits = min_s.length();
185
+ auto max_digits = max_s.length();
186
+
187
+ for (auto digits = min_digits; digits < max_digits; digits++) {
188
+ uniform_range(min_s, string_repeat("9", digits));
189
+ min_s = "1" + string_repeat("0", digits);
190
+ out << " | ";
191
+ }
192
+ uniform_range(min_s, max_s);
193
+ return;
194
+ }
195
+
196
+ auto less_decimals = std::max(decimals_left - 1, 1);
197
+
198
+ if (has_min) {
199
+ if (min_value < 0) {
200
+ out << "\"-\" (";
201
+ _build_min_max_int(std::numeric_limits<int>::min(), -min_value, out, decimals_left, /* top_level= */ false);
202
+ out << ") | [0] | [1-9] ";
203
+ more_digits(0, decimals_left - 1);
204
+ } else if (min_value == 0) {
205
+ if (top_level) {
206
+ out << "[0] | [1-9] ";
207
+ more_digits(0, less_decimals);
208
+ } else {
209
+ more_digits(1, decimals_left);
210
+ }
211
+ } else if (min_value <= 9) {
212
+ char c = '0' + min_value;
213
+ auto range_start = top_level ? '1' : '0';
214
+ if (c > range_start) {
215
+ digit_range(range_start, c - 1);
216
+ out << " ";
217
+ more_digits(1, less_decimals);
218
+ out << " | ";
219
+ }
220
+ digit_range(c, '9');
221
+ out << " ";
222
+ more_digits(0, less_decimals);
223
+ } else {
224
+ auto min_s = std::to_string(min_value);
225
+ auto len = min_s.length();
226
+ auto c = min_s[0];
227
+
228
+ if (c > '1') {
229
+ digit_range(top_level ? '1' : '0', c - 1);
230
+ out << " ";
231
+ more_digits(len, less_decimals);
232
+ out << " | ";
233
+ }
234
+ digit_range(c, c);
235
+ out << " (";
236
+ _build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits<int>::max(), out, less_decimals, /* top_level= */ false);
237
+ out << ")";
238
+ if (c < '9') {
239
+ out << " | ";
240
+ digit_range(c + 1, '9');
241
+ out << " ";
242
+ more_digits(len - 1, less_decimals);
243
+ }
244
+ }
245
+ return;
246
+ }
247
+
248
+ if (has_max) {
249
+ if (max_value >= 0) {
250
+ if (top_level) {
251
+ out << "\"-\" [1-9] ";
252
+ more_digits(0, less_decimals);
253
+ out << " | ";
254
+ }
255
+ _build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
256
+ } else {
257
+ out << "\"-\" (";
258
+ _build_min_max_int(-max_value, std::numeric_limits<int>::max(), out, decimals_left, /* top_level= */ false);
259
+ out << ")";
260
+ }
261
+ return;
262
+ }
263
+
264
+ throw std::runtime_error("At least one of min_value or max_value must be set");
265
+ }
266
+
267
+ const std::string SPACE_RULE = "| \" \" | \"\\n\"{1,2} [ \\t]{0,20}";
268
+
269
+ struct BuiltinRule {
270
+ std::string content;
271
+ std::vector<std::string> deps;
272
+ };
273
+
274
+ std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
275
+ {"boolean", {"(\"true\" | \"false\") space", {}}},
276
+ {"decimal-part", {"[0-9]{1,16}", {}}},
277
+ {"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}},
278
+ {"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}},
279
+ {"integer", {"(\"-\"? integral-part) space", {"integral-part"}}},
280
+ {"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}},
281
+ {"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}},
282
+ {"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}},
283
+ {"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}},
284
+ {"char", {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}},
285
+ {"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}},
286
+ {"null", {"\"null\" space", {}}},
287
+ };
288
+
289
+ std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES = {
290
+ {"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}},
291
+ {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}},
292
+ {"date-time", {"date \"T\" time", {"date", "time"}}},
293
+ {"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}},
294
+ {"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}},
295
+ {"date-time-string", {"\"\\\"\" date-time \"\\\"\" space", {"date-time"}}}
296
+ };
297
+
298
+ static bool is_reserved_name(const std::string & name) {
299
+ static std::unordered_set<std::string> RESERVED_NAMES;
300
+ if (RESERVED_NAMES.empty()) {
301
+ RESERVED_NAMES.insert("root");
302
+ for (const auto &p : PRIMITIVE_RULES) RESERVED_NAMES.insert(p.first);
303
+ for (const auto &p : STRING_FORMAT_RULES) RESERVED_NAMES.insert(p.first);
304
+ }
305
+ return RESERVED_NAMES.find(name) != RESERVED_NAMES.end();
306
+ }
307
+
308
+ std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
309
+ std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
310
+ std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
311
+ std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
312
+ {'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}
313
+ };
314
+
315
+ std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
316
+ std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
317
+
318
+ static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
319
+ std::smatch match;
320
+ std::string result;
321
+
322
+ std::string::const_iterator searchStart(input.cbegin());
323
+ std::string::const_iterator searchEnd(input.cend());
324
+
325
+ while (std::regex_search(searchStart, searchEnd, match, regex)) {
326
+ result.append(searchStart, searchStart + match.position());
327
+ result.append(replacement(match));
328
+ searchStart = match.suffix().first;
329
+ }
330
+
331
+ result.append(searchStart, searchEnd);
332
+
333
+ return result;
334
+ }
335
+
336
+ static std::string format_literal(const std::string & literal) {
337
+ std::string escaped = replacePattern(literal, GRAMMAR_LITERAL_ESCAPE_RE, [&](const std::smatch & match) {
338
+ char c = match.str()[0];
339
+ return GRAMMAR_LITERAL_ESCAPES.at(c);
340
+ });
341
+ return "\"" + escaped + "\"";
342
+ }
343
+
344
+ class SchemaConverter {
345
+ private:
346
+ friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
347
+ std::function<json(const std::string &)> _fetch_json;
348
+ bool _dotall;
349
+ std::map<std::string, std::string> _rules;
350
+ std::unordered_map<std::string, json> _refs;
351
+ std::unordered_set<std::string> _refs_being_resolved;
352
+ std::vector<std::string> _errors;
353
+ std::vector<std::string> _warnings;
354
+
355
+ std::string _add_rule(const std::string & name, const std::string & rule) {
356
+ std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-");
357
+ if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) {
358
+ _rules[esc_name] = rule;
359
+ return esc_name;
360
+ } else {
361
+ int i = 0;
362
+ while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) {
363
+ i++;
364
+ }
365
+ std::string key = esc_name + std::to_string(i);
366
+ _rules[key] = rule;
367
+ return key;
368
+ }
369
+ }
370
+
371
+ std::string _generate_union_rule(const std::string & name, const std::vector<json> & alt_schemas) {
372
+ std::vector<std::string> rules;
373
+ for (size_t i = 0; i < alt_schemas.size(); i++) {
374
+ rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
375
+ }
376
+ return string_join(rules, " | ");
377
+ }
378
+
379
+ std::string _visit_pattern(const std::string & pattern, const std::string & name) {
380
+ if (!(pattern.front() == '^' && pattern.back() == '$')) {
381
+ _errors.push_back("Pattern must start with '^' and end with '$'");
382
+ return "";
383
+ }
384
+ std::string sub_pattern = pattern.substr(1, pattern.length() - 2);
385
+ std::unordered_map<std::string, std::string> sub_rule_ids;
386
+
387
+ size_t i = 0;
388
+ size_t length = sub_pattern.length();
389
+
390
+ using literal_or_rule = std::pair<std::string, bool>;
391
+ auto to_rule = [&](const literal_or_rule & ls) {
392
+ auto is_literal = ls.second;
393
+ auto s = ls.first;
394
+ return is_literal ? "\"" + s + "\"" : s;
395
+ };
396
+ std::function<literal_or_rule()> transform = [&]() -> literal_or_rule {
397
+ size_t start = i;
398
+ std::vector<literal_or_rule> seq;
399
+
400
+ auto get_dot = [&]() {
401
+ std::string rule;
402
+ if (_dotall) {
403
+ rule = "[\\U00000000-\\U0010FFFF]";
404
+ } else {
405
+ rule = "[^\\x0A\\x0D]";
406
+ }
407
+ return _add_rule("dot", rule);
408
+ };
409
+
410
+ // Joins the sequence, merging consecutive literals together.
411
+ auto join_seq = [&]() {
412
+ std::vector<literal_or_rule> ret;
413
+
414
+ std::string literal;
415
+ auto flush_literal = [&]() {
416
+ if (literal.empty()) {
417
+ return false;
418
+ }
419
+ ret.emplace_back(literal, true);
420
+ literal.clear();
421
+ return true;
422
+ };
423
+
424
+ for (const auto & item : seq) {
425
+ auto is_literal = item.second;
426
+ if (is_literal) {
427
+ literal += item.first;
428
+ } else {
429
+ flush_literal();
430
+ ret.push_back(item);
431
+ }
432
+ }
433
+ flush_literal();
434
+
435
+ std::vector<std::string> results;
436
+ for (const auto & item : ret) {
437
+ results.push_back(to_rule(item));
438
+ }
439
+ return std::make_pair(string_join(results, " "), false);
440
+ };
441
+
442
+ while (i < length) {
443
+ char c = sub_pattern[i];
444
+ if (c == '.') {
445
+ seq.emplace_back(get_dot(), false);
446
+ i++;
447
+ } else if (c == '(') {
448
+ i++;
449
+ if (i < length) {
450
+ if (sub_pattern[i] == '?') {
451
+ _warnings.push_back("Unsupported pattern syntax");
452
+ }
453
+ }
454
+ seq.emplace_back("(" + to_rule(transform()) + ")", false);
455
+ } else if (c == ')') {
456
+ i++;
457
+ if (start > 0 && sub_pattern[start - 1] != '(') {
458
+ _errors.push_back("Unbalanced parentheses");
459
+ }
460
+ return join_seq();
461
+ } else if (c == '[') {
462
+ std::string square_brackets = std::string(1, c);
463
+ i++;
464
+ while (i < length && sub_pattern[i] != ']') {
465
+ if (sub_pattern[i] == '\\') {
466
+ square_brackets += sub_pattern.substr(i, 2);
467
+ i += 2;
468
+ } else {
469
+ square_brackets += sub_pattern[i];
470
+ i++;
471
+ }
472
+ }
473
+ if (i >= length) {
474
+ _errors.push_back("Unbalanced square brackets");
475
+ }
476
+ square_brackets += ']';
477
+ i++;
478
+ seq.emplace_back(square_brackets, false);
479
+ } else if (c == '|') {
480
+ seq.emplace_back("|", false);
481
+ i++;
482
+ } else if (c == '*' || c == '+' || c == '?') {
483
+ seq.back() = std::make_pair(to_rule(seq.back()) + c, false);
484
+ i++;
485
+ } else if (c == '{') {
486
+ std::string curly_brackets = std::string(1, c);
487
+ i++;
488
+ while (i < length && sub_pattern[i] != '}') {
489
+ curly_brackets += sub_pattern[i];
490
+ i++;
491
+ }
492
+ if (i >= length) {
493
+ _errors.push_back("Unbalanced curly brackets");
494
+ }
495
+ curly_brackets += '}';
496
+ i++;
497
+ auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
498
+ int min_times = 0;
499
+ int max_times = std::numeric_limits<int>::max();
500
+ try {
501
+ if (nums.size() == 1) {
502
+ min_times = max_times = std::stoi(nums[0]);
503
+ } else if (nums.size() != 2) {
504
+ _errors.push_back("Wrong number of values in curly brackets");
505
+ } else {
506
+ if (!nums[0].empty()) {
507
+ min_times = std::stoi(nums[0]);
508
+ }
509
+ if (!nums[1].empty()) {
510
+ max_times = std::stoi(nums[1]);
511
+ }
512
+ }
513
+ } catch (const std::invalid_argument & e) {
514
+ _errors.push_back("Invalid number in curly brackets");
515
+ return std::make_pair("", false);
516
+ }
517
+ auto &last = seq.back();
518
+ auto &sub = last.first;
519
+ auto sub_is_literal = last.second;
520
+
521
+ if (!sub_is_literal) {
522
+ std::string & sub_id = sub_rule_ids[sub];
523
+ if (sub_id.empty()) {
524
+ sub_id = _add_rule(name + "-" + std::to_string(sub_rule_ids.size()), sub);
525
+ }
526
+ sub = sub_id;
527
+ }
528
+ seq.back().first = build_repetition(
529
+ sub_is_literal ? "\"" + sub + "\"" : sub,
530
+ min_times,
531
+ max_times,
532
+ ""
533
+ );
534
+ seq.back().second = false;
535
+ } else {
536
+ std::string literal;
537
+ auto is_non_literal = [&](char c) {
538
+ return NON_LITERAL_SET.find(c) != NON_LITERAL_SET.end();
539
+ };
540
+ while (i < length) {
541
+ if (sub_pattern[i] == '\\' && i < length - 1) {
542
+ char next = sub_pattern[i + 1];
543
+ if (ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.find(next) != ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.end()) {
544
+ i++;
545
+ literal += sub_pattern[i];
546
+ i++;
547
+ } else {
548
+ literal += sub_pattern.substr(i, 2);
549
+ i += 2;
550
+ }
551
+ } else if (sub_pattern[i] == '"') {
552
+ literal += "\\\"";
553
+ i++;
554
+ } else if (!is_non_literal(sub_pattern[i]) &&
555
+ (i == length - 1 || literal.empty() || sub_pattern[i + 1] == '.' || !is_non_literal(sub_pattern[i + 1]))) {
556
+ literal += sub_pattern[i];
557
+ i++;
558
+ } else {
559
+ break;
560
+ }
561
+ }
562
+ if (!literal.empty()) {
563
+ seq.emplace_back(literal, true);
564
+ }
565
+ }
566
+ }
567
+ return join_seq();
568
+ };
569
+ return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space");
570
+ }
571
+
572
+ /*
573
+ Returns a rule that matches a JSON string that is none of the provided strings
574
+
575
+ not_strings({"a"})
576
+ -> ["] ( [a] char+ | [^"a] char* )? ["] space
577
+ not_strings({"and", "also"})
578
+ -> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
579
+ */
580
+ std::string _not_strings(const std::vector<std::string> & strings) {
581
+
582
+ struct TrieNode {
583
+ std::map<char, TrieNode> children;
584
+ bool is_end_of_string;
585
+
586
+ TrieNode() : is_end_of_string(false) {}
587
+
588
+ void insert(const std::string & string) {
589
+ auto node = this;
590
+ for (char c : string) {
591
+ node = &node->children[c];
592
+ }
593
+ node->is_end_of_string = true;
594
+ }
595
+ };
596
+
597
+ TrieNode trie;
598
+ for (const auto & s : strings) {
599
+ trie.insert(s);
600
+ }
601
+
602
+ std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
603
+ std::ostringstream out;
604
+ out << "[\"] ( ";
605
+ std::function<void(const TrieNode &)> visit = [&](const TrieNode & node) {
606
+ std::ostringstream rejects;
607
+ auto first = true;
608
+ for (const auto & kv : node.children) {
609
+ rejects << kv.first;
610
+ if (first) {
611
+ first = false;
612
+ } else {
613
+ out << " | ";
614
+ }
615
+ out << "[" << kv.first << "]";
616
+ if (!kv.second.children.empty()) {
617
+ out << " (";
618
+ visit(kv.second);
619
+ out << ")";
620
+ } else if (kv.second.is_end_of_string) {
621
+ out << " " << char_rule << "+";
622
+ }
623
+ }
624
+ if (!node.children.empty()) {
625
+ if (!first) {
626
+ out << " | ";
627
+ }
628
+ out << "[^\"" << rejects.str() << "] " << char_rule << "*";
629
+ }
630
+ };
631
+ visit(trie);
632
+
633
+ out << " )";
634
+ if (!trie.is_end_of_string) {
635
+ out << "?";
636
+ }
637
+ out << " [\"] space";
638
+ return out.str();
639
+ }
640
+
641
+ std::string _resolve_ref(const std::string & ref) {
642
+ std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
643
+ if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
644
+ _refs_being_resolved.insert(ref);
645
+ json resolved = _refs[ref];
646
+ ref_name = visit(resolved, ref_name);
647
+ _refs_being_resolved.erase(ref);
648
+ }
649
+ return ref_name;
650
+ }
651
+
652
+ std::string _build_object_rule(
653
+ const std::vector<std::pair<std::string, json>> & properties,
654
+ const std::unordered_set<std::string> & required,
655
+ const std::string & name,
656
+ const json & additional_properties)
657
+ {
658
+ std::vector<std::string> required_props;
659
+ std::vector<std::string> optional_props;
660
+ std::unordered_map<std::string, std::string> prop_kv_rule_names;
661
+ std::vector<std::string> prop_names;
662
+ for (const auto & kv : properties) {
663
+ const auto &prop_name = kv.first;
664
+ const auto &prop_schema = kv.second;
665
+
666
+ std::string prop_rule_name = visit(prop_schema, name + (name.empty() ? "" : "-") + prop_name);
667
+ prop_kv_rule_names[prop_name] = _add_rule(
668
+ name + (name.empty() ? "" : "-") + prop_name + "-kv",
669
+ format_literal(json(prop_name).dump()) + " space \":\" space " + prop_rule_name
670
+ );
671
+ if (required.find(prop_name) != required.end()) {
672
+ required_props.push_back(prop_name);
673
+ } else {
674
+ optional_props.push_back(prop_name);
675
+ }
676
+ prop_names.push_back(prop_name);
677
+ }
678
+ if ((additional_properties.is_boolean() && additional_properties.get<bool>()) || additional_properties.is_object()) {
679
+ std::string sub_name = name + (name.empty() ? "" : "-") + "additional";
680
+ std::string value_rule =
681
+ additional_properties.is_object() ? visit(additional_properties, sub_name + "-value")
682
+ : _add_primitive("value", PRIMITIVE_RULES.at("value"));
683
+
684
+ auto key_rule =
685
+ prop_names.empty() ? _add_primitive("string", PRIMITIVE_RULES.at("string"))
686
+ : _add_rule(sub_name + "-k", _not_strings(prop_names));
687
+ std::string kv_rule = _add_rule(sub_name + "-kv", key_rule + " \":\" space " + value_rule);
688
+ prop_kv_rule_names["*"] = kv_rule;
689
+ optional_props.push_back("*");
690
+ }
691
+
692
+ std::string rule = "\"{\" space ";
693
+ for (size_t i = 0; i < required_props.size(); i++) {
694
+ if (i > 0) {
695
+ rule += " \",\" space ";
696
+ }
697
+ rule += prop_kv_rule_names[required_props[i]];
698
+ }
699
+
700
+ if (!optional_props.empty()) {
701
+ rule += " (";
702
+ if (!required_props.empty()) {
703
+ rule += " \",\" space ( ";
704
+ }
705
+
706
+ std::function<std::string(const std::vector<std::string> &, bool)> get_recursive_refs = [&](const std::vector<std::string> & ks, bool first_is_optional) {
707
+ std::string res;
708
+ if (ks.empty()) {
709
+ return res;
710
+ }
711
+ std::string k = ks[0];
712
+ std::string kv_rule_name = prop_kv_rule_names[k];
713
+ std::string comma_ref = "( \",\" space " + kv_rule_name + " )";
714
+ if (first_is_optional) {
715
+ res = comma_ref + (k == "*" ? "*" : "?");
716
+ } else {
717
+ res = kv_rule_name + (k == "*" ? " " + comma_ref + "*" : "");
718
+ }
719
+ if (ks.size() > 1) {
720
+ res += " " + _add_rule(
721
+ name + (name.empty() ? "" : "-") + k + "-rest",
722
+ get_recursive_refs(std::vector<std::string>(ks.begin() + 1, ks.end()), true)
723
+ );
724
+ }
725
+ return res;
726
+ };
727
+
728
+ for (size_t i = 0; i < optional_props.size(); i++) {
729
+ if (i > 0) {
730
+ rule += " | ";
731
+ }
732
+ rule += get_recursive_refs(std::vector<std::string>(optional_props.begin() + i, optional_props.end()), false);
733
+ }
734
+ if (!required_props.empty()) {
735
+ rule += " )";
736
+ }
737
+ rule += " )?";
738
+ }
739
+
740
+ rule += " \"}\" space";
741
+
742
+ return rule;
743
+ }
744
+
745
+ std::string _add_primitive(const std::string & name, const BuiltinRule & rule) {
746
+ auto n = _add_rule(name, rule.content);
747
+ for (const auto & dep : rule.deps) {
748
+ BuiltinRule dep_rule;
749
+ auto it = PRIMITIVE_RULES.find(dep);
750
+ if (it == PRIMITIVE_RULES.end()) {
751
+ it = STRING_FORMAT_RULES.find(dep);
752
+ if (it == STRING_FORMAT_RULES.end()) {
753
+ _errors.push_back("Rule " + dep + " not known");
754
+ continue;
755
+ }
756
+ }
757
+ if (_rules.find(dep) == _rules.end()) {
758
+ _add_primitive(dep, it->second);
759
+ }
760
+ }
761
+ return n;
762
+ }
763
+
764
+ public:
765
+ SchemaConverter(
766
+ const std::function<json(const std::string &)> & fetch_json,
767
+ bool dotall)
768
+ : _fetch_json(fetch_json), _dotall(dotall)
769
+ {
770
+ _rules["space"] = SPACE_RULE;
771
+ }
772
+
773
+ void resolve_refs(json & schema, const std::string & url) {
774
+ /*
775
+ * Resolves all $ref fields in the given schema, fetching any remote schemas,
776
+ * replacing each $ref with absolute reference URL and populates _refs with the
777
+ * respective referenced (sub)schema dictionaries.
778
+ */
779
+ std::function<void(json &)> visit_refs = [&](json & n) {
780
+ if (n.is_array()) {
781
+ for (auto & x : n) {
782
+ visit_refs(x);
783
+ }
784
+ } else if (n.is_object()) {
785
+ if (n.contains("$ref")) {
786
+ std::string ref = n["$ref"];
787
+ if (_refs.find(ref) == _refs.end()) {
788
+ json target;
789
+ if (ref.find("https://") == 0) {
790
+ std::string base_url = ref.substr(0, ref.find('#'));
791
+ auto it = _refs.find(base_url);
792
+ if (it != _refs.end()) {
793
+ target = it->second;
794
+ } else {
795
+ // Fetch the referenced schema and resolve its refs
796
+ auto referenced = _fetch_json(ref);
797
+ resolve_refs(referenced, base_url);
798
+ _refs[base_url] = referenced;
799
+ }
800
+ if (ref.find('#') == std::string::npos || ref.substr(ref.find('#') + 1).empty()) {
801
+ return;
802
+ }
803
+ } else if (ref.find("#/") == 0) {
804
+ target = schema;
805
+ n["$ref"] = url + ref;
806
+ ref = url + ref;
807
+ } else {
808
+ _errors.push_back("Unsupported ref: " + ref);
809
+ return;
810
+ }
811
+ std::string pointer = ref.substr(ref.find('#') + 1);
812
+ std::vector<std::string> tokens = string_split(pointer, "/");
813
+ for (size_t i = 1; i < tokens.size(); ++i) {
814
+ std::string sel = tokens[i];
815
+ if (target.is_null() || !target.contains(sel)) {
816
+ _errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
817
+ return;
818
+ }
819
+ target = target[sel];
820
+ }
821
+ _refs[ref] = target;
822
+ }
823
+ } else {
824
+ for (auto & kv : n.items()) {
825
+ visit_refs(kv.value());
826
+ }
827
+ }
828
+ }
829
+ };
830
+
831
+ visit_refs(schema);
832
+ }
833
+
834
+ std::string _generate_constant_rule(const json & value) {
835
+ return format_literal(value.dump());
836
+ }
837
+
838
+ std::string visit(const json & schema, const std::string & name) {
839
+ json schema_type = schema.contains("type") ? schema["type"] : json();
840
+ std::string schema_format = schema.contains("format") ? schema["format"].get<std::string>() : "";
841
+ std::string rule_name = is_reserved_name(name) ? name + "-" : name.empty() ? "root" : name;
842
+
843
+ if (schema.contains("$ref")) {
844
+ return _add_rule(rule_name, _resolve_ref(schema["$ref"]));
845
+ } else if (schema.contains("oneOf") || schema.contains("anyOf")) {
846
+ std::vector<json> alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get<std::vector<json>>() : schema["anyOf"].get<std::vector<json>>();
847
+ return _add_rule(rule_name, _generate_union_rule(name, alt_schemas));
848
+ } else if (schema_type.is_array()) {
849
+ std::vector<json> schema_types;
850
+ for (const auto & t : schema_type) {
851
+ json schema_copy(schema);
852
+ schema_copy["type"] = t;
853
+ schema_types.push_back(schema_copy);
854
+ }
855
+ return _add_rule(rule_name, _generate_union_rule(name, schema_types));
856
+ } else if (schema.contains("const")) {
857
+ return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space");
858
+ } else if (schema.contains("enum")) {
859
+ std::vector<std::string> enum_values;
860
+ for (const auto & v : schema["enum"]) {
861
+ enum_values.push_back(_generate_constant_rule(v));
862
+ }
863
+ return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
864
+ } else if ((schema_type.is_null() || schema_type == "object")
865
+ && (schema.contains("properties") ||
866
+ (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
867
+ std::unordered_set<std::string> required;
868
+ if (schema.contains("required") && schema["required"].is_array()) {
869
+ for (const auto & item : schema["required"]) {
870
+ if (item.is_string()) {
871
+ required.insert(item.get<std::string>());
872
+ }
873
+ }
874
+ }
875
+ std::vector<std::pair<std::string, json>> properties;
876
+ if (schema.contains("properties")) {
877
+ for (const auto & prop : schema["properties"].items()) {
878
+ properties.emplace_back(prop.key(), prop.value());
879
+ }
880
+ }
881
+ return _add_rule(rule_name,
882
+ _build_object_rule(
883
+ properties, required, name,
884
+ schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
885
+ } else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
886
+ std::unordered_set<std::string> required;
887
+ std::vector<std::pair<std::string, json>> properties;
888
+ std::string hybrid_name = name;
889
+ std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
890
+ if (comp_schema.contains("$ref")) {
891
+ add_component(_refs[comp_schema["$ref"]], is_required);
892
+ } else if (comp_schema.contains("properties")) {
893
+ for (const auto & prop : comp_schema["properties"].items()) {
894
+ properties.emplace_back(prop.key(), prop.value());
895
+ if (is_required) {
896
+ required.insert(prop.key());
897
+ }
898
+ }
899
+ } else {
900
+ // todo warning
901
+ }
902
+ };
903
+ for (auto & t : schema["allOf"]) {
904
+ if (t.contains("anyOf")) {
905
+ for (auto & tt : t["anyOf"]) {
906
+ add_component(tt, false);
907
+ }
908
+ } else {
909
+ add_component(t, true);
910
+ }
911
+ }
912
+ return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
913
+ } else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
914
+ json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
915
+ if (items.is_array()) {
916
+ std::string rule = "\"[\" space ";
917
+ for (size_t i = 0; i < items.size(); i++) {
918
+ if (i > 0) {
919
+ rule += " \",\" space ";
920
+ }
921
+ rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i));
922
+ }
923
+ rule += " \"]\" space";
924
+ return _add_rule(rule_name, rule);
925
+ } else {
926
+ std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
927
+ int min_items = schema.contains("minItems") ? schema["minItems"].get<int>() : 0;
928
+ json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
929
+ int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max();
930
+
931
+ return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
932
+ }
933
+ } else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
934
+ return _visit_pattern(schema["pattern"], rule_name);
935
+ } else if ((schema_type.is_null() || schema_type == "string") && std::regex_match(schema_format, std::regex("^uuid[1-5]?$"))) {
936
+ return _add_primitive(rule_name == "root" ? "root" : schema_format, PRIMITIVE_RULES.at("uuid"));
937
+ } else if ((schema_type.is_null() || schema_type == "string") && STRING_FORMAT_RULES.find(schema_format + "-string") != STRING_FORMAT_RULES.end()) {
938
+ auto prim_name = schema_format + "-string";
939
+ return _add_rule(rule_name, _add_primitive(prim_name, STRING_FORMAT_RULES.at(prim_name)));
940
+ } else if (schema_type == "string" && (schema.contains("minLength") || schema.contains("maxLength"))) {
941
+ std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
942
+ int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
943
+ int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
944
+ return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
945
+ } else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
946
+ int min_value = std::numeric_limits<int>::min();
947
+ int max_value = std::numeric_limits<int>::max();
948
+ if (schema.contains("minimum")) {
949
+ min_value = schema["minimum"].get<int>();
950
+ } else if (schema.contains("exclusiveMinimum")) {
951
+ min_value = schema["exclusiveMinimum"].get<int>() + 1;
952
+ }
953
+ if (schema.contains("maximum")) {
954
+ max_value = schema["maximum"].get<int>();
955
+ } else if (schema.contains("exclusiveMaximum")) {
956
+ max_value = schema["exclusiveMaximum"].get<int>() - 1;
957
+ }
958
+ std::stringstream out;
959
+ out << "(";
960
+ _build_min_max_int(min_value, max_value, out);
961
+ out << ") space";
962
+ return _add_rule(rule_name, out.str());
963
+ } else if (schema.empty() || schema_type == "object") {
964
+ return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
965
+ } else {
966
+ if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get<std::string>()) == PRIMITIVE_RULES.end()) {
967
+ _errors.push_back("Unrecognized schema: " + schema.dump());
968
+ return "";
969
+ }
970
+ // TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
971
+ return _add_primitive(rule_name == "root" ? "root" : schema_type.get<std::string>(), PRIMITIVE_RULES.at(schema_type.get<std::string>()));
972
+ }
973
+ }
974
+
975
+ void check_errors() {
976
+ if (!_errors.empty()) {
977
+ throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
978
+ }
979
+ if (!_warnings.empty()) {
980
+ fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
981
+ }
982
+ }
983
+
984
+ std::string format_grammar() {
985
+ std::stringstream ss;
986
+ for (const auto & kv : _rules) {
987
+ ss << kv.first << " ::= " << kv.second << std::endl;
988
+ }
989
+ return ss.str();
990
+ }
991
+ };
992
+
993
+ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
994
+ #ifdef LLAMA_USE_LLGUIDANCE
995
+ if (!force_gbnf) {
996
+ return "%llguidance {}\nstart: %json " + schema.dump();
997
+ }
998
+ #else
999
+ (void)force_gbnf;
1000
+ #endif // LLAMA_USE_LLGUIDANCE
1001
+ return build_grammar([&](const common_grammar_builder & callbacks) {
1002
+ auto copy = schema;
1003
+ callbacks.resolve_refs(copy);
1004
+ callbacks.add_schema("", copy);
1005
+ });
1006
+ }
1007
+
1008
+ std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
1009
+ SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
1010
+ common_grammar_builder builder {
1011
+ /* .add_rule = */ [&](const std::string & name, const std::string & rule) {
1012
+ return converter._add_rule(name, rule);
1013
+ },
1014
+ /* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) {
1015
+ return converter.visit(schema, name == "root" ? "" : name);
1016
+ },
1017
+ /* .resolve_refs = */ [&](nlohmann::ordered_json & schema) {
1018
+ converter.resolve_refs(schema, "");
1019
+ }
1020
+ };
1021
+ cb(builder);
1022
+ converter.check_errors();
1023
+ return converter.format_grammar();
1024
+ }