renderers 0.1.8.dev1__tar.gz → 0.1.9.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/.github/workflows/publish.yml +32 -8
  2. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/PKG-INFO +2 -1
  3. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/pyproject.toml +6 -0
  4. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/_version.py +2 -2
  5. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/base.py +299 -23
  6. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/deepseek_v3.py +79 -35
  7. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/default.py +6 -1
  8. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/glm45.py +101 -46
  9. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/glm5.py +101 -42
  10. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/gpt_oss.py +77 -16
  11. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/kimi_k2.py +111 -68
  12. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/kimi_k25.py +119 -60
  13. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/laguna_xs2.py +93 -43
  14. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/minimax_m2.py +118 -46
  15. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/nemotron3.py +101 -59
  16. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/qwen3.py +101 -52
  17. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/qwen35.py +127 -74
  18. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/qwen3_vl.py +122 -65
  19. renderers-0.1.9.dev0/tests/test_load_tokenizer_fastokens.py +172 -0
  20. renderers-0.1.9.dev0/tests/test_sampled_mask.py +119 -0
  21. renderers-0.1.9.dev0/tests/test_tokens_per_message.py +325 -0
  22. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/uv.lock +4 -5
  23. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/.github/workflows/style.yml +0 -0
  24. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/.github/workflows/test.yml +0 -0
  25. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/.gitignore +0 -0
  26. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/.pre-commit-config.yaml +0 -0
  27. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/LICENSE +0 -0
  28. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/README.md +0 -0
  29. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/examples/README.md +0 -0
  30. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/examples/sglang/multiturn_generate_sglang.py +0 -0
  31. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/examples/sglang/online_multiturn_sglang.py +0 -0
  32. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/examples/tinker/multiturn_generate_tinker.py +0 -0
  33. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/examples/transformers/multiturn_generate_transformers.py +0 -0
  34. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/examples/vllm/multiturn_generate_vllm.py +0 -0
  35. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/__init__.py +0 -0
  36. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/client.py +0 -0
  37. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/parsers.py +0 -0
  38. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/parsing.py +0 -0
  39. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/renderers/qwen36.py +0 -0
  40. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/tests/conftest.py +0 -0
  41. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/tests/test_bridge.py +0 -0
  42. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/tests/test_build_helpers.py +0 -0
  43. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/tests/test_client.py +0 -0
  44. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/tests/test_gpt_oss_harmony_parity.py +0 -0
  45. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/tests/test_incremental.py +0 -0
  46. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/tests/test_load_tokenizer.py +0 -0
  47. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/tests/test_message_indices.py +0 -0
  48. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/tests/test_multimodal.py +0 -0
  49. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/tests/test_parse_response.py +0 -0
  50. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/tests/test_parse_response_robustness.py +0 -0
  51. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/tests/test_parsers.py +0 -0
  52. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/tests/test_preserve_thinking.py +0 -0
  53. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/tests/test_qwen35_size_coverage.py +0 -0
  54. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/tests/test_render_ids.py +0 -0
  55. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/tests/test_roundtrip.py +0 -0
  56. {renderers-0.1.8.dev1 → renderers-0.1.9.dev0}/tests/test_tool_arg_type_preservation.py +0 -0
@@ -12,8 +12,16 @@ on:
12
12
  - "renderers-v*"
13
13
 
14
14
  jobs:
15
- publish:
15
+ # Build (no OIDC) → publish (OIDC only). The build job runs uv build with
16
+ # contents: read only so a poisoned build-time dep cannot mint the OIDC
17
+ # token. The publish job has id-token: write and the pypi-prod environment
18
+ # but no source checkout — it only downloads the prebuilt artifact and runs
19
+ # the SHA-pinned pypa publish action.
20
+ build:
21
+ if: github.event_name == 'workflow_dispatch' || startsWith(github.ref, 'refs/tags/renderers-v')
16
22
  runs-on: ubuntu-latest
23
+ permissions:
24
+ contents: read
17
25
  steps:
