llama-cpp-python-win 0.3.16__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. bin/convert_hf_to_gguf.py +8751 -0
  2. bin/ggml-base.dll +0 -0
  3. bin/ggml-cpu.dll +0 -0
  4. bin/ggml.dll +0 -0
  5. bin/llama-mtmd-cli.exe +0 -0
  6. bin/llama.dll +0 -0
  7. bin/mtmd.dll +0 -0
  8. include/ggml-alloc.h +76 -0
  9. include/ggml-backend.h +354 -0
  10. include/ggml-blas.h +25 -0
  11. include/ggml-cann.h +123 -0
  12. include/ggml-cpp.h +39 -0
  13. include/ggml-cpu.h +145 -0
  14. include/ggml-cuda.h +47 -0
  15. include/ggml-metal.h +66 -0
  16. include/ggml-opt.h +256 -0
  17. include/ggml-rpc.h +33 -0
  18. include/ggml-sycl.h +49 -0
  19. include/ggml-vulkan.h +29 -0
  20. include/ggml-webgpu.h +19 -0
  21. include/ggml.h +2467 -0
  22. include/gguf.h +202 -0
  23. include/llama-cpp.h +30 -0
  24. include/llama.h +1482 -0
  25. include/mtmd-helper.h +91 -0
  26. include/mtmd.h +298 -0
  27. lib/cmake/ggml/ggml-config.cmake +328 -0
  28. lib/cmake/ggml/ggml-version.cmake +65 -0
  29. lib/cmake/llama/llama-config.cmake +54 -0
  30. lib/cmake/llama/llama-version.cmake +65 -0
  31. lib/ggml-base.lib +0 -0
  32. lib/ggml-cpu.lib +0 -0
  33. lib/ggml.lib +0 -0
  34. lib/llama.lib +0 -0
  35. lib/mtmd.lib +0 -0
  36. lib/pkgconfig/llama.pc +10 -0
  37. llama_cpp/__init__.py +4 -0
  38. llama_cpp/_ctypes_extensions.py +131 -0
  39. llama_cpp/_ggml.py +12 -0
  40. llama_cpp/_internals.py +856 -0
  41. llama_cpp/_logger.py +47 -0
  42. llama_cpp/_utils.py +78 -0
  43. llama_cpp/lib/ggml-base.dll +0 -0
  44. llama_cpp/lib/ggml-base.lib +0 -0
  45. llama_cpp/lib/ggml-cpu.dll +0 -0
  46. llama_cpp/lib/ggml-cpu.lib +0 -0
  47. llama_cpp/lib/ggml.dll +0 -0
  48. llama_cpp/lib/ggml.lib +0 -0
  49. llama_cpp/lib/llama.dll +0 -0
  50. llama_cpp/lib/llama.lib +0 -0
  51. llama_cpp/lib/mtmd.dll +0 -0
  52. llama_cpp/lib/mtmd.lib +0 -0
  53. llama_cpp/llama.py +2422 -0
  54. llama_cpp/llama_cache.py +155 -0
  55. llama_cpp/llama_chat_format.py +3962 -0
  56. llama_cpp/llama_cpp.py +4374 -0
  57. llama_cpp/llama_grammar.py +953 -0
  58. llama_cpp/llama_speculative.py +64 -0
  59. llama_cpp/llama_tokenizer.py +120 -0
  60. llama_cpp/llama_types.py +316 -0
  61. llama_cpp/llava_cpp.py +158 -0
  62. llama_cpp/mtmd_cpp.py +280 -0
  63. llama_cpp/py.typed +0 -0
  64. llama_cpp/server/__init__.py +0 -0
  65. llama_cpp/server/__main__.py +100 -0
  66. llama_cpp/server/app.py +597 -0
  67. llama_cpp/server/cli.py +97 -0
  68. llama_cpp/server/errors.py +212 -0
  69. llama_cpp/server/model.py +312 -0
  70. llama_cpp/server/settings.py +240 -0
  71. llama_cpp/server/types.py +316 -0
  72. llama_cpp_python_win-0.3.16.dist-info/METADATA +856 -0
  73. llama_cpp_python_win-0.3.16.dist-info/RECORD +75 -0
  74. llama_cpp_python_win-0.3.16.dist-info/WHEEL +5 -0
  75. llama_cpp_python_win-0.3.16.dist-info/licenses/LICENSE.md +9 -0
