transformers-grammar-constraint 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers_grammar_constraint-0.2.1/PKG-INFO +8 -0
- transformers_grammar_constraint-0.2.1/pyproject.toml +40 -0
- transformers_grammar_constraint-0.2.1/src/transformers_grammar_constraint/__init__.py +16 -0
- transformers_grammar_constraint-0.2.1/src/transformers_grammar_constraint/compat.py +13 -0
- transformers_grammar_constraint-0.2.1/src/transformers_grammar_constraint/grammar.py +457 -0
- transformers_grammar_constraint-0.2.1/src/transformers_grammar_constraint/grammars/__init__.py +6 -0
- transformers_grammar_constraint-0.2.1/src/transformers_grammar_constraint/grammars/json_grammar.py +286 -0
- transformers_grammar_constraint-0.2.1/src/transformers_grammar_constraint/grammars/pgn_grammar.py +220 -0
- transformers_grammar_constraint-0.2.1/src/transformers_grammar_constraint/grammars/san_grammar.py +320 -0
- transformers_grammar_constraint-0.2.1/src/transformers_grammar_constraint/processor.py +252 -0
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: transformers-grammar-constraint
|
|
3
|
+
Version: 0.2.1
|
|
4
|
+
Summary: Grammar-constrained LLM token generation via Lark + HuggingFace Transformers
|
|
5
|
+
Requires-Dist: lark>=1.3.1
|
|
6
|
+
Requires-Dist: torch>=2.10.0
|
|
7
|
+
Requires-Dist: transformers>=4.17.0,<6.0.0
|
|
8
|
+
Requires-Python: >=3.12
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "transformers-grammar-constraint"
|
|
3
|
+
version = "0.2.1"
|
|
4
|
+
description = "Grammar-constrained LLM token generation via Lark + HuggingFace Transformers"
|
|
5
|
+
requires-python = ">=3.12"
|
|
6
|
+
dependencies = [
|
|
7
|
+
"lark>=1.3.1",
|
|
8
|
+
"torch>=2.10.0",
|
|
9
|
+
"transformers>=4.17.0,<6.0.0",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
[tool.uv]
|
|
13
|
+
package = true
|
|
14
|
+
exclude-newer = "3d"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
[build-system]
|
|
18
|
+
requires = ["uv-build>=0.6,<0.12"]
|
|
19
|
+
build-backend = "uv_build"
|
|
20
|
+
|
|
21
|
+
[tool.pytest.ini_options]
|
|
22
|
+
pythonpath = ["src"]
|
|
23
|
+
addopts = "-m 'not slow'"
|
|
24
|
+
markers = [
|
|
25
|
+
"slow: requires GPU and network access to download model weights",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
[tool.ruff.lint]
|
|
29
|
+
select = ["I"]
|
|
30
|
+
|
|
31
|
+
[dependency-groups]
|
|
32
|
+
dev = [
|
|
33
|
+
"pytest>=8.0.0",
|
|
34
|
+
"hypothesis>=6.0.0",
|
|
35
|
+
"python-chess>=1.999",
|
|
36
|
+
"pytest-xdist>=3.0.0",
|
|
37
|
+
"pre-commit>=4.5.1",
|
|
38
|
+
"ruff>=0.15.9",
|
|
39
|
+
"ty>=0.0.29",
|
|
40
|
+
]
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# src/grammar_constrain/__init__.py
|
|
2
|
+
from transformers_grammar_constraint import (
|
|
3
|
+
compat as _compat, # noqa: F401 — applies shim at import time
|
|
4
|
+
)
|
|
5
|
+
from transformers_grammar_constraint.grammar import Grammar, LarkGrammar
|
|
6
|
+
from transformers_grammar_constraint.grammars import JsonGrammar, PgnGrammar, SanGrammar
|
|
7
|
+
from transformers_grammar_constraint.processor import GrammarConstrainedLogitsProcessor
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"Grammar",
|
|
11
|
+
"LarkGrammar",
|
|
12
|
+
"GrammarConstrainedLogitsProcessor",
|
|
13
|
+
"JsonGrammar",
|
|
14
|
+
"SanGrammar",
|
|
15
|
+
"PgnGrammar",
|
|
16
|
+
]
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# src/grammar_constrain/compat.py
|
|
2
|
+
"""
|
|
3
|
+
Compatibility shim for transformers 4.x/5.x.
|
|
4
|
+
|
|
5
|
+
In transformers 4.x, LogitsWarper was a separate base class.
|
|
6
|
+
In 5.x it was removed; warper-style processors now subclass LogitsProcessor directly.
|
|
7
|
+
This module patches the missing name so user code importing LogitsWarper still works.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import transformers
|
|
11
|
+
|
|
12
|
+
if not hasattr(transformers, "LogitsWarper"):
|
|
13
|
+
setattr(transformers, "LogitsWarper", transformers.LogitsProcessor)
|
|
@@ -0,0 +1,457 @@
|
|
|
1
|
+
# src/grammar_constrain/grammar.py
|
|
2
|
+
"""
|
|
3
|
+
Grammar base class.
|
|
4
|
+
|
|
5
|
+
Subclass LarkGrammar and implement `grammar_string` to define a Lark LALR grammar.
|
|
6
|
+
The base class handles Lark parser creation, character-level incremental validation,
|
|
7
|
+
state-based caching, and EOS detection.
|
|
8
|
+
|
|
9
|
+
`Grammar` is an abstract generic base; `LarkGrammar` is the concrete Lark-backed
|
|
10
|
+
implementation. DFA-backed grammars (e.g. SanGrammar) subclass `Grammar[_DfaState]`
|
|
11
|
+
directly.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
import re
|
|
16
|
+
from abc import ABC, abstractmethod
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from typing import Any, Generic, TypeVar
|
|
19
|
+
|
|
20
|
+
from lark import Lark, UnexpectedCharacters, UnexpectedEOF, UnexpectedToken
|
|
21
|
+
from lark.lexer import PatternRE, PatternStr
|
|
22
|
+
|
|
23
|
+
# Minimal set of characters used as probe suffixes when checking whether a
|
|
24
|
+
# decoded token can be the START of a regex terminal (e.g. is `"` a valid
|
|
25
|
+
# prefix of ESCAPED_STRING?). We try completing with each suffix in turn;
|
|
26
|
+
# the first successful fullmatch confirms the text is a valid prefix.
|
|
27
|
+
_REGEX_PROBE_SUFFIXES: tuple[str, ...] = (
|
|
28
|
+
'"',
|
|
29
|
+
"0",
|
|
30
|
+
"a",
|
|
31
|
+
",",
|
|
32
|
+
"]",
|
|
33
|
+
"}",
|
|
34
|
+
"+",
|
|
35
|
+
"-",
|
|
36
|
+
"1",
|
|
37
|
+
")",
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# Characters probed as single-char continuations in the Earley fast-path
|
|
41
|
+
# pre-filter. Covers printable ASCII plus common whitespace; non-ASCII
|
|
42
|
+
# first characters fall through and are skipped automatically.
|
|
43
|
+
_EARLEY_PROBE_CHARS: tuple[str, ...] = tuple(
|
|
44
|
+
chr(i)
|
|
45
|
+
for i in range(9, 127) # HT, LF, VT, FF, CR, then space..~
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class _TextState:
|
|
51
|
+
"""Lightweight parser state for non-LALR (e.g. Earley) grammars.
|
|
52
|
+
|
|
53
|
+
Tracks accumulated text instead of an InteractiveParser object,
|
|
54
|
+
since parse_interactive() is only available for LALR.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
text: str = ""
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class _LALRState:
|
|
62
|
+
"""Wrapper around an LALR InteractiveParser with explicit text tracking.
|
|
63
|
+
|
|
64
|
+
Replaces direct mutation of Lark private attributes (_accumulated_text,
|
|
65
|
+
_partial_prefix) with first-class dataclass fields, eliminating dependence
|
|
66
|
+
on Lark's internal API.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
parser: Any # InteractiveParser
|
|
70
|
+
accumulated_text: str = ""
|
|
71
|
+
partial_prefix: str = ""
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
logger = logging.getLogger(__name__)
|
|
75
|
+
|
|
76
|
+
_StateT = TypeVar("_StateT")
|
|
77
|
+
|
|
78
|
+
# Type alias for the Lark-backed state union used by LarkGrammar.
|
|
79
|
+
_LarkState = _LALRState | _TextState
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class Grammar(ABC, Generic[_StateT]):
|
|
83
|
+
"""Abstract generic base class for grammar-constrained generation.
|
|
84
|
+
|
|
85
|
+
Subclass `LarkGrammar` (and implement `grammar_string`) to define a Lark
|
|
86
|
+
EBNF grammar. For custom state machines, subclass `Grammar[YourStateType]`
|
|
87
|
+
directly and implement all abstract methods.
|
|
88
|
+
|
|
89
|
+
Example (Lark grammar)::
|
|
90
|
+
|
|
91
|
+
class MyGrammar(LarkGrammar):
|
|
92
|
+
@property
|
|
93
|
+
def grammar_string(self) -> str:
|
|
94
|
+
return r'''
|
|
95
|
+
start: greeting name
|
|
96
|
+
greeting: "hello" | "hi"
|
|
97
|
+
name: /[A-Z][a-z]+/
|
|
98
|
+
'''
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
#: Override in subclasses to use a different Lark parser backend.
|
|
102
|
+
#: "lalr" is faster but cannot handle ambiguous grammars (e.g. SAN chess).
|
|
103
|
+
#: "earley" handles ambiguous grammars at the cost of speed.
|
|
104
|
+
#: "none" disables Lark entirely (used by DFA-backed grammars).
|
|
105
|
+
lark_parser: str = "lalr"
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
@abstractmethod
|
|
109
|
+
def grammar_string(self) -> str:
|
|
110
|
+
"""Return a Lark grammar string (EBNF-like Lark syntax)."""
|
|
111
|
+
...
|
|
112
|
+
|
|
113
|
+
@abstractmethod
|
|
114
|
+
def fresh_state(self) -> _StateT:
|
|
115
|
+
"""Return a fresh parser state at the start of the grammar."""
|
|
116
|
+
...
|
|
117
|
+
|
|
118
|
+
@abstractmethod
|
|
119
|
+
def advance_state(self, state: _StateT, text: str) -> _StateT:
|
|
120
|
+
"""Return a new state after feeding `text` through the parser.
|
|
121
|
+
|
|
122
|
+
Raises ValueError if the text is not a valid continuation.
|
|
123
|
+
"""
|
|
124
|
+
...
|
|
125
|
+
|
|
126
|
+
@abstractmethod
|
|
127
|
+
def _state_key(self, state: _StateT) -> tuple:
|
|
128
|
+
"""Return a hashable key representing the current parser state."""
|
|
129
|
+
...
|
|
130
|
+
|
|
131
|
+
@abstractmethod
|
|
132
|
+
def _make_state(self, text: str) -> _StateT | None:
|
|
133
|
+
"""Try to create a parser state from `text`.
|
|
134
|
+
|
|
135
|
+
Returns a state if text is a valid prefix, None if text is invalid.
|
|
136
|
+
"""
|
|
137
|
+
...
|
|
138
|
+
|
|
139
|
+
@abstractmethod
|
|
140
|
+
def _is_accepting(self, state: _StateT) -> bool:
|
|
141
|
+
"""Return True if the parser is in a grammar-complete (accepting) state."""
|
|
142
|
+
...
|
|
143
|
+
|
|
144
|
+
@abstractmethod
|
|
145
|
+
def _is_partial_terminal_extension(self, state: _StateT, new_text: str) -> bool:
|
|
146
|
+
"""Return True if new_text extends a partial terminal at the current state."""
|
|
147
|
+
...
|
|
148
|
+
|
|
149
|
+
@abstractmethod
|
|
150
|
+
def get_valid_token_ids(
|
|
151
|
+
self,
|
|
152
|
+
state: _StateT,
|
|
153
|
+
tokenizer: Any,
|
|
154
|
+
*,
|
|
155
|
+
vocab: dict[str, int] | None = None,
|
|
156
|
+
decoded_vocab: dict[int, str] | None = None,
|
|
157
|
+
) -> frozenset[int]:
|
|
158
|
+
"""Return the set of token IDs that are valid continuations from `state`."""
|
|
159
|
+
...
|
|
160
|
+
|
|
161
|
+
@abstractmethod
|
|
162
|
+
def is_valid_complete(self, text: str) -> bool:
|
|
163
|
+
"""Return True if `text` is a complete, valid string for this grammar."""
|
|
164
|
+
...
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class LarkGrammar(Grammar[_LarkState]):
|
|
168
|
+
"""Concrete Grammar backed by Lark LALR/Earley parser.
|
|
169
|
+
|
|
170
|
+
Subclass this and implement `grammar_string` to define your grammar.
|
|
171
|
+
|
|
172
|
+
Example::
|
|
173
|
+
|
|
174
|
+
class MyGrammar(LarkGrammar):
|
|
175
|
+
@property
|
|
176
|
+
def grammar_string(self) -> str:
|
|
177
|
+
return r'''
|
|
178
|
+
start: greeting name
|
|
179
|
+
greeting: "hello" | "hi"
|
|
180
|
+
name: /[A-Z][a-z]+/
|
|
181
|
+
'''
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
def __init__(self) -> None:
|
|
185
|
+
if self.lark_parser == "none":
|
|
186
|
+
return
|
|
187
|
+
self._parser = Lark(
|
|
188
|
+
self.grammar_string,
|
|
189
|
+
parser=self.lark_parser,
|
|
190
|
+
propagate_positions=False,
|
|
191
|
+
)
|
|
192
|
+
# Cache: state_key -> frozenset of valid token ids
|
|
193
|
+
self._cache: dict[tuple, frozenset[int]] = {}
|
|
194
|
+
# Precompute per-terminal information used in get_valid_token_ids.
|
|
195
|
+
# PatternStr terminals: literal string value (for prefix matching).
|
|
196
|
+
self._terminal_literals: dict[str, str] = {}
|
|
197
|
+
# PatternRE terminals: compiled regex pattern (for fullmatch + probe checks).
|
|
198
|
+
self._terminal_patterns: dict[str, re.Pattern[str]] = {}
|
|
199
|
+
for term in self._parser.terminals:
|
|
200
|
+
if isinstance(term.pattern, PatternStr):
|
|
201
|
+
self._terminal_literals[term.name] = term.pattern.value
|
|
202
|
+
elif isinstance(term.pattern, PatternRE):
|
|
203
|
+
try:
|
|
204
|
+
self._terminal_patterns[term.name] = re.compile(
|
|
205
|
+
term.pattern.value, re.DOTALL
|
|
206
|
+
)
|
|
207
|
+
except re.error:
|
|
208
|
+
pass
|
|
209
|
+
|
|
210
|
+
def fresh_state(self) -> _LarkState:
|
|
211
|
+
"""Return a fresh parser state at the start of the grammar."""
|
|
212
|
+
if self.lark_parser != "lalr":
|
|
213
|
+
return _TextState(text="")
|
|
214
|
+
ip = self._parser.parse_interactive("")
|
|
215
|
+
ip.exhaust_lexer()
|
|
216
|
+
return _LALRState(parser=ip)
|
|
217
|
+
|
|
218
|
+
def advance_state(self, state: _LarkState, text: str) -> _LarkState:
|
|
219
|
+
"""Return a new state after feeding `text` through the parser.
|
|
220
|
+
|
|
221
|
+
Raises ValueError if the text is not a valid continuation.
|
|
222
|
+
"""
|
|
223
|
+
if isinstance(state, _TextState):
|
|
224
|
+
new_text = state.text + text
|
|
225
|
+
new_state = self._make_state(new_text)
|
|
226
|
+
if new_state is None:
|
|
227
|
+
raise ValueError(
|
|
228
|
+
f"Text {text!r} is not a valid continuation for this grammar state"
|
|
229
|
+
)
|
|
230
|
+
return new_state
|
|
231
|
+
new_text = self._state_text(state) + text
|
|
232
|
+
new_ip = self._parser.parse_interactive(new_text)
|
|
233
|
+
try:
|
|
234
|
+
new_ip.exhaust_lexer()
|
|
235
|
+
except (UnexpectedCharacters, UnexpectedToken):
|
|
236
|
+
raise ValueError(
|
|
237
|
+
f"Text {text!r} is not a valid continuation for this grammar state"
|
|
238
|
+
)
|
|
239
|
+
return _LALRState(parser=new_ip, accumulated_text=new_text)
|
|
240
|
+
|
|
241
|
+
def _state_key(self, state: _LarkState) -> tuple:
|
|
242
|
+
"""Return a hashable key representing the current parser state."""
|
|
243
|
+
if isinstance(state, _TextState):
|
|
244
|
+
return (state.text,)
|
|
245
|
+
return (tuple(state.parser.parser_state.state_stack), state.partial_prefix)
|
|
246
|
+
|
|
247
|
+
def _state_text(self, state: _LarkState) -> str:
|
|
248
|
+
"""Return the accumulated text for a parser state."""
|
|
249
|
+
if isinstance(state, _TextState):
|
|
250
|
+
return state.text
|
|
251
|
+
return state.accumulated_text
|
|
252
|
+
|
|
253
|
+
def _make_state(self, text: str) -> _LarkState | None:
|
|
254
|
+
"""Try to create a parser state from `text`.
|
|
255
|
+
|
|
256
|
+
Returns a state if text is a valid prefix, None if text is invalid.
|
|
257
|
+
|
|
258
|
+
For LALR: uses parse_interactive + exhaust_lexer; UnexpectedEOF means
|
|
259
|
+
valid but incomplete prefix.
|
|
260
|
+
|
|
261
|
+
For Earley: uses parse(); UnexpectedEOF means valid but incomplete prefix.
|
|
262
|
+
"""
|
|
263
|
+
if self.lark_parser != "lalr":
|
|
264
|
+
try:
|
|
265
|
+
self._parser.parse(text)
|
|
266
|
+
return _TextState(text=text) # complete valid parse
|
|
267
|
+
except UnexpectedEOF:
|
|
268
|
+
return _TextState(text=text) # valid but incomplete prefix
|
|
269
|
+
except Exception:
|
|
270
|
+
return None
|
|
271
|
+
ip = self._parser.parse_interactive(text)
|
|
272
|
+
try:
|
|
273
|
+
ip.exhaust_lexer()
|
|
274
|
+
return _LALRState(parser=ip, accumulated_text=text)
|
|
275
|
+
except UnexpectedEOF:
|
|
276
|
+
return _LALRState(parser=ip, accumulated_text=text)
|
|
277
|
+
except (UnexpectedCharacters, UnexpectedToken):
|
|
278
|
+
return None
|
|
279
|
+
|
|
280
|
+
def _token_matches_terminal(self, term_name: str, full_candidate: str) -> bool:
|
|
281
|
+
"""Return True if full_candidate is a valid complete or partial match of terminal.
|
|
282
|
+
|
|
283
|
+
For PatternStr: checks whether term_value starts with full_candidate (partial
|
|
284
|
+
prefix) or full_candidate equals term_value (exact match).
|
|
285
|
+
|
|
286
|
+
For PatternRE: checks whether full_candidate fully matches the compiled regex
|
|
287
|
+
OR whether there exists a short probe extension of full_candidate that matches
|
|
288
|
+
(indicating full_candidate is a valid partial prefix of the terminal).
|
|
289
|
+
"""
|
|
290
|
+
term_str = self._terminal_literals.get(term_name)
|
|
291
|
+
if term_str is not None:
|
|
292
|
+
return term_str.startswith(full_candidate)
|
|
293
|
+
compiled = self._terminal_patterns.get(term_name)
|
|
294
|
+
if compiled is not None:
|
|
295
|
+
if compiled.fullmatch(full_candidate):
|
|
296
|
+
return True
|
|
297
|
+
for suffix in _REGEX_PROBE_SUFFIXES:
|
|
298
|
+
if compiled.fullmatch(full_candidate + suffix):
|
|
299
|
+
return True
|
|
300
|
+
return False
|
|
301
|
+
|
|
302
|
+
def _is_partial_terminal_extension(self, state: _LarkState, new_text: str) -> bool:
|
|
303
|
+
"""Return True if new_text extends a partial terminal at the current state.
|
|
304
|
+
|
|
305
|
+
Checks whether (partial_prefix + new_text) matches or is a valid prefix of
|
|
306
|
+
any terminal (PatternStr or PatternRE) expected by the LALR parser at the
|
|
307
|
+
current state. This handles the case where a vocabulary token is only part
|
|
308
|
+
of a multi-character terminal (e.g. the token `"` when the grammar expects
|
|
309
|
+
the string literal `"age"` or the regex terminal ESCAPED_STRING).
|
|
310
|
+
"""
|
|
311
|
+
if isinstance(state, _TextState):
|
|
312
|
+
return False
|
|
313
|
+
full_suffix = state.partial_prefix + new_text
|
|
314
|
+
if not full_suffix:
|
|
315
|
+
return False
|
|
316
|
+
try:
|
|
317
|
+
expected = state.parser.accepts()
|
|
318
|
+
except Exception:
|
|
319
|
+
return False
|
|
320
|
+
return any(self._token_matches_terminal(t, full_suffix) for t in expected)
|
|
321
|
+
|
|
322
|
+
def _is_accepting(self, state: _LarkState) -> bool:
|
|
323
|
+
"""Return True if the parser is in a grammar-complete (accepting) state."""
|
|
324
|
+
if isinstance(state, _TextState):
|
|
325
|
+
try:
|
|
326
|
+
self._parser.parse(state.text)
|
|
327
|
+
return True
|
|
328
|
+
except Exception:
|
|
329
|
+
return False
|
|
330
|
+
try:
|
|
331
|
+
accepts = state.parser.accepts()
|
|
332
|
+
return "$END" in accepts
|
|
333
|
+
except Exception:
|
|
334
|
+
return False
|
|
335
|
+
|
|
336
|
+
def get_valid_token_ids(
|
|
337
|
+
self,
|
|
338
|
+
state: _LarkState,
|
|
339
|
+
tokenizer: Any,
|
|
340
|
+
*,
|
|
341
|
+
vocab: dict[str, int] | None = None,
|
|
342
|
+
decoded_vocab: dict[int, str] | None = None,
|
|
343
|
+
) -> frozenset[int]:
|
|
344
|
+
"""Return the set of token IDs that are valid continuations from `state`.
|
|
345
|
+
|
|
346
|
+
Includes EOS token if the grammar is currently in an accepting state.
|
|
347
|
+
Uses state-based caching: identical LALR states reuse prior results.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
state: Current InteractiveParser state (from fresh_state or advance_state).
|
|
351
|
+
tokenizer: HuggingFace tokenizer with get_vocab() and decode() methods.
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
frozenset of valid token IDs. Empty frozenset signals a dead end.
|
|
355
|
+
|
|
356
|
+
Note:
|
|
357
|
+
For LALR grammars, token-to-terminal matching uses a probe-suffix heuristic to
|
|
358
|
+
detect valid regex-terminal prefixes (see `_REGEX_PROBE_SUFFIXES`). The probe
|
|
359
|
+
set is finite and may miss unusual terminals — tokens that are valid prefixes of
|
|
360
|
+
an uncovered regex pattern will be incorrectly excluded.
|
|
361
|
+
"""
|
|
362
|
+
state_key = self._state_key(state)
|
|
363
|
+
|
|
364
|
+
if state_key in self._cache:
|
|
365
|
+
return self._cache[state_key]
|
|
366
|
+
|
|
367
|
+
_vocab: dict[str, int] = vocab if vocab is not None else tokenizer.get_vocab()
|
|
368
|
+
valid_ids: set[int] = set()
|
|
369
|
+
|
|
370
|
+
if isinstance(state, _TextState):
|
|
371
|
+
# Earley / non-LALR: fall back to full re-parse for each candidate.
|
|
372
|
+
current_text = self._state_text(state)
|
|
373
|
+
|
|
374
|
+
# Fast pre-filter: probe ASCII characters one at a time to find
|
|
375
|
+
# which first characters lead to valid continuations. Tokens
|
|
376
|
+
# whose decoded form starts with an invalid character are skipped
|
|
377
|
+
# without a full re-parse, cutting O(vocab) Earley calls to
|
|
378
|
+
# O(128 + |filtered_vocab|) — critical for large vocabularies.
|
|
379
|
+
valid_first_chars: frozenset[str] = frozenset(
|
|
380
|
+
c
|
|
381
|
+
for c in _EARLEY_PROBE_CHARS
|
|
382
|
+
if self._make_state(current_text + c) is not None
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
for token_str, token_id in _vocab.items():
|
|
386
|
+
if not token_str:
|
|
387
|
+
continue
|
|
388
|
+
decoded = (
|
|
389
|
+
decoded_vocab[token_id]
|
|
390
|
+
if decoded_vocab is not None
|
|
391
|
+
else tokenizer.decode([token_id])
|
|
392
|
+
)
|
|
393
|
+
if not decoded:
|
|
394
|
+
continue
|
|
395
|
+
# Skip tokens whose first character cannot start a valid
|
|
396
|
+
# continuation. Non-ASCII first characters are also filtered
|
|
397
|
+
# because they will not be present in valid_first_chars.
|
|
398
|
+
if decoded[0] not in valid_first_chars:
|
|
399
|
+
continue
|
|
400
|
+
if self._make_state(current_text + decoded) is not None:
|
|
401
|
+
valid_ids.add(token_id)
|
|
402
|
+
else:
|
|
403
|
+
# LALR: use state.accepts() to determine the expected terminals, then
|
|
404
|
+
# check each vocabulary token against those terminals directly. This
|
|
405
|
+
# avoids the false-positive that arises when _make_state() re-parses
|
|
406
|
+
# the whole accumulated text and the greedy lexer produces a different
|
|
407
|
+
# tokenisation (e.g. "00" instead of "0" + "0") that happens to be
|
|
408
|
+
# valid even though the incremental parser has already committed past
|
|
409
|
+
# the token boundary.
|
|
410
|
+
partial = state.partial_prefix
|
|
411
|
+
try:
|
|
412
|
+
expected_terms = frozenset(state.parser.accepts())
|
|
413
|
+
except Exception:
|
|
414
|
+
expected_terms = frozenset()
|
|
415
|
+
|
|
416
|
+
for token_str, token_id in _vocab.items():
|
|
417
|
+
if not token_str:
|
|
418
|
+
continue
|
|
419
|
+
decoded = (
|
|
420
|
+
decoded_vocab[token_id]
|
|
421
|
+
if decoded_vocab is not None
|
|
422
|
+
else tokenizer.decode([token_id])
|
|
423
|
+
)
|
|
424
|
+
if not decoded:
|
|
425
|
+
continue
|
|
426
|
+
full_candidate = partial + decoded
|
|
427
|
+
if any(
|
|
428
|
+
self._token_matches_terminal(t, full_candidate)
|
|
429
|
+
for t in expected_terms
|
|
430
|
+
):
|
|
431
|
+
valid_ids.add(token_id)
|
|
432
|
+
|
|
433
|
+
# Include EOS if grammar is in accepting state
|
|
434
|
+
if self._is_accepting(state) and hasattr(tokenizer, "eos_token_id"):
|
|
435
|
+
eos_id = tokenizer.eos_token_id
|
|
436
|
+
if eos_id is not None:
|
|
437
|
+
valid_ids.add(eos_id)
|
|
438
|
+
|
|
439
|
+
if not valid_ids:
|
|
440
|
+
logger.warning(
|
|
441
|
+
"Grammar reached a dead end: no valid tokens. "
|
|
442
|
+
"Allowing all tokens to prevent hanging generation."
|
|
443
|
+
)
|
|
444
|
+
result = frozenset(_vocab.values())
|
|
445
|
+
else:
|
|
446
|
+
result = frozenset(valid_ids)
|
|
447
|
+
|
|
448
|
+
self._cache[state_key] = result
|
|
449
|
+
return result
|
|
450
|
+
|
|
451
|
+
def is_valid_complete(self, text: str) -> bool:
|
|
452
|
+
"""Return True if `text` is a complete, valid string for this grammar."""
|
|
453
|
+
try:
|
|
454
|
+
self._parser.parse(text)
|
|
455
|
+
return True
|
|
456
|
+
except Exception:
|
|
457
|
+
return False
|
transformers_grammar_constraint-0.2.1/src/transformers_grammar_constraint/grammars/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
# src/grammar_constrain/grammars/__init__.py
|
|
2
|
+
from transformers_grammar_constraint.grammars.json_grammar import JsonGrammar
|
|
3
|
+
from transformers_grammar_constraint.grammars.pgn_grammar import PgnGrammar
|
|
4
|
+
from transformers_grammar_constraint.grammars.san_grammar import SanGrammar
|
|
5
|
+
|
|
6
|
+
__all__ = ["JsonGrammar", "SanGrammar", "PgnGrammar"]
|