18
26
  - name: Checkout tagged release (dispatch)
19
27
  if: github.event_name == 'workflow_dispatch'
@@ -28,8 +36,7 @@ jobs:
28
36
  with:
29
37
  fetch-depth: 0
30
38
 
31
- - name: Resolve release tag
32
- id: release
39
+ - name: Validate release tag
33
40
  env:
34
41
  EVENT_NAME: ${{ github.event_name }}
35
42
  PUSHED_REF: ${{ github.ref_name }}
@@ -53,14 +60,31 @@ jobs:
53
60
  ;;
54
61
  esac
55
62
 
56
- echo "tag=$TAG" >> "$GITHUB_OUTPUT"
57
-
58
63
  - uses: astral-sh/setup-uv@v7
59
64
 
60
65
  - name: Build renderers
61
66
  run: uv build
62
67
 
68
+ - name: Upload dist artifacts
69
+ uses: actions/upload-artifact@v4
70
+ with:
71
+ name: dist
72
+ path: dist/
73
+ if-no-files-found: error
74
+ retention-days: 7
75
+
76
+ publish:
77
+ needs: build
78
+ runs-on: ubuntu-latest
79
+ environment: pypi-prod
80
+ permissions:
81
+ id-token: write
82
+ steps:
83
+ - name: Download dist artifacts
84
+ uses: actions/download-artifact@v4
85
+ with:
86
+ name: dist
87
+ path: dist/
88
+
63
89
  - name: Publish to PyPI