@@ -0,0 +1,953 @@
1
+ """Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp."""
2
+
3
+ # flake8: noqa
4
+ from pathlib import Path
5
+
6
+ from itertools import groupby
7
+ from typing import (
8
+ Any,
9
+ Set,
10
+ List,
11
+ Optional,
12
+ Tuple,
13
+ Union,
14
+ )
15
+
16
+ LLAMA_GRAMMAR_DEFAULT_ROOT = "root"
17
+
18
+
19
+ class LlamaGrammar:
20
+ def __init__(self, *args, _grammar: str, **kwargs):
21
+ self._grammar = _grammar
22
+ self._root = LLAMA_GRAMMAR_DEFAULT_ROOT
23
+
24
+ @classmethod
25
+ def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
26
+ return cls(_grammar=grammar)
27
+
28
+ @classmethod
29
+ def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar":
30
+ try:
31
+ with open(file) as f:
32
+ grammar = f.read()
33
+ except Exception as err:
34
+ raise Exception(
35
+ f"{cls.from_file.__name__}: error reading grammar file: {err}"
36
+ )
37
+
38
+ if grammar:
39
+ return cls.from_string(grammar, verbose=verbose)
40
+
41
+ raise ValueError(
42
+ f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
43
+ )
44
+
45
+ @classmethod
46
+ def from_json_schema(cls, json_schema: str, verbose: bool = True) -> "LlamaGrammar":
47
+ return cls.from_string(json_schema_to_gbnf(json_schema), verbose=verbose)
48
+
49
+
50
+ """llama.cpp gbnf rules from vendor/llama.cpp/grammars"""
51
+
52
+ ARITHMETIC_GBNF = r"""
53
+ root ::= (expr "=" ws term "\n")+
54
+ expr ::= term ([-+*/] term)*
55
+ term ::= ident | num | "(" ws expr ")" ws
56
+ ident ::= [a-z] [a-z0-9_]* ws
57
+ num ::= [0-9]+ ws
58
+ ws ::= [ \t\n]*
59
+ """
60
+
61
+ C_GBNF = r"""
62
+ root ::= (declaration)*
63
+
64
+ declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}"
65
+
66
+ dataType ::= "int" ws | "float" ws | "char" ws
67
+ identifier ::= [a-zA-Z_] [a-zA-Z_0-9]*
68
+
69
+ parameter ::= dataType identifier
70
+
71
+ statement ::=
72
+ ( dataType identifier ws "=" ws expression ";" ) |
73
+ ( identifier ws "=" ws expression ";" ) |
74
+ ( identifier ws "(" argList? ")" ";" ) |
75
+ ( "return" ws expression ";" ) |
76
+ ( "while" "(" condition ")" "{" statement* "}" ) |
77
+ ( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) |
78
+ ( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) |
79
+ ( singleLineComment ) |
80
+ ( multiLineComment )
81
+
82
+ forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression
83
+ forUpdate ::= identifier ws "=" ws expression
84
+
85
+ condition ::= expression relationOperator expression
86
+ relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">")
87
+
88
+ expression ::= term (("+" | "-") term)*
89
+ term ::= factor(("*" | "/") factor)*
90
+
91
+ factor ::= identifier | number | unaryTerm | funcCall | parenExpression
92
+ unaryTerm ::= "-" factor
93
+ funcCall ::= identifier "(" argList? ")"
94
+ parenExpression ::= "(" ws expression ws ")"
95
+
96
+ argList ::= expression ("," ws expression)*
97
+
98
+ number ::= [0-9]+
99
+
100
+ singleLineComment ::= "//" [^\n]* "\n"
101
+ multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/"
102
+
103
+ ws ::= ([ \t\n]+)
104
+ """
105
+
106
+ CHESS_GBNF = r"""
107
+ root ::= object
108
+ value ::= object | array | string | number | ("true" | "false" | "null") ws
109
+
110
+ object ::=
111
+ "{" ws (
112
+ string ":" ws value
113
+ ("," ws string ":" ws value)*
114
+ )? "}" ws
115
+
116
+ array ::=
117
+ "[" ws (
118
+ value
119
+ ("," ws value)*
120
+ )? "]" ws
121
+
122
+ string ::=
123
+ "\"" (
124
+ [^"\\] |
125
+ "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
126
+ )* "\"" ws
127
+
128
+ number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
129
+
130
+ # Optional space: by convention, applied in this grammar after literal chars when allowed
131
+ ws ::= ([ \t\n] ws)?
132
+ """
133
+
134
+ JAPANESE_GBNF = r"""
135
+ root ::= object
136
+ value ::= object | array | string | number | ("true" | "false" | "null") ws
137
+
138
+ object ::=
139
+ "{" ws (
140
+ string ":" ws value
141
+ ("," ws string ":" ws value)*
142
+ )? "}" ws
143
+
144
+ array ::=
145
+ "[" ws (
146
+ value
147
+ ("," ws value)*
148
+ )? "]" ws
149
+
150
+ string ::=
151
+ "\"" (
152
+ [^"\\] |
153
+ "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
154
+ )* "\"" ws
155
+
156
+ number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
157
+
158
+ # Optional space: by convention, applied in this grammar after literal chars when allowed
159
+ ws ::= ([ \t\n] ws)?
160
+ """
161
+
162
+ JSON_ARR_GBNF = r"""
163
+ # This is the same as json.gbnf but we restrict whitespaces at the end of the root array
164
+ # Useful for generating JSON arrays
165
+
166
+ root ::= arr
167
+ value ::= object | array | string | number | ("true" | "false" | "null") ws
168
+
169
+ arr ::=
170
+ "[\n" ws (
171
+ value
172
+ (",\n" ws value)*
173
+ )? "]"
174
+
175
+ object ::=
176
+ "{" ws (
177
+ string ":" ws value
178
+ ("," ws string ":" ws value)*
179
+ )? "}" ws
180
+
181
+ array ::=
182
+ "[" ws (
183
+ value
184
+ ("," ws value)*
185
+ )? "]" ws
186
+
187
+ string ::=
188
+ "\"" (
189
+ [^"\\\x7F\x00-\x1F] |
190
+ "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
191
+ )* "\"" ws
192
+
193
+ number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
194
+
195
+ # Optional space: by convention, applied in this grammar after literal chars when allowed
196
+ ws ::= ([ \t\n] ws)?
197
+ """
198
+
199
+
200
+ JSON_GBNF = r"""
201
+ root ::= object
202
+ value ::= object | array | string | number | ("true" | "false" | "null") ws
203
+
204
+ object ::=
205
+ "{" ws (
206
+ string ":" ws value
207
+ ("," ws string ":" ws value)*
208
+ )? "}" ws
209
+
210
+ array ::=
211
+ "[" ws (
212
+ value
213
+ ("," ws value)*
214
+ )? "]" ws
215
+
216
+ string ::=
217
+ "\"" (
218
+ [^"\\\x7F\x00-\x1F] |
219
+ "\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes
220
+ )* "\"" ws
221
+
222
+ number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [0-9] [1-9]{0,15})? ws
223
+
224
+ # Optional space: by convention, applied in this grammar after literal chars when allowed
225
+ ws ::= | " " | "\n" [ \t]{0,20}
226
+ """
227
+
228
+ LIST_GBNF = r"""
229
+ root ::= item+
230
+
231
+ # Excludes various line break characters
232
+ item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n"
233
+ """
234
+
235
+ """llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py"""
236
+ import json
237
+ import re
238
+ from typing import List, Optional
239
+
240
+ # whitespace is constrained to a single space char to prevent model "running away" in
241
+ # whitespace. Also maybe improves generation quality?
242
+ SPACE_RULE = '" "?'
243
+
244
+
245
+ INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
246
+ GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
247
+ GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'}
248
+
249
+ # whitespace is constrained to a single space char to prevent model "running away" in
250
+ # whitespace. Also maybe improves generation quality?
251
+ SPACE_RULE = '" "?'
252
+
253
+
254
+ def _build_repetition(
255
+ item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False
256
+ ):
257
+ if not separator_rule:
258
+ if min_items == 0 and max_items == 1:
259
+ return f"{item_rule}?"
260
+ elif min_items == 1 and max_items is None:
261
+ return f"{item_rule}+"
262
+
263
+ result = ""
264
+
265
+ if min_items > 0:
266
+ if item_rule_is_literal and separator_rule is None:
267
+ result = '"' + (item_rule[1:-1] * min_items) + '"'
268
+ else:
269
+ result = (f" {separator_rule} " if separator_rule else " ").join(
270
+ [item_rule] * min_items
271
+ )
272
+
273
+ def opt_repetitions(up_to_n, prefix_with_sep=False):
274
+ """
275
+ - n=4, no sep: '(a (a (a (a)?)?)?)?'
276
+ - n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?'
277
+ - n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?'
278
+ """
279
+
280
+ content = (
281
+ f"{separator_rule} {item_rule}"
282
+ if prefix_with_sep and separator_rule
283
+ else item_rule
284
+ )
285
+ if up_to_n == 0:
286
+ return ""
287
+ elif up_to_n == 1:
288
+ return f"({content})?"
289
+ elif separator_rule and not prefix_with_sep:
290
+ return f"({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?"
291
+ else:
292
+ return (f"({content} " * up_to_n).rstrip() + (")?" * up_to_n)
293
+
294
+ if min_items > 0 and max_items != min_items:
295
+ result += " "
296
+
297
+ if max_items is not None:
298
+ result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0)
299
+ else:
300
+ item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})'
301
+
302
+ if min_items == 0 and separator_rule:
303
+ result = f"({item_rule} {item_operator}*)?"
304
+ else:
305
+ result += f"{item_operator}*"
306
+
307
+ return result
308
+
309
+
310
+ class BuiltinRule:
311
+ def __init__(self, content: str, deps: list = None):
312
+ self.content = content
313
+ self.deps = deps or []
314
+
315
+
316
+ _up_to_15_digits = _build_repetition("[0-9]", 0, 15)
317
+
318
+ PRIMITIVE_RULES = {
319
+ "boolean": BuiltinRule('("true" | "false") space', []),
320
+ "decimal-part": BuiltinRule("[0-9] " + _up_to_15_digits, []),
321
+ "integral-part": BuiltinRule("[0-9] | [1-9] " + _up_to_15_digits, []),
322
+ "number": BuiltinRule(
323
+ '("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space',
324
+ ["integral-part", "decimal-part"],
325
+ ),
326
+ "integer": BuiltinRule('("-"? integral-part) space', ["integral-part"]),
327
+ "value": BuiltinRule(
328
+ "object | array | string | number | boolean | null",
329
+ ["object", "array", "string", "number", "boolean", "null"],
330
+ ),
331
+ "object": BuiltinRule(
332
+ '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space',
333
+ ["string", "value"],
334
+ ),
335
+ "array": BuiltinRule(
336
+ '"[" space ( value ("," space value)* )? "]" space', ["value"]
337
+ ),
338
+ "uuid": BuiltinRule(
339
+ r'"\"" '
340
+ + ' "-" '.join("[0-9a-fA-F]" * n for n in [8, 4, 4, 4, 12])
341
+ + r' "\"" space',
342
+ [],
343
+ ),
344
+ "char": BuiltinRule(
345
+ r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])',
346
+ [],
347
+ ),
348
+ "string": BuiltinRule(r'"\"" char* "\"" space', ["char"]),
349
+ "null": BuiltinRule('"null" space', []),
350
+ }
351
+
352
+ # TODO: support "uri", "email" string formats
353
+ STRING_FORMAT_RULES = {
354
+ "date": BuiltinRule(
355
+ '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( "0" [1-9] | [1-2] [0-9] | "3" [0-1] )',
356
+ [],
357
+ ),
358
+ "time": BuiltinRule(
359
+ '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )',
360
+ [],
361
+ ),
362
+ "date-time": BuiltinRule('date "T" time', ["date", "time"]),
363
+ "date-string": BuiltinRule('"\\"" date "\\"" space', ["date"]),
364
+ "time-string": BuiltinRule('"\\"" time "\\"" space', ["time"]),
365
+ "date-time-string": BuiltinRule('"\\"" date-time "\\"" space', ["date-time"]),
366
+ }
367
+
368
+ DOTALL = "[\\U00000000-\\U0010FFFF]"
369
+ DOT = "[^\\x0A\\x0D]"
370
+
371
+ RESERVED_NAMES = set(
372
+ ["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]
373
+ )
374
+
375
+
376
+ NON_LITERAL_SET = set("|.()[]{}*+?")
377
+ ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set("[]()|{}*+?")
378
+
379
+
380
+ class SchemaConverter:
381
+ def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
382
+ self._prop_order = prop_order
383
+ self._allow_fetch = allow_fetch
384
+ self._dotall = dotall
385
+ self._raw_pattern = raw_pattern
386
+ self._rules = {
387
+ "space": SPACE_RULE,
388
+ }
389
+ self._refs = {}
390
+ self._refs_being_resolved = set()
391
+
392
+ def _format_literal(self, literal):
393
+ escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
394
+ lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
395
+ )
396
+ return f'"{escaped}"'
397
+
398
+ def not_literal(
399
+ self, literal: str, dotall: bool = True, maybe_escaped_underscores=False
400
+ ) -> str:
401
+ """
402
+ not_literal('a') -> '[^a]'
403
+ not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
404
+ """
405
+ assert len(literal) > 0, "Empty literal not supported"
406
+
407
+ def recurse(i: int):
408
+ c = literal[i]
409
+ if maybe_escaped_underscores and c == "_":
410
+ yield f"[^{c}\\\\]"
411
+ yield " | "
412
+ yield f'"\\\\"? "{c}"'
413
+ else:
414
+ yield f"[^{c}]"
415
+ if i < len(literal) - 1:
416
+ yield " | "
417
+ yield self._format_literal(c)
418
+ yield " ("
419
+ yield from recurse(i + 1)
420
+ yield ")?"
421
+
422
+ return "".join(("(", *recurse(0), ")"))
423
+
424
+ def _add_rule(self, name, rule):
425
+ esc_name = INVALID_RULE_CHARS_RE.sub("-", name)
426
+ if esc_name not in self._rules or self._rules[esc_name] == rule:
427
+ key = esc_name
428
+ else:
429
+ i = 0
430
+ while (
431
+ f"{esc_name}{i}" in self._rules
432
+ and self._rules[f"{esc_name}{i}"] != rule
433
+ ):
434
+ i += 1
435
+ key = f"{esc_name}{i}"
436
+ self._rules[key] = rule
437
+ return key
438
+
439
+ def resolve_refs(self, schema: dict, url: str):
440
+ """
441
+ Resolves all $ref fields in the given schema, fetching any remote schemas,
442
+ replacing $ref with absolute reference URL and populating self._refs with the
443
+ respective referenced (sub)schema dictionaries.
444
+ """
445
+
446
+ def visit(n: dict):
447
+ if isinstance(n, list):
448
+ return [visit(x) for x in n]
449
+ elif isinstance(n, dict):
450
+ ref = n.get("$ref")
451
+ if ref is not None and ref not in self._refs:
452
+ if ref.startswith("https://"):
453
+ assert (
454
+ self._allow_fetch
455
+ ), "Fetching remote schemas is not allowed (use --allow-fetch for force)"
456
+ import requests
457
+
458
+ frag_split = ref.split("#")
459
+ base_url = frag_split[0]
460
+
461
+ target = self._refs.get(base_url)
462
+ if target is None:
463
+ target = self.resolve_refs(
464
+ requests.get(ref).json(), base_url
465
+ )
466
+ self._refs[base_url] = target
467
+
468
+ if len(frag_split) == 1 or frag_split[-1] == "":
469
+ return target
470
+ elif ref.startswith("#/"):
471
+ target = schema
472
+ ref = f"{url}{ref}"
473
+ n["$ref"] = ref
474
+ else:
475
+ raise ValueError(f"Unsupported ref {ref}")
476
+
477
+ for sel in ref.split("#")[-1].split("/")[1:]:
478
+ assert (
479
+ target is not None and sel in target
480
+ ), f"Error resolving ref {ref}: {sel} not in {target}"
481
+ target = target[sel]
482
+
483
+ self._refs[ref] = target
484
+ else:
485
+ for v in n.values():
486
+ visit(v)
487
+
488
+ return n
489
+
490
+ return visit(schema)
491
+
492
+ def _generate_union_rule(self, name, alt_schemas):
493
+ return " | ".join(
494
+ (
495
+ self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
496
+ for i, alt_schema in enumerate(alt_schemas)
497
+ )
498
+ )
499
+
500
+ def _visit_pattern(self, pattern, name):
501
+ """
502
+ Transforms a regular expression pattern into a GBNF rule.
503
+
504
+ Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
505
+ Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
506
+
507
+ Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
508
+
509
+ Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
510
+ we define sub-rules to keep the output lean.
511
+ """
512
+
513
+ assert pattern.startswith("^") and pattern.endswith(
514
+ "$"
515
+ ), 'Pattern must start with "^" and end with "$"'
516
+ pattern = pattern[1:-1]
517
+ sub_rule_ids = {}
518
+
519
+ i = 0
520
+ length = len(pattern)
521
+
522
+ def to_rule(s: Tuple[str, bool]) -> str:
523
+ (txt, is_literal) = s
524
+ return '"' + txt + '"' if is_literal else txt
525
+
526
+ def transform() -> Tuple[str, bool]:
527
+ """
528
+ Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
529
+ """
530
+ nonlocal i
531
+ nonlocal pattern
532
+ nonlocal sub_rule_ids
533
+
534
+ start = i
535
+ # For each component of this sequence, store its string representation and whether it's a literal.
536
+ # We only need a flat structure here to apply repetition operators to the last item, and
537
+ # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
538
+ # (GBNF's syntax is luckily very close to regular expressions!)
539
+ seq: list[Tuple[str, bool]] = []
540
+
541
+ def get_dot():
542
+ if self._dotall:
543
+ rule = DOTALL
544
+ else:
545
+ # Accept any character... except \n and \r line break chars (\x0A and \xOD)
546
+ rule = DOT
547
+ return self._add_rule(f"dot", rule)
548
+
549
+ def join_seq():
550
+ nonlocal seq
551
+ ret = []
552
+ for is_literal, g in groupby(seq, lambda x: x[1]):
553
+ if is_literal:
554
+ ret.append(("".join(x[0] for x in g), True))
555
+ else:
556
+ ret.extend(g)
557
+ if len(ret) == 1:
558
+ return ret[0]
559
+ return (" ".join(to_rule(x) for x in seq), False)
560
+
561
+ while i < length:
562
+ c = pattern[i]
563
+ if c == ".":
564
+ seq.append((get_dot(), False))
565
+ i += 1
566
+ elif c == "(":
567
+ i += 1
568
+ if i < length:
569
+ assert (
570
+ pattern[i] != "?"
571
+ ), f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
572
+ seq.append((f"({to_rule(transform())})", False))
573
+ elif c == ")":
574
+ i += 1
575
+ assert (
576
+ start > 0 and pattern[start - 1] == "("
577
+ ), f"Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}"
578
+ return join_seq()
579
+ elif c == "[":
580
+ square_brackets = c
581
+ i += 1
582
+ while i < length and pattern[i] != "]":
583
+ if pattern[i] == "\\":
584
+ square_brackets += pattern[i : i + 2]
585
+ i += 2
586
+ else:
587
+ square_brackets += pattern[i]
588
+ i += 1
589
+ assert (
590
+ i < length
591
+ ), f"Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}"
592
+ square_brackets += "]"
593
+ i += 1
594
+ seq.append((square_brackets, False))
595
+ elif c == "|":
596
+ seq.append(("|", False))
597
+ i += 1
598
+ elif c in ("*", "+", "?"):
599
+ seq[-1] = (to_rule(seq[-1]) + c, False)
600
+ i += 1
601
+ elif c == "{":
602
+ curly_brackets = c
603
+ i += 1
604
+ while i < length and pattern[i] != "}":
605
+ curly_brackets += pattern[i]
606
+ i += 1
607
+ assert (
608
+ i < length
609
+ ), f"Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}"
610
+ curly_brackets += "}"
611
+ i += 1
612
+ nums = [s.strip() for s in curly_brackets[1:-1].split(",")]
613
+ min_times = 0
614
+ max_times = None
615
+ try:
616
+ if len(nums) == 1:
617
+ min_times = int(nums[0])
618
+ max_times = min_times
619
+ else:
620
+ assert len(nums) == 2
621
+ min_times = int(nums[0]) if nums[0] else 0
622
+ max_times = int(nums[1]) if nums[1] else None
623
+ except ValueError:
624
+ raise ValueError(
625
+ f"Invalid quantifier {curly_brackets} in /{pattern}/"
626
+ )
627
+
628
+ (sub, sub_is_literal) = seq[-1]
629
+
630
+ if not sub_is_literal:
631
+ id = sub_rule_ids.get(sub)
632
+ if id is None:
633
+ id = self._add_rule(f"{name}-{len(sub_rule_ids) + 1}", sub)
634
+ sub_rule_ids[sub] = id
635
+ sub = id
636
+
637
+ seq[-1] = (
638
+ _build_repetition(
639
+ f'"{sub}"' if sub_is_literal else sub,
640
+ min_times,
641
+ max_times,
642
+ item_rule_is_literal=sub_is_literal,
643
+ ),
644
+ False,
645
+ )
646
+ else:
647
+ literal = ""
648
+ while i < length:
649
+ if pattern[i] == "\\" and i < length - 1:
650
+ next = pattern[i + 1]
651
+ if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
652
+ i += 1
653
+ literal += pattern[i]
654
+ i += 1
655
+ else:
656
+ literal += pattern[i : i + 2]
657
+ i += 2
658
+ elif pattern[i] == '"' and not self._raw_pattern:
659
+ literal += '\\"'
660
+ i += 1
661
+ elif pattern[i] not in NON_LITERAL_SET and (
662
+ i == length - 1
663
+ or literal == ""
664
+ or pattern[i + 1] == "."
665
+ or pattern[i + 1] not in NON_LITERAL_SET
666
+ ):
667
+ literal += pattern[i]
668
+ i += 1
669
+ else:
670
+ break
671
+ if literal:
672
+ seq.append((literal, True))
673
+
674
+ return join_seq()
675
+
676
+ return self._add_rule(
677
+ name,
678
+ (
679
+ to_rule(transform())
680
+ if self._raw_pattern
681
+ else '"\\"" ' + to_rule(transform()) + ' "\\"" space'
682
+ ),
683
+ )
684
+
685
+ def _resolve_ref(self, ref):
686
+ ref_name = ref.split("/")[-1]
687
+ if ref_name not in self._rules and ref not in self._refs_being_resolved:
688
+ self._refs_being_resolved.add(ref)
689
+ resolved = self._refs[ref]
690
+ ref_name = self.visit(resolved, ref_name)
691
+ self._refs_being_resolved.remove(ref)
692
+ return ref_name
693
+
694
+ def _generate_constant_rule(self, value):
695
+ return self._format_literal(json.dumps(value))
696
+
697
+ def visit(self, schema, name):
698
+ schema_type = schema.get("type")
699
+ schema_format = schema.get("format")
700
+ rule_name = name + "-" if name in RESERVED_NAMES else name or "root"
701
+
702
+ if (ref := schema.get("$ref")) is not None:
703
+ return self._add_rule(rule_name, self._resolve_ref(ref))
704
+
705
+ elif "oneOf" in schema or "anyOf" in schema:
706
+ return self._add_rule(
707
+ rule_name,
708
+ self._generate_union_rule(name, schema.get("oneOf") or schema["anyOf"]),
709
+ )
710
+
711
+ elif isinstance(schema_type, list):
712
+ return self._add_rule(
713
+ rule_name,
714
+ self._generate_union_rule(name, [{"type": t} for t in schema_type]),
715
+ )
716
+
717
+ elif "const" in schema:
718
+ return self._add_rule(
719
+ rule_name, self._generate_constant_rule(schema["const"])
720
+ )
721
+
722
+ elif "enum" in schema:
723
+ rule = " | ".join((self._generate_constant_rule(v) for v in schema["enum"]))
724
+ return self._add_rule(rule_name, rule)
725
+
726
+ elif schema_type in (None, "object") and (
727
+ "properties" in schema
728
+ or (
729
+ "additionalProperties" in schema
730
+ and schema["additionalProperties"] is not True
731
+ )
732
+ ):
733
+ required = set(schema.get("required", []))
734
+ properties = list(schema.get("properties", {}).items())
735
+ return self._add_rule(
736
+ rule_name,
737
+ self._build_object_rule(
738
+ properties, required, name, schema.get("additionalProperties")
739
+ ),
740
+ )
741
+
742
+ elif schema_type in (None, "object") and "allOf" in schema:
743
+ required = set()
744
+ properties = []
745
+ hybrid_name = name
746
+
747
+ def add_component(comp_schema, is_required):
748
+ if (ref := comp_schema.get("$ref")) is not None:
749
+ comp_schema = self._refs[ref]
750
+
751
+ if "properties" in comp_schema:
752
+ for prop_name, prop_schema in comp_schema["properties"].items():
753
+ properties.append((prop_name, prop_schema))
754
+ if is_required:
755
+ required.add(prop_name)
756
+
757
+ for t in schema["allOf"]:
758
+ if "anyOf" in t:
759
+ for tt in t["anyOf"]:
760
+ add_component(tt, is_required=False)
761
+ else:
762
+ add_component(t, is_required=True)
763
+
764
+ return self._add_rule(
765
+ rule_name,
766
+ self._build_object_rule(
767
+ properties, required, hybrid_name, additional_properties=[]
768
+ ),
769
+ )
770
+
771
+ elif schema_type in (None, "array") and (
772
+ "items" in schema or "prefixItems" in schema
773
+ ):
774
+ items = schema.get("items") or schema["prefixItems"]
775
+ if isinstance(items, list):
776
+ return self._add_rule(
777
+ rule_name,
778
+ '"[" space '
779
+ + ' "," space '.join(
780
+ self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
781
+ for i, item in enumerate(items)
782
+ )
783
+ + ' "]" space',
784
+ )
785
+ else:
786
+ item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
787
+ min_items = schema.get("minItems", 0)
788
+ max_items = schema.get("maxItems")
789
+ return self._add_rule(
790
+ rule_name,
791
+ '"[" space '
792
+ + _build_repetition(
793
+ item_rule_name, min_items, max_items, separator_rule='"," space'
794
+ )
795
+ + ' "]" space',
796
+ )
797
+
798
+ elif schema_type in (None, "string") and "pattern" in schema:
799
+ return self._visit_pattern(schema["pattern"], rule_name)
800
+
801
+ elif schema_type in (None, "string") and re.match(
802
+ r"^uuid[1-5]?$", schema_format or ""
803
+ ):
804
+ return self._add_primitive(
805
+ "root" if rule_name == "root" else schema_format,
806
+ PRIMITIVE_RULES["uuid"],
807
+ )
808
+
809
+ elif (
810
+ schema_type in (None, "string")
811
+ and f"{schema_format}-string" in STRING_FORMAT_RULES
812
+ ):
813
+ prim_name = f"{schema_format}-string"
814
+ return self._add_rule(
815
+ rule_name,
816
+ self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]),
817
+ )
818
+
819
+ elif schema_type == "string" and (
820
+ "minLength" in schema or "maxLength" in schema
821
+ ):
822
+ char_rule = self._add_primitive("char", PRIMITIVE_RULES["char"])
823
+ min_len = schema.get("minLength", 0)
824
+ max_len = schema.get("maxLength")
825
+
826
+ return self._add_rule(
827
+ rule_name,
828
+ r'"\"" '
829
+ + _build_repetition(char_rule, min_len, max_len)
830
+ + r' "\"" space',
831
+ )
832
+
833
+ elif (schema_type == "object") or (len(schema) == 0):
834
+ return self._add_rule(
835
+ rule_name, self._add_primitive("object", PRIMITIVE_RULES["object"])
836
+ )
837
+
838
+ else:
839
+ assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
840
+ # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
841
+ return self._add_primitive(
842
+ "root" if rule_name == "root" else schema_type,
843
+ PRIMITIVE_RULES[schema_type],
844
+ )
845
+
846
+ def _add_primitive(self, name: str, rule: BuiltinRule):
847
+ n = self._add_rule(name, rule.content)
848
+
849
+ for dep in rule.deps:
850
+ dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
851
+ assert dep_rule, f"Rule {dep} not known"
852
+ if dep not in self._rules:
853
+ self._add_primitive(dep, dep_rule)
854
+ return n
855
+
856
+ def _build_object_rule(
857
+ self,
858
+ properties: List[Tuple[str, Any]],
859
+ required: Set[str],
860
+ name: str,
861
+ additional_properties: Union[bool, Any],
862
+ ):
863
+ prop_order = self._prop_order
864
+ # sort by position in prop_order (if specified) then by original order
865
+ sorted_props = [
866
+ kv[0]
867
+ for _, kv in sorted(
868
+ enumerate(properties),
869
+ key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]),
870
+ )
871
+ ]
872
+
873
+ prop_kv_rule_names = {}
874
+ for prop_name, prop_schema in properties:
875
+ prop_rule_name = self.visit(
876
+ prop_schema, f'{name}{"-" if name else ""}{prop_name}'
877
+ )
878
+ prop_kv_rule_names[prop_name] = self._add_rule(
879
+ f'{name}{"-" if name else ""}{prop_name}-kv',
880
+ rf'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}',
881
+ )
882
+ required_props = [k for k in sorted_props if k in required]
883
+ optional_props = [k for k in sorted_props if k not in required]
884
+
885
+ if additional_properties == True or isinstance(additional_properties, dict):
886
+ sub_name = f'{name}{"-" if name else ""}additional'
887
+ value_rule = self.visit(
888
+ {} if additional_properties == True else additional_properties,
889
+ f"{sub_name}-value",
890
+ )
891
+ prop_kv_rule_names["*"] = self._add_rule(
892
+ f"{sub_name}-kv",
893
+ self._add_primitive("string", PRIMITIVE_RULES["string"])
894
+ + f' ":" space {value_rule}',
895
+ )
896
+ optional_props.append("*")
897
+
898
+ rule = '"{" space '
899
+ rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
900
+
901
+ if optional_props:
902
+ rule += " ("
903
+ if required_props:
904
+ rule += ' "," space ( '
905
+
906
+ def get_recursive_refs(ks, first_is_optional):
907
+ [k, *rest] = ks
908
+ kv_rule_name = prop_kv_rule_names[k]
909
+ if k == "*":
910
+ res = self._add_rule(
911
+ f'{name}{"-" if name else ""}additional-kvs',
912
+ f'{kv_rule_name} ( "," space ' + kv_rule_name + " )*",
913
+ )
914
+ elif first_is_optional:
915
+ res = f'( "," space {kv_rule_name} )?'
916
+ else:
917
+ res = kv_rule_name
918
+ if len(rest) > 0:
919
+ res += " " + self._add_rule(
920
+ f'{name}{"-" if name else ""}{k}-rest',
921
+ get_recursive_refs(rest, first_is_optional=True),
922
+ )
923
+ return res
924
+
925
+ rule += " | ".join(
926
+ get_recursive_refs(optional_props[i:], first_is_optional=False)
927
+ for i in range(len(optional_props))
928
+ )
929
+ if required_props:
930
+ rule += " )"
931
+ rule += " )?"
932
+
933
+ rule += ' "}" space'
934
+
935
+ return rule
936
+
937
+ def format_grammar(self):
938
+ return "\n".join(
939
+ f"{name} ::= {rule}"
940
+ for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
941
+ )
942
+
943
+
944
+ def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
945
+ prop_order = prop_order or []
946
+ schema = json.loads(schema)
947
+ prop_order = {name: idx for idx, name in enumerate(prop_order)}
948
+ converter = SchemaConverter(
949
+ prop_order=prop_order, allow_fetch=False, dotall=False, raw_pattern=False
950
+ )
951
+ schema = converter.resolve_refs(schema, "stdin")
952
+ converter.visit(schema, "")
953
+ return converter.format_grammar()