transformers-grammar-constraint 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,8 @@
1
+ Metadata-Version: 2.3
2
+ Name: transformers-grammar-constraint
3
+ Version: 0.2.1
4
+ Summary: Grammar-constrained LLM token generation via Lark + HuggingFace Transformers
5
+ Requires-Dist: lark>=1.3.1
6
+ Requires-Dist: torch>=2.10.0
7
+ Requires-Dist: transformers>=4.17.0,<6.0.0
8
+ Requires-Python: >=3.12
@@ -0,0 +1,40 @@
1
+ [project]
2
+ name = "transformers-grammar-constraint"
3
+ version = "0.2.1"
4
+ description = "Grammar-constrained LLM token generation via Lark + HuggingFace Transformers"
5
+ requires-python = ">=3.12"
6
+ dependencies = [
7
+ "lark>=1.3.1",
8
+ "torch>=2.10.0",
9
+ "transformers>=4.17.0,<6.0.0",
10
+ ]
11
+
12
+ [tool.uv]
13
+ package = true
14
+ exclude-newer = "3d"
15
+
16
+
17
+ [build-system]
18
+ requires = ["uv-build>=0.6,<0.12"]
19
+ build-backend = "uv_build"
20
+
21
+ [tool.pytest.ini_options]
22
+ pythonpath = ["src"]
23
+ addopts = "-m 'not slow'"
24
+ markers = [
25
+ "slow: requires GPU and network access to download model weights",
26
+ ]
27
+
28
+ [tool.ruff.lint]
29
+ select = ["I"]
30
+
31
+ [dependency-groups]
32
+ dev = [
33
+ "pytest>=8.0.0",
34
+ "hypothesis>=6.0.0",
35
+ "python-chess>=1.999",
36
+ "pytest-xdist>=3.0.0",
37
+ "pre-commit>=4.5.1",
38
+ "ruff>=0.15.9",
39
+ "ty>=0.0.29",
40
+ ]
@@ -0,0 +1,16 @@
1
+ # src/grammar_constrain/__init__.py
2
+ from transformers_grammar_constraint import (
3
+ compat as _compat, # noqa: F401 — applies shim at import time
4
+ )
5
+ from transformers_grammar_constraint.grammar import Grammar, LarkGrammar
6
+ from transformers_grammar_constraint.grammars import JsonGrammar, PgnGrammar, SanGrammar
7
+ from transformers_grammar_constraint.processor import GrammarConstrainedLogitsProcessor
8
+
9
+ __all__ = [
10
+ "Grammar",
11
+ "LarkGrammar",
12
+ "GrammarConstrainedLogitsProcessor",
13
+ "JsonGrammar",
14
+ "SanGrammar",
15
+ "PgnGrammar",
16
+ ]
@@ -0,0 +1,13 @@
1
+ # src/grammar_constrain/compat.py
2
+ """
3
+ Compatibility shim for transformers 4.x/5.x.
4
+
5
+ In transformers 4.x, LogitsWarper was a separate base class.
6
+ In 5.x it was removed; warper-style processors now subclass LogitsProcessor directly.
7
+ This module patches the missing name so user code importing LogitsWarper still works.
8
+ """
9
+
10
+ import transformers
11
+
12
+ if not hasattr(transformers, "LogitsWarper"):
13
+ setattr(transformers, "LogitsWarper", transformers.LogitsProcessor)
@@ -0,0 +1,457 @@
1
+ # src/grammar_constrain/grammar.py
2
+ """
3
+ Grammar base class.
4
+
5
+ Subclass LarkGrammar and implement `grammar_string` to define a Lark LALR grammar.
6
+ The base class handles Lark parser creation, character-level incremental validation,
7
+ state-based caching, and EOS detection.
8
+
9
+ `Grammar` is an abstract generic base; `LarkGrammar` is the concrete Lark-backed
10
+ implementation. DFA-backed grammars (e.g. SanGrammar) subclass `Grammar[_DfaState]`
11
+ directly.
12
+ """
13
+
14
+ import logging
15
+ import re
16
+ from abc import ABC, abstractmethod
17
+ from dataclasses import dataclass
18
+ from typing import Any, Generic, TypeVar
19
+
20
+ from lark import Lark, UnexpectedCharacters, UnexpectedEOF, UnexpectedToken
21
+ from lark.lexer import PatternRE, PatternStr
22
+
23
+ # Minimal set of characters used as probe suffixes when checking whether a
24
+ # decoded token can be the START of a regex terminal (e.g. is `"` a valid
25
+ # prefix of ESCAPED_STRING?). We try completing with each suffix in turn;
26
+ # the first successful fullmatch confirms the text is a valid prefix.
27
+ _REGEX_PROBE_SUFFIXES: tuple[str, ...] = (
28
+ '"',
29
+ "0",
30
+ "a",
31
+ ",",
32
+ "]",
33
+ "}",
34
+ "+",
35
+ "-",
36
+ "1",
37
+ ")",
38
+ )
39
+
40
+ # Characters probed as single-char continuations in the Earley fast-path
41
+ # pre-filter. Covers printable ASCII plus common whitespace; non-ASCII
42
+ # first characters fall through and are skipped automatically.
43
+ _EARLEY_PROBE_CHARS: tuple[str, ...] = tuple(
44
+ chr(i)
45
+ for i in range(9, 127) # HT, LF, VT, FF, CR, then space..~
46
+ )
47
+
48
+
49
+ @dataclass
50
+ class _TextState:
51
+ """Lightweight parser state for non-LALR (e.g. Earley) grammars.
52
+
53
+ Tracks accumulated text instead of an InteractiveParser object,
54
+ since parse_interactive() is only available for LALR.
55
+ """
56
+
57
+ text: str = ""
58
+
59
+
60
+ @dataclass
61
+ class _LALRState:
62
+ """Wrapper around an LALR InteractiveParser with explicit text tracking.
63
+
64
+ Replaces direct mutation of Lark private attributes (_accumulated_text,
65
+ _partial_prefix) with first-class dataclass fields, eliminating dependence
66
+ on Lark's internal API.
67
+ """
68
+
69
+ parser: Any # InteractiveParser
70
+ accumulated_text: str = ""
71
+ partial_prefix: str = ""
72
+
73
+
74
+ logger = logging.getLogger(__name__)
75
+
76
+ _StateT = TypeVar("_StateT")
77
+
78
+ # Type alias for the Lark-backed state union used by LarkGrammar.
79
+ _LarkState = _LALRState | _TextState
80
+
81
+
82
+ class Grammar(ABC, Generic[_StateT]):
83
+ """Abstract generic base class for grammar-constrained generation.
84
+
85
+ Subclass `LarkGrammar` (and implement `grammar_string`) to define a Lark
86
+ EBNF grammar. For custom state machines, subclass `Grammar[YourStateType]`
87
+ directly and implement all abstract methods.
88
+
89
+ Example (Lark grammar)::
90
+
91
+ class MyGrammar(LarkGrammar):
92
+ @property
93
+ def grammar_string(self) -> str:
94
+ return r'''
95
+ start: greeting name
96
+ greeting: "hello" | "hi"
97
+ name: /[A-Z][a-z]+/
98
+ '''
99
+ """
100
+
101
+ #: Override in subclasses to use a different Lark parser backend.
102
+ #: "lalr" is faster but cannot handle ambiguous grammars (e.g. SAN chess).
103
+ #: "earley" handles ambiguous grammars at the cost of speed.
104
+ #: "none" disables Lark entirely (used by DFA-backed grammars).
105
+ lark_parser: str = "lalr"
106
+
107
+ @property
108
+ @abstractmethod
109
+ def grammar_string(self) -> str:
110
+ """Return a Lark grammar string (EBNF-like Lark syntax)."""
111
+ ...
112
+
113
+ @abstractmethod
114
+ def fresh_state(self) -> _StateT:
115
+ """Return a fresh parser state at the start of the grammar."""
116
+ ...
117
+
118
+ @abstractmethod
119
+ def advance_state(self, state: _StateT, text: str) -> _StateT:
120
+ """Return a new state after feeding `text` through the parser.
121
+
122
+ Raises ValueError if the text is not a valid continuation.
123
+ """
124
+ ...
125
+
126
+ @abstractmethod
127
+ def _state_key(self, state: _StateT) -> tuple:
128
+ """Return a hashable key representing the current parser state."""
129
+ ...
130
+
131
+ @abstractmethod
132
+ def _make_state(self, text: str) -> _StateT | None:
133
+ """Try to create a parser state from `text`.
134
+
135
+ Returns a state if text is a valid prefix, None if text is invalid.
136
+ """
137
+ ...
138
+
139
+ @abstractmethod
140
+ def _is_accepting(self, state: _StateT) -> bool:
141
+ """Return True if the parser is in a grammar-complete (accepting) state."""
142
+ ...
143
+
144
+ @abstractmethod
145
+ def _is_partial_terminal_extension(self, state: _StateT, new_text: str) -> bool:
146
+ """Return True if new_text extends a partial terminal at the current state."""
147
+ ...
148
+
149
+ @abstractmethod
150
+ def get_valid_token_ids(
151
+ self,
152
+ state: _StateT,
153
+ tokenizer: Any,
154
+ *,
155
+ vocab: dict[str, int] | None = None,
156
+ decoded_vocab: dict[int, str] | None = None,
157
+ ) -> frozenset[int]:
158
+ """Return the set of token IDs that are valid continuations from `state`."""
159
+ ...
160
+
161
+ @abstractmethod
162
+ def is_valid_complete(self, text: str) -> bool:
163
+ """Return True if `text` is a complete, valid string for this grammar."""
164
+ ...
165
+
166
+
167
+ class LarkGrammar(Grammar[_LarkState]):
168
+ """Concrete Grammar backed by Lark LALR/Earley parser.
169
+
170
+ Subclass this and implement `grammar_string` to define your grammar.
171
+
172
+ Example::
173
+
174
+ class MyGrammar(LarkGrammar):
175
+ @property
176
+ def grammar_string(self) -> str:
177
+ return r'''
178
+ start: greeting name
179
+ greeting: "hello" | "hi"
180
+ name: /[A-Z][a-z]+/
181
+ '''
182
+ """
183
+
184
+ def __init__(self) -> None:
185
+ if self.lark_parser == "none":
186
+ return
187
+ self._parser = Lark(
188
+ self.grammar_string,
189
+ parser=self.lark_parser,
190
+ propagate_positions=False,
191
+ )
192
+ # Cache: state_key -> frozenset of valid token ids
193
+ self._cache: dict[tuple, frozenset[int]] = {}
194
+ # Precompute per-terminal information used in get_valid_token_ids.
195
+ # PatternStr terminals: literal string value (for prefix matching).
196
+ self._terminal_literals: dict[str, str] = {}
197
+ # PatternRE terminals: compiled regex pattern (for fullmatch + probe checks).
198
+ self._terminal_patterns: dict[str, re.Pattern[str]] = {}
199
+ for term in self._parser.terminals:
200
+ if isinstance(term.pattern, PatternStr):
201
+ self._terminal_literals[term.name] = term.pattern.value
202
+ elif isinstance(term.pattern, PatternRE):
203
+ try:
204
+ self._terminal_patterns[term.name] = re.compile(
205
+ term.pattern.value, re.DOTALL
206
+ )
207
+ except re.error:
208
+ pass
209
+
210
+ def fresh_state(self) -> _LarkState:
211
+ """Return a fresh parser state at the start of the grammar."""
212
+ if self.lark_parser != "lalr":
213
+ return _TextState(text="")
214
+ ip = self._parser.parse_interactive("")
215
+ ip.exhaust_lexer()
216
+ return _LALRState(parser=ip)
217
+
218
+ def advance_state(self, state: _LarkState, text: str) -> _LarkState:
219
+ """Return a new state after feeding `text` through the parser.
220
+
221
+ Raises ValueError if the text is not a valid continuation.
222
+ """
223
+ if isinstance(state, _TextState):
224
+ new_text = state.text + text
225
+ new_state = self._make_state(new_text)
226
+ if new_state is None:
227
+ raise ValueError(
228
+ f"Text {text!r} is not a valid continuation for this grammar state"
229
+ )
230
+ return new_state
231
+ new_text = self._state_text(state) + text
232
+ new_ip = self._parser.parse_interactive(new_text)
233
+ try:
234
+ new_ip.exhaust_lexer()
235
+ except (UnexpectedCharacters, UnexpectedToken):
236
+ raise ValueError(
237
+ f"Text {text!r} is not a valid continuation for this grammar state"
238
+ )
239
+ return _LALRState(parser=new_ip, accumulated_text=new_text)
240
+
241
+ def _state_key(self, state: _LarkState) -> tuple:
242
+ """Return a hashable key representing the current parser state."""
243
+ if isinstance(state, _TextState):
244
+ return (state.text,)
245
+ return (tuple(state.parser.parser_state.state_stack), state.partial_prefix)
246
+
247
+ def _state_text(self, state: _LarkState) -> str:
248
+ """Return the accumulated text for a parser state."""
249
+ if isinstance(state, _TextState):
250
+ return state.text
251
+ return state.accumulated_text
252
+
253
+ def _make_state(self, text: str) -> _LarkState | None:
254
+ """Try to create a parser state from `text`.
255
+
256
+ Returns a state if text is a valid prefix, None if text is invalid.
257
+
258
+ For LALR: uses parse_interactive + exhaust_lexer; UnexpectedEOF means
259
+ valid but incomplete prefix.
260
+
261
+ For Earley: uses parse(); UnexpectedEOF means valid but incomplete prefix.
262
+ """
263
+ if self.lark_parser != "lalr":
264
+ try:
265
+ self._parser.parse(text)
266
+ return _TextState(text=text) # complete valid parse
267
+ except UnexpectedEOF:
268
+ return _TextState(text=text) # valid but incomplete prefix
269
+ except Exception:
270
+ return None
271
+ ip = self._parser.parse_interactive(text)
272
+ try:
273
+ ip.exhaust_lexer()
274
+ return _LALRState(parser=ip, accumulated_text=text)
275
+ except UnexpectedEOF:
276
+ return _LALRState(parser=ip, accumulated_text=text)
277
+ except (UnexpectedCharacters, UnexpectedToken):
278
+ return None
279
+
280
+ def _token_matches_terminal(self, term_name: str, full_candidate: str) -> bool:
281
+ """Return True if full_candidate is a valid complete or partial match of terminal.
282
+
283
+ For PatternStr: checks whether term_value starts with full_candidate (partial
284
+ prefix) or full_candidate equals term_value (exact match).
285
+
286
+ For PatternRE: checks whether full_candidate fully matches the compiled regex
287
+ OR whether there exists a short probe extension of full_candidate that matches
288
+ (indicating full_candidate is a valid partial prefix of the terminal).
289
+ """
290
+ term_str = self._terminal_literals.get(term_name)
291
+ if term_str is not None:
292
+ return term_str.startswith(full_candidate)
293
+ compiled = self._terminal_patterns.get(term_name)
294
+ if compiled is not None:
295
+ if compiled.fullmatch(full_candidate):
296
+ return True
297
+ for suffix in _REGEX_PROBE_SUFFIXES:
298
+ if compiled.fullmatch(full_candidate + suffix):
299
+ return True
300
+ return False
301
+
302
+ def _is_partial_terminal_extension(self, state: _LarkState, new_text: str) -> bool:
303
+ """Return True if new_text extends a partial terminal at the current state.
304
+
305
+ Checks whether (partial_prefix + new_text) matches or is a valid prefix of
306
+ any terminal (PatternStr or PatternRE) expected by the LALR parser at the
307
+ current state. This handles the case where a vocabulary token is only part
308
+ of a multi-character terminal (e.g. the token `"` when the grammar expects
309
+ the string literal `"age"` or the regex terminal ESCAPED_STRING).
310
+ """
311
+ if isinstance(state, _TextState):
312
+ return False
313
+ full_suffix = state.partial_prefix + new_text
314
+ if not full_suffix:
315
+ return False
316
+ try:
317
+ expected = state.parser.accepts()
318
+ except Exception:
319
+ return False
320
+ return any(self._token_matches_terminal(t, full_suffix) for t in expected)
321
+
322
+ def _is_accepting(self, state: _LarkState) -> bool:
323
+ """Return True if the parser is in a grammar-complete (accepting) state."""
324
+ if isinstance(state, _TextState):
325
+ try:
326
+ self._parser.parse(state.text)
327
+ return True
328
+ except Exception:
329
+ return False
330
+ try:
331
+ accepts = state.parser.accepts()
332
+ return "$END" in accepts
333
+ except Exception:
334
+ return False
335
+
336
+ def get_valid_token_ids(
337
+ self,
338
+ state: _LarkState,
339
+ tokenizer: Any,
340
+ *,
341
+ vocab: dict[str, int] | None = None,
342
+ decoded_vocab: dict[int, str] | None = None,
343
+ ) -> frozenset[int]:
344
+ """Return the set of token IDs that are valid continuations from `state`.
345
+
346
+ Includes EOS token if the grammar is currently in an accepting state.
347
+ Uses state-based caching: identical LALR states reuse prior results.
348
+
349
+ Args:
350
+ state: Current InteractiveParser state (from fresh_state or advance_state).
351
+ tokenizer: HuggingFace tokenizer with get_vocab() and decode() methods.
352
+
353
+ Returns:
354
+ frozenset of valid token IDs. Empty frozenset signals a dead end.
355
+
356
+ Note:
357
+ For LALR grammars, token-to-terminal matching uses a probe-suffix heuristic to
358
+ detect valid regex-terminal prefixes (see `_REGEX_PROBE_SUFFIXES`). The probe
359
+ set is finite and may miss unusual terminals — tokens that are valid prefixes of
360
+ an uncovered regex pattern will be incorrectly excluded.
361
+ """
362
+ state_key = self._state_key(state)
363
+
364
+ if state_key in self._cache:
365
+ return self._cache[state_key]
366
+
367
+ _vocab: dict[str, int] = vocab if vocab is not None else tokenizer.get_vocab()
368
+ valid_ids: set[int] = set()
369
+
370
+ if isinstance(state, _TextState):
371
+ # Earley / non-LALR: fall back to full re-parse for each candidate.
372
+ current_text = self._state_text(state)
373
+
374
+ # Fast pre-filter: probe ASCII characters one at a time to find
375
+ # which first characters lead to valid continuations. Tokens
376
+ # whose decoded form starts with an invalid character are skipped
377
+ # without a full re-parse, cutting O(vocab) Earley calls to
378
+ # O(128 + |filtered_vocab|) — critical for large vocabularies.
379
+ valid_first_chars: frozenset[str] = frozenset(
380
+ c
381
+ for c in _EARLEY_PROBE_CHARS
382
+ if self._make_state(current_text + c) is not None
383
+ )
384
+
385
+ for token_str, token_id in _vocab.items():
386
+ if not token_str:
387
+ continue
388
+ decoded = (
389
+ decoded_vocab[token_id]
390
+ if decoded_vocab is not None
391
+ else tokenizer.decode([token_id])
392
+ )
393
+ if not decoded:
394
+ continue
395
+ # Skip tokens whose first character cannot start a valid
396
+ # continuation. Non-ASCII first characters are also filtered
397
+ # because they will not be present in valid_first_chars.
398
+ if decoded[0] not in valid_first_chars:
399
+ continue
400
+ if self._make_state(current_text + decoded) is not None:
401
+ valid_ids.add(token_id)
402
+ else:
403
+ # LALR: use state.accepts() to determine the expected terminals, then
404
+ # check each vocabulary token against those terminals directly. This
405
+ # avoids the false-positive that arises when _make_state() re-parses
406
+ # the whole accumulated text and the greedy lexer produces a different
407
+ # tokenisation (e.g. "00" instead of "0" + "0") that happens to be
408
+ # valid even though the incremental parser has already committed past
409
+ # the token boundary.
410
+ partial = state.partial_prefix
411
+ try:
412
+ expected_terms = frozenset(state.parser.accepts())
413
+ except Exception:
414
+ expected_terms = frozenset()
415
+
416
+ for token_str, token_id in _vocab.items():
417
+ if not token_str:
418
+ continue
419
+ decoded = (
420
+ decoded_vocab[token_id]
421
+ if decoded_vocab is not None
422
+ else tokenizer.decode([token_id])
423
+ )
424
+ if not decoded:
425
+ continue
426
+ full_candidate = partial + decoded
427
+ if any(
428
+ self._token_matches_terminal(t, full_candidate)
429
+ for t in expected_terms
430
+ ):
431
+ valid_ids.add(token_id)
432
+
433
+ # Include EOS if grammar is in accepting state
434
+ if self._is_accepting(state) and hasattr(tokenizer, "eos_token_id"):
435
+ eos_id = tokenizer.eos_token_id
436
+ if eos_id is not None:
437
+ valid_ids.add(eos_id)
438
+
439
+ if not valid_ids:
440
+ logger.warning(
441
+ "Grammar reached a dead end: no valid tokens. "
442
+ "Allowing all tokens to prevent hanging generation."
443
+ )
444
+ result = frozenset(_vocab.values())
445
+ else:
446
+ result = frozenset(valid_ids)
447
+
448
+ self._cache[state_key] = result
449
+ return result
450
+
451
+ def is_valid_complete(self, text: str) -> bool:
452
+ """Return True if `text` is a complete, valid string for this grammar."""
453
+ try:
454
+ self._parser.parse(text)
455
+ return True
456
+ except Exception:
457
+ return False
@@ -0,0 +1,6 @@
1
+ # src/grammar_constrain/grammars/__init__.py
2
+ from transformers_grammar_constraint.grammars.json_grammar import JsonGrammar
3
+ from transformers_grammar_constraint.grammars.pgn_grammar import PgnGrammar
4
+ from transformers_grammar_constraint.grammars.san_grammar import SanGrammar
5
+
6
+ __all__ = ["JsonGrammar", "SanGrammar", "PgnGrammar"]