64
- env:
65
- PYPI_RENDERERS_TOKEN: ${{ secrets.PYPI_RENDERERS_TOKEN }}
66
- run: uv publish --token "$PYPI_RENDERERS_TOKEN" dist/*
90
+ uses: pypa/gh-action-pypi-publish@cef221092ed1bacb1cc03d23a2d87d1d172e277b # v1.14.0
@@ -1,10 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: renderers
3
- Version: 0.1.8.dev1
3
+ Version: 0.1.9.dev0
4
4
  Summary: Chat template renderers — deterministic message-to-token conversion for LLM training
5
5
  License-Expression: Apache-2.0
6
6
  License-File: LICENSE
7
7
  Requires-Python: <3.14,>=3.10
8
+ Requires-Dist: fastokens>=0.1.1
8
9
  Requires-Dist: jinja2
9
10
  Requires-Dist: numpy
10
11
  Requires-Dist: openai-harmony>=0.0.8
@@ -24,6 +24,12 @@ dependencies = [
24
24
  # OpenAI's reference implementation keeps us byte-identical with vLLM
25
25
  # (which also uses it) and saves us mirroring a 330-line Jinja template.
26
26
  "openai-harmony>=0.0.8",
27
+ # Crusoe's Rust BPE tokenizer; ~10x faster encode vs HF's tokenizers.
28
+ # ``load_tokenizer`` patches it in by default for every supported model
29
+ # except a small denylist (DeepSeek-V3 family, MiniMax-M2 family). The
30
+ # patch is bracketed around ``from_pretrained``, so subsequent
31
+ # ``AutoTokenizer`` calls outside the renderers package stay vanilla.
32
+ "fastokens>=0.1.1",
27
33
  ]
28
34
 
29
35
  [tool.hatch.version]
@@ -18,7 +18,7 @@ version_tuple: tuple[int | str, ...]
18
18
  commit_id: str | None
19
19
  __commit_id__: str | None
20
20
 
21
- __version__ = version = '0.1.8.dev1'
22
- __version_tuple__ = version_tuple = (0, 1, 8, 'dev1')
21
+ __version__ = version = '0.1.9.dev0'
22
+ __version_tuple__ = version_tuple = (0, 1, 9, 'dev0')
23
23
 
24
24
  __commit_id__ = commit_id = None
@@ -148,8 +148,26 @@ class RenderedTokens:
148
148
  """Result of rendering messages to tokens.
149
149
 
150
150
  Each token carries an index into the original message list so callers can
151
- build per-token loss masks without re-rendering. Tokens from structural
152
- scaffolding (generation prompt, im_start/im_end wrapping) carry index -1.
151
+ build per-token loss masks without re-rendering. Tokens from structural
152
+ scaffolding the renderer adds outside any single message (e.g. the
153
+ trailing generation prompt) carry index ``-1``.
154
+
155
+ ``sampled_mask`` is a separate per-token signal: ``True`` if the model
156
+ would have produced this token at inference time (i.e. it appears in
157
+ the sampled completion), ``False`` if it is template-injected
158
+ scaffolding the model never emits (``<|im_start|>role\\n`` openers,
159
+ inter-turn ``\\n`` separators, system / user / tool content from
160
+ conversation history, etc.). This is distinct from
161
+ ``message_indices``: a token can belong to an assistant message
162
+ (``message_indices[k] >= 0``) and still be scaffolding the template
163
+ adds around the model's actual completion. SFT loss masks should AND
164
+ both: train on tokens whose role is trainable AND that the model
165
+ would actually sample.
166
+
167
+ Empty ``sampled_mask`` (``[]``) means the renderer doesn't provide
168
+ this signal — consumers should fall back to attribution-only
169
+ masking. ``DefaultRenderer`` leaves it empty because the Jinja
170
+ template is opaque; hand-coded renderers populate it.
153
171
 
154
172
  ``multi_modal_data`` is populated by multimodal renderers (e.g.
155
173
  ``Qwen3VLRenderer``) when image / video content parts are present;
@@ -158,8 +176,163 @@ class RenderedTokens:
158
176
 
159
177
  token_ids: list[int] = field(default_factory=list)
160
178
  message_indices: list[int] = field(default_factory=list)
179
+ sampled_mask: list[bool] = field(default_factory=list)
180
+ message_roles: list[str] = field(default_factory=list)
161
181
  multi_modal_data: "MultiModalData | None" = None
162
182
 
183
+ def tokens_per_message(
184
+ self, n_messages: int | None = None, *, sampled_only: bool = False
185
+ ) -> list[int]:
186
+ """Count rendered tokens attributed to each caller-relative message.
187
+
188
+ ``out[i]`` is the number of tokens with ``message_indices[k] == i``,
189
+ i.e. tokens the renderer attributed to ``messages[i]``. This
190
+ includes template scaffolding the renderer wraps around the
191
+ message — the ``<|im_start|>role\\n`` opener, the closing
192
+ ``<|im_end|>\\n``, etc. — because those are the renderer's own
193
+ attribution decision and are preserved verbatim here. Tokens with
194
+ ``message_indices[k] == -1`` (scaffolding outside any single
195
+ message, e.g. the trailing generation prompt) are not counted.
196
+
197
+ With ``sampled_only=True``, counts only tokens the model would
198
+ have emitted at inference (``sampled_mask[k] is True``). For
199
+ example, length-penalty signals in RL: the template wraps each
200
+ assistant turn in scaffolding tokens (e.g. ``<|im_start|>assistant\\n``,
201
+ ``<|im_end|>\\n``) that are constant-size and not chosen by the
202
+ model, so they shouldn't enter the penalty. For roles the model
203
+ never samples (``user``, ``tool``, ``system``), the
204
+ ``sampled_only`` count is zero by construction. Renderers that
205
+ don't populate ``sampled_mask`` (``DefaultRenderer`` — the Jinja
206
+ template is opaque) return all zeros under ``sampled_only=True``.
207
+
208
+ ``n_messages`` defaults to ``len(self.message_roles)``, which
209
+ every Renderer populates with the caller-relative message list
210
+ (caller's ``messages`` for ``render()``; ``new_messages`` for
211
+ ``bridge_to_next_turn()``). Pass it explicitly only to truncate
212
+ — indices outside ``[0, n_messages)`` are ignored, so passing a
213
+ smaller value won't raise; it just drops the tail. Values larger
214
+ than ``len(self.message_roles)`` are clamped, so the returned
215
+ list never claims more messages than the renderer attributed.
216
+
217
+ Works on results from both :meth:`Renderer.render` and
218
+ :meth:`Renderer.bridge_to_next_turn`. For a bridge result the
219
+ indices are relative to the new messages the bridge added, not
220
+ the full conversation history; the prior portion is uniformly
221
+ ``-1`` (and ``sampled_mask`` uniformly ``False``), so it
222
+ contributes nothing to either count.
223
+ """
224
+ if n_messages is None:
225
+ n_messages = len(self.message_roles)
226
+ else:
227
+ n_messages = min(n_messages, len(self.message_roles))
228
+ out = [0] * n_messages
229
+ if sampled_only:
230
+ if len(self.sampled_mask) != len(self.token_ids):
231
+ return out
232
+ for idx, sampled in zip(self.message_indices, self.sampled_mask):
233
+ if sampled and 0 <= idx < n_messages:
234
+ out[idx] += 1
235
+ else:
236
+ for idx in self.message_indices:
237
+ if 0 <= idx < n_messages:
238
+ out[idx] += 1
239
+ return out
240
+
241
+ def message_token_spans(self) -> list[tuple[int, int] | None]:
242
+ """Per-message ``(start, end)`` slices into :attr:`token_ids`.
243
+
244
+ ``out[i]`` is the half-open span ``[start, end)`` such that
245
+ ``token_ids[start:end]`` are the tokens attributed to
246
+ ``messages[i]`` (or ``new_messages[i]`` for a bridge result).
247
+ Messages that contributed no tokens get ``None``. Renderer
248
+ scaffolding outside any message (``message_indices[k] == -1``)
249
+ is not represented.
250
+
251
+ Hand-coded renderers emit each message's tokens contiguously,
252
+ so the span is well-defined. The implementation tolerates
253
+ non-contiguous attribution by returning the outer span
254
+ ``(first_k, last_k + 1)``; if you suspect interleaving, slice
255
+ ``message_indices`` yourself to verify.
256
+
257
+ Returns ``len(self.message_roles)`` entries when ``message_roles``
258
+ is populated. Otherwise infers the count from
259
+ ``max(message_indices) + 1`` — useful for manually-constructed
260
+ ``RenderedTokens`` in tests but only correct when the last
261
+ message contributed at least one token.
262
+
263
+ Cheap to call: single pass over ``message_indices``. Re-call
264
+ rather than caching the result if you mutate the dataclass.
265
+ """
266
+ if self.message_roles:
267
+ n_messages = len(self.message_roles)
268
+ else:
269
+ max_idx = -1
270
+ for idx in self.message_indices:
271
+ if idx > max_idx:
272
+ max_idx = idx
273
+ n_messages = max_idx + 1
274
+
275
+ firsts: list[int] = [-1] * n_messages
276
+ lasts: list[int] = [-1] * n_messages
277
+ for k, idx in enumerate(self.message_indices):
278
+ if 0 <= idx < n_messages:
279
+ if firsts[idx] == -1:
280
+ firsts[idx] = k
281
+ lasts[idx] = k
282
+
283
+ out: list[tuple[int, int] | None] = []
284
+ for i in range(n_messages):
285
+ if firsts[i] == -1:
286
+ out.append(None)
287
+ else:
288
+ out.append((firsts[i], lasts[i] + 1))
289
+ return out
290
+
291
+ def role_token_spans(self) -> dict[str, list[tuple[int, int]]]:
292
+ """:meth:`message_token_spans` regrouped by ``message_roles``.
293
+
294
+ Maps each role appearing in :attr:`message_roles` to a list of
295
+ ``(start, end)`` spans — one per occurrence of that role, in
296
+ message order. Messages with no contributed tokens are skipped.
297
+ Returns an empty dict if :attr:`message_roles` is empty.
298
+
299
+ Intended for per-role statistics that operate on per-token
300
+ signals — e.g. ``logprobs[start:end]`` for each assistant span
301
+ to compute per-turn perplexity, or
302
+ ``attention[start:end]`` for tool-response attention analysis.
303
+ """
304
+ spans = self.message_token_spans()
305
+ out: dict[str, list[tuple[int, int]]] = {}
306
+ for role, span in zip(self.message_roles, spans):
307
+ if span is None:
308
+ out.setdefault(role, [])
309
+ continue
310
+ out.setdefault(role, []).append(span)
311
+ return out
312
+
313
+ def tokens_by_role(self, *, sampled_only: bool = False) -> dict[str, int]:
314
+ """Sum :meth:`tokens_per_message` grouped by ``message_roles``.
315
+
316
+ Convenience for length-penalty bookkeeping in RL trainers:
317
+ ``rendered.tokens_by_role(sampled_only=True)["assistant"]`` is
318
+ the count of tokens the model actually emitted across all
319
+ assistant turns — template scaffolding excluded.
320
+ ``rendered.tokens_by_role()["tool"]`` is the raw count of
321
+ tool-response tokens (``sampled_only`` is zero for ``tool`` by
322
+ construction since the model never samples those).
323
+
324
+ Roles present in :attr:`message_roles` always appear in the
325
+ returned dict, even with post-filter count ``0``, so callers
326
+ can index directly without ``KeyError`` on conversations that
327
+ happen to lack a role. Returns an empty dict if
328
+ :attr:`message_roles` is empty.
329
+ """
330
+ counts = self.tokens_per_message(sampled_only=sampled_only)
331
+ out: dict[str, int] = {}
332
+ for role, n in zip(self.message_roles, counts):
333
+ out[role] = out.get(role, 0) + n
334
+ return out
335
+
163
336
 
164
337
  class ToolCallParseStatus(str, enum.Enum):
165
338
  """Per-attempt outcome of parsing a single ``<tool_call>`` block.
@@ -339,6 +512,25 @@ class Renderer(Protocol):
339
512
  list so far with ``add_generation_prompt=True`` — except prev
340
513
  sampled tokens are kept verbatim rather than re-rendered).
341
514
 
515
+ Attribution on the returned ``RenderedTokens``:
516
+
517
+ - ``message_indices`` is ``-1`` over the entire prior portion
518
+ (length ``len(previous_ids)`` after :func:`trim_to_turn_close`)
519
+ because the bridge gets the prior as raw token lists with no
520
+ attribution. Over the bridge-added portion, indices are
521
+ relative to ``new_messages``: a token rendered as part of
522
+ ``new_messages[i]`` carries ``i``, and inter-turn separators /
523
+ the trailing generation prompt carry ``-1``. So
524
+ ``bridge.tokens_per_message(len(new_messages))`` gives the
525
+ per-new-message token count for length-penalty bookkeeping.
526
+ - ``sampled_mask`` is uniformly ``False`` across the entire
527
+ returned sequence. The bridge output is consumed as the next
528
+ turn's prompt; nothing it emits was model-sampled, and the
529
+ bridge has no way to recover which prior tokens were. If the
530
+ caller needs that distinction for the prior portion, they
531
+ have it directly: every token in ``prev_completion_ids`` was
532
+ sampled; every token in ``prev_prompt_ids`` was not.
533
+
342
534
  Text-only renderers return :class:`RenderedTokens` with
343
535
  ``multi_modal_data=None``. Multimodal renderers (see
344
536
  :class:`MultimodalRenderer`) populate ``multi_modal_data`` so
@@ -713,37 +905,108 @@ TRUSTED_REVISIONS: dict[str, str] = {
713
905
  }
714
906
 
715
907
 
716
- def load_tokenizer(model_name_or_path: str):
717
- """Load a tokenizer with the renderers-package security policy.
908
+ # Models for which ``fastokens`` is known to diverge from vanilla
909
+ # ``transformers.AutoTokenizer`` and therefore must NOT be patched.
910
+ # Empirical audit ran each entry of ``MODEL_RENDERER_MAP`` through both
911
+ # backends; 31/35 passed byte-identical. The four below either fail to
912
+ # load under fastokens (DeepSeek-V3 family — Metaspace pretokenizer not
913
+ # yet implemented) or are kept defensively pending an upstream fastokens
914
+ # fix (MiniMax-M2 family — see per-entry comments).
915
+ FASTOKENS_INCOMPATIBLE: frozenset[str] = frozenset(
916
+ {
917
+ # fastokens 0.1.1: ``ValueError: pre-tokenizer error: unsupported
918
+ # pre-tokenizer type: Metaspace`` — DeepSeek's tokenizer uses
919
+ # SentencePiece-style Metaspace pretokenization which fastokens
920
+ # doesn't yet implement.
921
+ "deepseek-ai/DeepSeek-V3",
922
+ "deepseek-ai/DeepSeek-V3-Base",
923
+ # MiniMax: kept defensive pending upstream fastokens fix
924
+ # https://github.com/crusoecloud/fastokens/pull/32 — that PR
925
+ # removes a stray attribute leaked by ``unpatch_transformers``
926
+ # which steers MiniMax (declared ``tokenizer_class =
927
+ # 'GPT2Tokenizer'`` → slow→fast conversion path) down a different
928
+ # load path on subsequent vanilla loads. Once the upstream fix
929
+ # is released, these two entries can be dropped after re-audit.
930
+ "MiniMaxAI/MiniMax-M2",
931
+ "MiniMaxAI/MiniMax-M2.5",
932
+ }
933
+ )
934
+
935
+
936
+ def _patched_load(model_name_or_path: str, **kwargs):
937
+ """Run ``AutoTokenizer.from_pretrained`` with fastokens patched in
938
+ process-locally — patch around the load, unpatch right after.
939
+
940
+ fastokens captures the loaded backend on a per-tokenizer basis, so
941
+ after we unpatch the returned tokenizer object continues to use
942
+ fastokens for ``encode``/``decode`` while subsequent
943
+ ``AutoTokenizer.from_pretrained`` calls (outside our control) go
944
+ back to vanilla. This keeps the global side effect minimal.
945
+ """
946
+ import fastokens
947
+ from transformers import AutoTokenizer
948
+
949
+ fastokens.patch_transformers()
950
+ try:
951
+ return AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
952
+ finally:
953
+ fastokens.unpatch_transformers()
718
954
 
719
- Default: ``trust_remote_code=False`` — the safe choice for every
720
- model in ``MODEL_RENDERER_MAP`` *except* the Kimi-K2 family.
721
955
 
722
- Models listed in ``TRUSTED_REVISIONS`` load with
723
- ``trust_remote_code=True`` AND ``revision=<pinned sha>`` — required
724
- because their tokenizer config has an ``auto_map.AutoTokenizer``
725
- entry pointing at a repo-supplied Python class
726
- (``tokenization_kimi.TikTokenTokenizer``). Pinning the revision
727
- means transformers executes only the reviewed commit's code, not
728
- whatever ``HEAD`` points at when the call fires.
956
+ def load_tokenizer(
957
+ model_name_or_path: str,
958
+ *,
959
+ use_fastokens: bool = True,
960
+ ):
961
+ """Load a tokenizer with the renderers-package security + perf policy.
962
+
963
+ **Security** — default ``trust_remote_code=False``. Models listed in
964
+ ``TRUSTED_REVISIONS`` (Moonshot Kimi-K2 family) load with
965
+ ``trust_remote_code=True`` AND a pinned ``revision=<sha>`` so
966
+ transformers only executes the reviewed commit's tokenizer Python.
967
+
968
+ **Performance** — ``use_fastokens=True`` (default) routes the load
969
+ through ``fastokens.patch_transformers()`` so the resulting tokenizer
970
+ encodes ~10x faster than vanilla ``tokenizers``. The patch is
971
+ bracketed: it's applied before ``from_pretrained`` and removed
972
+ immediately after, so global ``AutoTokenizer.from_pretrained`` calls
973
+ elsewhere in the user's process are not affected.
974
+
975
+ Models in ``FASTOKENS_INCOMPATIBLE`` (DeepSeek-V3 family, MiniMax-M2
976
+ family) skip the patch — fastokens 0.1.1 either fails to load them
977
+ or produces token-divergent output. Pass ``use_fastokens=False`` to
978
+ force the vanilla backend for any other model.
729
979
 
730
980
  Unknown / fine-tuned model paths fall through to
731
- ``trust_remote_code=False``. Callers who legitimately need to load
732
- a custom-code tokenizer outside this allow-list should call
733
- ``AutoTokenizer.from_pretrained`` themselves and pass the result to
734
- ``create_renderer`` (which doesn't load tokenizers — only
735
- ``create_renderer_pool`` does).
981
+ ``trust_remote_code=False`` and the patched-load fast path. If
982
+ fastokens raises during the patched load (e.g. an unknown
983
+ pre-tokenizer type), we automatically retry with the vanilla
984
+ backend and emit an INFO log.
736
985
  """
737
986
  from transformers import AutoTokenizer
738
987
 
988
+ kwargs: dict[str, Any] = {}
739
989
  revision = TRUSTED_REVISIONS.get(model_name_or_path)
740
990
  if revision is not None:
741
- return AutoTokenizer.from_pretrained(
991
+ kwargs = {"trust_remote_code": True, "revision": revision}
992
+ else:
993
+ kwargs = {"trust_remote_code": False}
994
+
995
+ if not use_fastokens or model_name_or_path in FASTOKENS_INCOMPATIBLE:
996
+ return AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
997
+
998
+ try:
999
+ return _patched_load(model_name_or_path, **kwargs)
1000
+ except Exception as exc:
1001
+ logger.info(
1002
+ "fastokens could not load %r (%s: %s); falling back to vanilla "
1003
+ "AutoTokenizer. Add this model to FASTOKENS_INCOMPATIBLE in "
1004
+ "renderers.base to suppress the retry.",
742
1005
  model_name_or_path,
743
- trust_remote_code=True,
744
- revision=revision,
1006
+ type(exc).__name__,
1007
+ str(exc)[:160],
745
1008
  )
746
- return AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=False)
1009
+ return AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
747
1010
 
748
1011
 
749
1012
  def _populate_registry():
@@ -947,12 +1210,25 @@ def build_training_sample(
947
1210
 
948
1211
  Single render() call + message_indices → per-token mask.
949
1212
  Replaces build_incremental_token_mask (O(N) renders → O(1)).
1213
+
1214
+ When the renderer populates ``rendered.sampled_mask``, the loss mask
1215
+ is the AND of role-based attribution and the sampled signal: only
1216
+ tokens the model would have produced at inference are trainable.
1217
+ This keeps SFT byte-aligned with the RL trajectory mask (where the
1218
+ prompt / completion split achieves the same effect structurally).
1219
+ Renderers that don't populate ``sampled_mask`` (empty list) fall
1220
+ back to attribution-only masking — every token attributed to a
1221
+ trainable role is trained on, including template-injected
1222
+ ``<|im_start|>role\\n`` openers.
950
1223
  """
951
1224
  rendered = renderer.render(messages, tools=tools)
1225
+ has_sampled_info = len(rendered.sampled_mask) == len(rendered.token_ids)
952
1226
  loss_mask: list[bool] = []
953
- for msg_idx in rendered.message_indices:
1227
+ for k, msg_idx in enumerate(rendered.message_indices):
954
1228
  if msg_idx < 0:
955
1229
  loss_mask.append(False)
1230
+ elif has_sampled_info and not rendered.sampled_mask[k]:
1231
+ loss_mask.append(False)
956
1232
  else:
957
1233
  loss_mask.append(role_to_mask(messages[msg_idx]))
958
1234
  return rendered.token_ids, loss_mask