renderers 0.1.8.dev2__tar.gz → 0.1.9.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/PKG-INFO +1 -1
  2. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/_version.py +2 -2
  3. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/base.py +173 -0
  4. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/deepseek_v3.py +28 -12
  5. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/default.py +6 -1
  6. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/glm45.py +28 -12
  7. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/glm5.py +28 -12
  8. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/gpt_oss.py +23 -4
  9. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/kimi_k2.py +28 -11
  10. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/kimi_k25.py +37 -20
  11. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/laguna_xs2.py +36 -19
  12. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/minimax_m2.py +28 -12
  13. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/nemotron3.py +28 -12
  14. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/qwen3.py +28 -12
  15. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/qwen35.py +37 -22
  16. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/qwen3_vl.py +15 -8
  17. renderers-0.1.9.dev0/tests/test_tokens_per_message.py +325 -0
  18. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/.github/workflows/publish.yml +0 -0
  19. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/.github/workflows/style.yml +0 -0
  20. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/.github/workflows/test.yml +0 -0
  21. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/.gitignore +0 -0
  22. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/.pre-commit-config.yaml +0 -0
  23. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/LICENSE +0 -0
  24. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/README.md +0 -0
  25. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/examples/README.md +0 -0
  26. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/examples/sglang/multiturn_generate_sglang.py +0 -0
  27. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/examples/sglang/online_multiturn_sglang.py +0 -0
  28. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/examples/tinker/multiturn_generate_tinker.py +0 -0
  29. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/examples/transformers/multiturn_generate_transformers.py +0 -0
  30. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/examples/vllm/multiturn_generate_vllm.py +0 -0
  31. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/pyproject.toml +0 -0
  32. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/__init__.py +0 -0
  33. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/client.py +0 -0
  34. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/parsers.py +0 -0
  35. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/parsing.py +0 -0
  36. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/renderers/qwen36.py +0 -0
  37. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/conftest.py +0 -0
  38. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/test_bridge.py +0 -0
  39. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/test_build_helpers.py +0 -0
  40. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/test_client.py +0 -0
  41. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/test_gpt_oss_harmony_parity.py +0 -0
  42. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/test_incremental.py +0 -0
  43. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/test_load_tokenizer.py +0 -0
  44. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/test_load_tokenizer_fastokens.py +0 -0
  45. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/test_message_indices.py +0 -0
  46. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/test_multimodal.py +0 -0
  47. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/test_parse_response.py +0 -0
  48. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/test_parse_response_robustness.py +0 -0
  49. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/test_parsers.py +0 -0
  50. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/test_preserve_thinking.py +0 -0
  51. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/test_qwen35_size_coverage.py +0 -0
  52. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/test_render_ids.py +0 -0
  53. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/test_roundtrip.py +0 -0
  54. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/test_sampled_mask.py +0 -0
  55. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/tests/test_tool_arg_type_preservation.py +0 -0
  56. {renderers-0.1.8.dev2 → renderers-0.1.9.dev0}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: renderers
3
- Version: 0.1.8.dev2
3
+ Version: 0.1.9.dev0
4
4
  Summary: Chat template renderers — deterministic message-to-token conversion for LLM training
5
5
  License-Expression: Apache-2.0
6
6
  License-File: LICENSE
@@ -18,7 +18,7 @@ version_tuple: tuple[int | str, ...]
18
18
  commit_id: str | None
19
19
  __commit_id__: str | None
20
20
 
21
- __version__ = version = '0.1.8.dev2'
22
- __version_tuple__ = version_tuple = (0, 1, 8, 'dev2')
21
+ __version__ = version = '0.1.9.dev0'
22
+ __version_tuple__ = version_tuple = (0, 1, 9, 'dev0')
23
23
 
24
24
  __commit_id__ = commit_id = None
@@ -177,8 +177,162 @@ class RenderedTokens:
177
177
  token_ids: list[int] = field(default_factory=list)
178
178
  message_indices: list[int] = field(default_factory=list)
179
179
  sampled_mask: list[bool] = field(default_factory=list)
180
+ message_roles: list[str] = field(default_factory=list)
180
181
  multi_modal_data: "MultiModalData | None" = None
181
182
 
183
+ def tokens_per_message(
184
+ self, n_messages: int | None = None, *, sampled_only: bool = False
185
+ ) -> list[int]:
186
+ """Count rendered tokens attributed to each caller-relative message.
187
+
188
+ ``out[i]`` is the number of tokens with ``message_indices[k] == i``,
189
+ i.e. tokens the renderer attributed to ``messages[i]``. This
190
+ includes template scaffolding the renderer wraps around the
191
+ message — the ``<|im_start|>role\\n`` opener, the closing
192
+ ``<|im_end|>\\n``, etc. — because those are the renderer's own
193
+ attribution decision and are preserved verbatim here. Tokens with
194
+ ``message_indices[k] == -1`` (scaffolding outside any single
195
+ message, e.g. the trailing generation prompt) are not counted.
196
+
197
+ With ``sampled_only=True``, counts only tokens the model would
198
+ have emitted at inference (``sampled_mask[k] is True``). For
199
+ example, length-penalty signals in RL: the template wraps each
200
+ assistant turn in scaffolding tokens (e.g. ``<|im_start|>assistant\\n``,
201
+ ``<|im_end|>\\n``) that are constant-size and not chosen by the
202
+ model, so they shouldn't enter the penalty. For roles the model
203
+ never samples (``user``, ``tool``, ``system``), the
204
+ ``sampled_only`` count is zero by construction. Renderers that
205
+ don't populate ``sampled_mask`` (``DefaultRenderer`` — the Jinja
206
+ template is opaque) return all zeros under ``sampled_only=True``.
207
+
208
+ ``n_messages`` defaults to ``len(self.message_roles)``, which
209
+ every Renderer populates with the caller-relative message list
210
+ (caller's ``messages`` for ``render()``; ``new_messages`` for
211
+ ``bridge_to_next_turn()``). Pass it explicitly only to truncate
212
+ — indices outside ``[0, n_messages)`` are ignored, so passing a
213
+ smaller value won't raise; it just drops the tail. Values larger
214
+ than ``len(self.message_roles)`` are clamped, so the returned
215
+ list never claims more messages than the renderer attributed.
216
+
217
+ Works on results from both :meth:`Renderer.render` and
218
+ :meth:`Renderer.bridge_to_next_turn`. For a bridge result the
219
+ indices are relative to the new messages the bridge added, not
220
+ the full conversation history; the prior portion is uniformly
221
+ ``-1`` (and ``sampled_mask`` uniformly ``False``), so it
222
+ contributes nothing to either count.
223
+ """
224
+ if n_messages is None:
225
+ n_messages = len(self.message_roles)
226
+ else:
227
+ n_messages = min(n_messages, len(self.message_roles))
228
+ out = [0] * n_messages
229
+ if sampled_only:
230
+ if len(self.sampled_mask) != len(self.token_ids):
231
+ return out
232
+ for idx, sampled in zip(self.message_indices, self.sampled_mask):
233
+ if sampled and 0 <= idx < n_messages:
234
+ out[idx] += 1
235
+ else:
236
+ for idx in self.message_indices:
237
+ if 0 <= idx < n_messages:
238
+ out[idx] += 1
239
+ return out
240
+
241
+ def message_token_spans(self) -> list[tuple[int, int] | None]:
242
+ """Per-message ``(start, end)`` slices into :attr:`token_ids`.
243
+
244
+ ``out[i]`` is the half-open span ``[start, end)`` such that
245
+ ``token_ids[start:end]`` are the tokens attributed to
246
+ ``messages[i]`` (or ``new_messages[i]`` for a bridge result).
247
+ Messages that contributed no tokens get ``None``. Renderer
248
+ scaffolding outside any message (``message_indices[k] == -1``)
249
+ is not represented.
250
+
251
+ Hand-coded renderers emit each message's tokens contiguously,
252
+ so the span is well-defined. The implementation tolerates
253
+ non-contiguous attribution by returning the outer span
254
+ ``(first_k, last_k + 1)``; if you suspect interleaving, slice
255
+ ``message_indices`` yourself to verify.
256
+
257
+ Returns ``len(self.message_roles)`` entries when ``message_roles``
258
+ is populated. Otherwise infers the count from
259
+ ``max(message_indices) + 1`` — useful for manually-constructed
260
+ ``RenderedTokens`` in tests but only correct when the last
261
+ message contributed at least one token.
262
+
263
+ Cheap to call: single pass over ``message_indices``. Re-call
264
+ rather than caching the result if you mutate the dataclass.
265
+ """
266
+ if self.message_roles:
267
+ n_messages = len(self.message_roles)
268
+ else:
269
+ max_idx = -1
270
+ for idx in self.message_indices:
271
+ if idx > max_idx:
272
+ max_idx = idx
273
+ n_messages = max_idx + 1
274
+
275
+ firsts: list[int] = [-1] * n_messages
276
+ lasts: list[int] = [-1] * n_messages
277
+ for k, idx in enumerate(self.message_indices):
278
+ if 0 <= idx < n_messages:
279
+ if firsts[idx] == -1:
280
+ firsts[idx] = k
281
+ lasts[idx] = k
282
+
283
+ out: list[tuple[int, int] | None] = []
284
+ for i in range(n_messages):
285
+ if firsts[i] == -1:
286
+ out.append(None)
287
+ else:
288
+ out.append((firsts[i], lasts[i] + 1))
289
+ return out
290
+
291
+ def role_token_spans(self) -> dict[str, list[tuple[int, int]]]:
292
+ """:meth:`message_token_spans` regrouped by ``message_roles``.
293
+
294
+ Maps each role appearing in :attr:`message_roles` to a list of
295
+ ``(start, end)`` spans — one per occurrence of that role, in
296
+ message order. Messages with no contributed tokens are skipped.
297
+ Returns an empty dict if :attr:`message_roles` is empty.
298
+
299
+ Intended for per-role statistics that operate on per-token
300
+ signals — e.g. ``logprobs[start:end]`` for each assistant span
301
+ to compute per-turn perplexity, or
302
+ ``attention[start:end]`` for tool-response attention analysis.
303
+ """
304
+ spans = self.message_token_spans()
305
+ out: dict[str, list[tuple[int, int]]] = {}
306
+ for role, span in zip(self.message_roles, spans):
307
+ if span is None:
308
+ out.setdefault(role, [])
309
+ continue
310
+ out.setdefault(role, []).append(span)
311
+ return out
312
+
313
+ def tokens_by_role(self, *, sampled_only: bool = False) -> dict[str, int]:
314
+ """Sum :meth:`tokens_per_message` grouped by ``message_roles``.
315
+
316
+ Convenience for length-penalty bookkeeping in RL trainers:
317
+ ``rendered.tokens_by_role(sampled_only=True)["assistant"]`` is
318
+ the count of tokens the model actually emitted across all
319
+ assistant turns — template scaffolding excluded.
320
+ ``rendered.tokens_by_role()["tool"]`` is the raw count of
321
+ tool-response tokens (``sampled_only`` is zero for ``tool`` by
322
+ construction since the model never samples those).
323
+
324
+ Roles present in :attr:`message_roles` always appear in the
325
+ returned dict, even with post-filter count ``0``, so callers
326
+ can index directly without ``KeyError`` on conversations that
327
+ happen to lack a role. Returns an empty dict if
328
+ :attr:`message_roles` is empty.
329
+ """
330
+ counts = self.tokens_per_message(sampled_only=sampled_only)
331
+ out: dict[str, int] = {}
332
+ for role, n in zip(self.message_roles, counts):
333
+ out[role] = out.get(role, 0) + n
334
+ return out
335
+
182
336
 
183
337
  class ToolCallParseStatus(str, enum.Enum):
184
338
  """Per-attempt outcome of parsing a single ``<tool_call>`` block.
@@ -358,6 +512,25 @@ class Renderer(Protocol):
358
512
  list so far with ``add_generation_prompt=True`` — except prev
359
513
  sampled tokens are kept verbatim rather than re-rendered).
360
514
 
515
+ Attribution on the returned ``RenderedTokens``:
516
+
517
+ - ``message_indices`` is ``-1`` over the entire prior portion
518
+ (length ``len(previous_ids)`` after :func:`trim_to_turn_close`)
519
+ because the bridge gets the prior as raw token lists with no
520
+ attribution. Over the bridge-added portion, indices are
521
+ relative to ``new_messages``: a token rendered as part of
522
+ ``new_messages[i]`` carries ``i``, and inter-turn separators /
523
+ the trailing generation prompt carry ``-1``. So
524
+ ``bridge.tokens_per_message(len(new_messages))`` gives the
525
+ per-new-message token count for length-penalty bookkeeping.
526
+ - ``sampled_mask`` is uniformly ``False`` across the entire
527
+ returned sequence. The bridge output is consumed as the next
528
+ turn's prompt; nothing it emits was model-sampled, and the
529
+ bridge has no way to recover which prior tokens were. If the
530
+ caller needs that distinction for the prior portion, they
531
+ have it directly: every token in ``prev_completion_ids`` was
532
+ sampled; every token in ``prev_prompt_ids`` was not.
533
+
361
534
  Text-only renderers return :class:`RenderedTokens` with
362
535
  ``multi_modal_data=None``. Multimodal renderers (see
363
536
  :class:`MultimodalRenderer`) populate ``multi_modal_data`` so
@@ -210,7 +210,10 @@ class DeepSeekV3Renderer:
210
210
  emit_text("<think>\n", -1, is_sampled=False)
211
211
 
212
212
  return RenderedTokens(
213
- token_ids=tokens, message_indices=indices, sampled_mask=sampled
213
+ token_ids=tokens,
214
+ message_indices=indices,
215
+ sampled_mask=sampled,
216
+ message_roles=[m.get("role") or "" for m in messages],
214
217
  )
215
218
 
216
219
  def render_ids(
@@ -271,22 +274,29 @@ class DeepSeekV3Renderer:
271
274
  return None
272
275
 
273
276
  ext: list[int] = []
274
-
275
- # Bridge output is consumed as the next turn's prompt — the
276
- # caller blanket-masks it via ``prompt_mask=[False]*N``, so we
277
- # don't track sampled_mask here. Local helpers accept the kwarg
278
- # for signature compatibility with ``_render_tool`` and ignore
279
- # it; the returned ``RenderedTokens`` leaves ``sampled_mask``
280
- # empty.
277
+ ext_indices: list[int] = []
278
+ ext_sampled: list[bool] = []
279
+
280
+ # Bridge populates ``message_indices`` (relative to ``new_messages``)
281
+ # and ``sampled_mask`` (uniformly ``False`` every token the
282
+ # bridge emits is template scaffolding for the next prompt, not
283
+ # something the model sampled). Downstream consumers can run
284
+ # :meth:`RenderedTokens.tokens_per_message` on the bridge output
285
+ # to get per-new-message token counts without re-rendering.
281
286
  def emit_special(
282
- token_id: int, _msg_idx: int = -1, *, is_sampled: bool = False
287
+ token_id: int, msg_idx: int = -1, *, is_sampled: bool = False
283
288
  ) -> None:
284
289
  ext.append(token_id)
290
+ ext_indices.append(msg_idx)
291
+ ext_sampled.append(is_sampled)
285
292
 
286
293
  def emit_text(
287
- text: str, _msg_idx: int = -1, *, is_sampled: bool = False
294
+ text: str, msg_idx: int = -1, *, is_sampled: bool = False
288
295
  ) -> None:
289
- ext.extend(self._encode(text))
296
+ ids = self._encode(text)
297
+ ext.extend(ids)
298
+ ext_indices.extend([msg_idx] * len(ids))
299
+ ext_sampled.extend([is_sampled] * len(ids))
290
300
 
291
301
  for i, msg in enumerate(new_messages):
292
302
  role = msg.get("role")
@@ -329,7 +339,13 @@ class DeepSeekV3Renderer:
329
339
  if self._enable_thinking:
330
340
  emit_text("<think>\n", -1)
331
341
 
332
- return RenderedTokens(token_ids=previous_ids + ext)
342
+ total_len = len(previous_ids) + len(ext)
343
+ return RenderedTokens(
344
+ token_ids=previous_ids + ext,
345
+ message_indices=[-1] * len(previous_ids) + ext_indices,
346
+ sampled_mask=[False] * total_len,
347
+ message_roles=[m.get("role") or "" for m in new_messages],
348
+ )
333
349
 
334
350
  # ------------------------------------------------------------------
335
351
  # Assistant rendering
@@ -143,7 +143,12 @@ class DefaultRenderer:
143
143
  token_ids = full_ids
144
144
  message_indices.extend([-1] * len(gen_tokens))
145
145
 
146
- return RenderedTokens(token_ids=token_ids, message_indices=message_indices)
146
+ message_roles = [m.get("role") or "" for m in messages]
147
+ return RenderedTokens(
148
+ token_ids=token_ids,
149
+ message_indices=message_indices,
150
+ message_roles=message_roles,
151
+ )
147
152
 
148
153
  def _apply(self, messages, *, tools=None, add_generation_prompt=False) -> list[int]:
149
154
  kwargs = dict(self._chat_template_kwargs)
@@ -203,7 +203,10 @@ class GLM45Renderer:
203
203
  emit_special(self._think_end, -1, is_sampled=False)
204
204
 
205
205
  return RenderedTokens(
206
- token_ids=tokens, message_indices=indices, sampled_mask=sampled
206
+ token_ids=tokens,
207
+ message_indices=indices,
208
+ sampled_mask=sampled,
209
+ message_roles=[m.get("role") or "" for m in messages],
207
210
  )
208
211
 
209
212
  def render_ids(
@@ -271,22 +274,29 @@ class GLM45Renderer:
271
274
  last_prev = previous_ids[-1]
272
275
 
273
276
  ext: list[int] = []
274
-
275
- # Bridge output is consumed as the next turn's prompt — the
276
- # caller blanket-masks it via ``prompt_mask=[False]*N``, so we
277
- # don't track sampled_mask here. Local helpers accept the kwarg
278
- # for signature compatibility with ``_render_tool`` and ignore
279
- # it; the returned ``RenderedTokens`` leaves ``sampled_mask``
280
- # empty.
277
+ ext_indices: list[int] = []
278
+ ext_sampled: list[bool] = []
279
+
280
+ # Bridge populates ``message_indices`` (relative to ``new_messages``)
281
+ # and ``sampled_mask`` (uniformly ``False`` every token the
282
+ # bridge emits is template scaffolding for the next prompt, not
283
+ # something the model sampled). Downstream consumers can run
284
+ # :meth:`RenderedTokens.tokens_per_message` on the bridge output
285
+ # to get per-new-message token counts without re-rendering.
281
286
  def emit_special(
282
- token_id: int, _msg_idx: int = -1, *, is_sampled: bool = False
287
+ token_id: int, msg_idx: int = -1, *, is_sampled: bool = False
283
288
  ) -> None:
284
289
  ext.append(token_id)
290
+ ext_indices.append(msg_idx)
291
+ ext_sampled.append(is_sampled)
285
292
 
286
293
  def emit_text(
287
- text: str, _msg_idx: int = -1, *, is_sampled: bool = False
294
+ text: str, msg_idx: int = -1, *, is_sampled: bool = False
288
295
  ) -> None:
289
- ext.extend(self._encode(text))
296
+ ids = self._encode(text)
297
+ ext.extend(ids)
298
+ ext_indices.extend([msg_idx] * len(ids))
299
+ ext_sampled.extend([is_sampled] * len(ids))
290
300
 
291
301
  for i, msg in enumerate(new_messages):
292
302
  role = msg.get("role")
@@ -318,7 +328,13 @@ class GLM45Renderer:
318
328
  emit_special(self._think, -1)
319
329
  emit_special(self._think_end, -1)
320
330
 
321
- return RenderedTokens(token_ids=previous_ids + ext)
331
+ total_len = len(previous_ids) + len(ext)
332
+ return RenderedTokens(
333
+ token_ids=previous_ids + ext,
334
+ message_indices=[-1] * len(previous_ids) + ext_indices,
335
+ sampled_mask=[False] * total_len,
336
+ message_roles=[m.get("role") or "" for m in new_messages],
337
+ )
322
338
 
323
339
  def _render_assistant(
324
340
  self,
@@ -220,7 +220,10 @@ class GLM5Renderer:
220
220
  emit_special(self._think_end, -1, is_sampled=False)
221
221
 
222
222
  return RenderedTokens(
223
- token_ids=tokens, message_indices=indices, sampled_mask=sampled
223
+ token_ids=tokens,
224
+ message_indices=indices,
225
+ sampled_mask=sampled,
226
+ message_roles=[m.get("role") or "" for m in messages],
224
227
  )
225
228
 
226
229
  def render_ids(
@@ -292,22 +295,29 @@ class GLM5Renderer:
292
295
  last_prev = previous_ids[-1]
293
296
 
294
297
  ext: list[int] = []
295
-
296
- # Bridge output is consumed as the next turn's prompt — the
297
- # caller blanket-masks it via ``prompt_mask=[False]*N``, so we
298
- # don't track sampled_mask here. Local helpers accept the kwarg
299
- # for signature compatibility with ``_render_assistant`` /
300
- # ``_render_tool`` and ignore it; the returned ``RenderedTokens``
301
- # leaves ``sampled_mask`` empty.
298
+ ext_indices: list[int] = []
299
+ ext_sampled: list[bool] = []
300
+
301
+ # Bridge populates ``message_indices`` (relative to ``new_messages``)
302
+ # and ``sampled_mask`` (uniformly ``False`` — every token the
303
+ # bridge emits is template scaffolding for the next prompt, not
304
+ # something the model sampled). Downstream consumers can run
305
+ # :meth:`RenderedTokens.tokens_per_message` on the bridge output
306
+ # to get per-new-message token counts without re-rendering.
302
307
  def emit_special(
303
- token_id: int, _msg_idx: int = -1, *, is_sampled: bool = False
308
+ token_id: int, msg_idx: int = -1, *, is_sampled: bool = False
304
309
  ) -> None:
305
310
  ext.append(token_id)
311
+ ext_indices.append(msg_idx)
312
+ ext_sampled.append(is_sampled)
306
313
 
307
314
  def emit_text(
308
- text: str, _msg_idx: int = -1, *, is_sampled: bool = False
315
+ text: str, msg_idx: int = -1, *, is_sampled: bool = False
309
316
  ) -> None:
310
- ext.extend(self._encode(text))
317
+ ids = self._encode(text)
318
+ ext.extend(ids)
319
+ ext_indices.extend([msg_idx] * len(ids))
320
+ ext_sampled.extend([is_sampled] * len(ids))
311
321
 
312
322
  for i, msg in enumerate(new_messages):
313
323
  role = msg.get("role")
@@ -340,7 +350,13 @@ class GLM5Renderer:
340
350
  else:
341
351
  emit_special(self._think_end, -1)
342
352
 
343
- return RenderedTokens(token_ids=previous_ids + ext)
353
+ total_len = len(previous_ids) + len(ext)
354
+ return RenderedTokens(
355
+ token_ids=previous_ids + ext,
356
+ message_indices=[-1] * len(previous_ids) + ext_indices,
357
+ sampled_mask=[False] * total_len,
358
+ message_roles=[m.get("role") or "" for m in new_messages],
359
+ )
344
360
 
345
361
  def _render_assistant(
346
362
  self,
@@ -333,7 +333,10 @@ class GptOssRenderer:
333
333
  emit([self._message], -1, is_sampled=False)
334
334
 
335
335
  return RenderedTokens(
336
- token_ids=tokens, message_indices=indices, sampled_mask=sampled
336
+ token_ids=tokens,
337
+ message_indices=indices,
338
+ sampled_mask=sampled,
339
+ message_roles=[m.get("role") or "" for m in messages],
337
340
  )
338
341
 
339
342
  def render_ids(
@@ -400,22 +403,38 @@ class GptOssRenderer:
400
403
  if previous_ids is None:
401
404
  return None
402
405
 
406
+ # Bridge populates ``message_indices`` (relative to ``new_messages``)
407
+ # and ``sampled_mask`` (uniformly ``False``). The harmony encoder
408
+ # renders each ``new_messages[i]`` as a single block, so every
409
+ # token in that block carries index ``i``; the trailing
410
+ # generation prompt uses ``-1``.
403
411
  ext: list[int] = []
404
- for msg in new_messages:
412
+ ext_indices: list[int] = []
413
+ for i, msg in enumerate(new_messages):
405
414
  role = msg.get("role")
406
415
  if role not in ("tool", "user", "system", "developer"):
407
416
  return None
408
417
  for hm in self._to_harmony_messages(msg):
409
- ext.extend(self._enc.render(hm))
418
+ ids = self._enc.render(hm)
419
+ ext.extend(ids)
420
+ ext_indices.extend([i] * len(ids))
410
421
 
411
422
  # Generation prompt: <|start|>assistant<|channel|>analysis<|message|>
423
+ gen_before = len(ext)
412
424
  ext.append(self._start)
413
425
  ext.extend(self._encode("assistant"))
414
426
  ext.append(self._channel)
415
427
  ext.extend(self._encode("analysis"))
416
428
  ext.append(self._message)
429
+ ext_indices.extend([-1] * (len(ext) - gen_before))
417
430
 
418
- return RenderedTokens(token_ids=previous_ids + ext)
431
+ total_len = len(previous_ids) + len(ext)
432
+ return RenderedTokens(
433
+ token_ids=previous_ids + ext,
434
+ message_indices=[-1] * len(previous_ids) + ext_indices,
435
+ sampled_mask=[False] * total_len,
436
+ message_roles=[m.get("role") or "" for m in new_messages],
437
+ )
419
438
 
420
439
  # ── message conversion ───────────────────────────────────────────────────
421
440
 
@@ -270,7 +270,10 @@ class KimiK2Renderer:
270
270
  emit_special(self._im_middle, -1, is_sampled=False)
271
271
 
272
272
  return RenderedTokens(
273
- token_ids=token_ids, message_indices=indices, sampled_mask=sampled
273
+ token_ids=token_ids,
274
+ message_indices=indices,
275
+ sampled_mask=sampled,
276
+ message_roles=[m.get("role") or "" for m in messages],
274
277
  )
275
278
 
276
279
  def render_ids(
@@ -331,21 +334,29 @@ class KimiK2Renderer:
331
334
  return None
332
335
 
333
336
  ext: list[int] = []
334
-
335
- # Bridge output is consumed as the next turn's prompt — the caller
336
- # blanket-masks it via ``prompt_mask=[False]*N``, so we don't track
337
- # sampled_mask here. Local helpers accept the kwarg for signature
338
- # compatibility with ``_render_tool`` and ignore it; the returned
339
- # ``RenderedTokens`` leaves ``sampled_mask`` empty.
337
+ ext_indices: list[int] = []
338
+ ext_sampled: list[bool] = []
339
+
340
+ # Bridge populates ``message_indices`` (relative to ``new_messages``)
341
+ # and ``sampled_mask`` (uniformly ``False`` every token the
342
+ # bridge emits is template scaffolding for the next prompt, not
343
+ # something the model sampled). Downstream consumers can run
344
+ # :meth:`RenderedTokens.tokens_per_message` on the bridge output
345
+ # to get per-new-message token counts without re-rendering.
340
346
  def emit_special(
341
- token_id: int, _msg_idx: int = -1, *, is_sampled: bool = False
347
+ token_id: int, msg_idx: int = -1, *, is_sampled: bool = False
342
348
  ) -> None:
343
349
  ext.append(token_id)
350
+ ext_indices.append(msg_idx)
351
+ ext_sampled.append(is_sampled)
344
352
 
345
353
  def emit_text(
346
- text: str, _msg_idx: int = -1, *, is_sampled: bool = False
354
+ text: str, msg_idx: int = -1, *, is_sampled: bool = False
347
355
  ) -> None:
348
- ext.extend(self._encode(text))
356
+ ids = self._encode(text)
357
+ ext.extend(ids)
358
+ ext_indices.extend([msg_idx] * len(ids))
359
+ ext_sampled.extend([is_sampled] * len(ids))
349
360
 
350
361
  for i, msg in enumerate(new_messages):
351
362
  role = msg.get("role")
@@ -388,7 +399,13 @@ class KimiK2Renderer:
388
399
  emit_text("assistant", -1, is_sampled=False)
389
400
  emit_special(self._im_middle, -1, is_sampled=False)
390
401
 
391
- return RenderedTokens(token_ids=previous_ids + ext)
402
+ total_len = len(previous_ids) + len(ext)
403
+ return RenderedTokens(
404
+ token_ids=previous_ids + ext,
405
+ message_indices=[-1] * len(previous_ids) + ext_indices,
406
+ sampled_mask=[False] * total_len,
407
+ message_roles=[m.get("role") or "" for m in new_messages],
408
+ )
392
409
 
393
410
  def _render_assistant(
394
411
  self,
@@ -906,6 +906,7 @@ class KimiK25Renderer:
906
906
  token_ids=tokens,
907
907
  message_indices=indices,
908
908
  sampled_mask=sampled,
909
+ message_roles=[m.get("role") or "" for m in messages],
909
910
  multi_modal_data=mm_data,
910
911
  )
911
912
 
@@ -995,44 +996,52 @@ class KimiK25Renderer:
995
996
  return None
996
997
 
997
998
  # Seed combined-token list with prior turn so placeholder offsets
998
- # are absolute in the bridged sequence.
999
+ # are absolute in the bridged sequence. Parallel
1000
+ # ``indices``/``sampled`` are seeded with ``-1``/``False`` for the
1001
+ # prior portion — the bridge has no attribution info for
1002
+ # ``previous_ids``. Bridge-added tokens get proper ``msg_idx``
1003
+ # (relative to ``new_messages``) and uniformly ``False``
1004
+ # ``sampled``: nothing the bridge emits was model-sampled.
999
1005
  tokens: list[int] = list(previous_ids)
1006
+ indices: list[int] = [-1] * len(previous_ids)
1007
+ sampled: list[bool] = [False] * len(previous_ids)
1000
1008
  new_hashes: dict[str, list[str]] = {}
1001
1009
  new_placeholders: dict[str, list[PlaceholderRange]] = {}
1002
1010
  new_items: dict[str, list[dict[str, Any]]] = {}
1003
1011
 
1004
- # Bridge output is consumed as the next turn's prompt — the caller
1005
- # blanket-masks it via ``prompt_mask=[False]*N``, so we don't track
1006
- # sampled_mask here. Local helpers accept the kwarg for signature
1007
- # compatibility with ``_render_tool_body`` / ``_emit_content`` and
1008
- # ignore it; the returned ``RenderedTokens`` leaves ``sampled_mask``
1009
- # empty.
1010
1012
  def emit_special(
1011
- token_id: int, _msg_idx: int = -1, *, is_sampled: bool = False
1013
+ token_id: int, msg_idx: int = -1, *, is_sampled: bool = False
1012
1014
  ) -> None:
1013
1015
  tokens.append(token_id)
1016
+ indices.append(msg_idx)
1017
+ sampled.append(is_sampled)
1014
1018
 
1015
1019
  def emit_text(
1016
- text: str, _msg_idx: int = -1, *, is_sampled: bool = False
1020
+ text: str, msg_idx: int = -1, *, is_sampled: bool = False
1017
1021
  ) -> None:
1018
- tokens.extend(self._encode(text))
1022
+ ids = self._encode(text)
1023
+ tokens.extend(ids)
1024
+ indices.extend([msg_idx] * len(ids))
1025
+ sampled.extend([is_sampled] * len(ids))
1019
1026
 
1020
1027
  def emit_ids(
1021
- ids: list[int], _msg_idx: int = -1, *, is_sampled: bool = False
1028
+ ids: list[int], msg_idx: int = -1, *, is_sampled: bool = False
1022
1029
  ) -> None:
1023
1030
  tokens.extend(ids)
1031
+ indices.extend([msg_idx] * len(ids))
1032
+ sampled.extend([is_sampled] * len(ids))
1024
1033
 
1025
1034
  def emit_image(
1026
- part: dict[str, Any], _msg_idx: int = -1, *, is_sampled: bool = False
1035
+ part: dict[str, Any], msg_idx: int = -1, *, is_sampled: bool = False
1027
1036
  ) -> None:
1028
1037
  _, out, _num_patches, h = self._process_image(part)
1029
- emit_special(self._media_begin)
1030
- emit_text("image")
1031
- emit_special(self._media_content)
1038
+ emit_special(self._media_begin, msg_idx)
1039
+ emit_text("image", msg_idx)
1040
+ emit_special(self._media_content, msg_idx)
1032
1041
  offset = len(tokens)
1033
- emit_special(self._media_pad)
1034
- emit_special(self._media_end)
1035
- emit_text("\n")
1042
+ emit_special(self._media_pad, msg_idx)
1043
+ emit_special(self._media_end, msg_idx)
1044
+ emit_text("\n", msg_idx)
1036
1045
  new_hashes.setdefault("image", []).append(h)
1037
1046
  new_placeholders.setdefault("image", []).append(
1038
1047
  PlaceholderRange(offset=offset, length=1)
@@ -1113,8 +1122,14 @@ class KimiK25Renderer:
1113
1122
  for modality, vals in new_items.items():
1114
1123
  merged_items.setdefault(modality, []).extend(vals)
1115
1124
 
1125
+ bridge_roles = [m.get("role") or "" for m in new_messages]
1116
1126
  if not (merged_hashes or merged_placeholders or merged_items):
1117
- return RenderedTokens(token_ids=tokens)
1127
+ return RenderedTokens(
1128
+ token_ids=tokens,
1129
+ message_indices=indices,
1130
+ sampled_mask=sampled,
1131
+ message_roles=bridge_roles,
1132
+ )
1118
1133
 
1119
1134
  mm_data = MultiModalData(
1120
1135
  mm_hashes=merged_hashes,
@@ -1123,7 +1138,9 @@ class KimiK25Renderer:
1123
1138
  )
1124
1139
  return RenderedTokens(
1125
1140
  token_ids=tokens,
1126
- message_indices=[-1] * len(tokens),
1141
+ message_indices=indices,
1142
+ sampled_mask=sampled,
1143
+ message_roles=bridge_roles,
1127
1144
  multi_modal_data=mm_data,
1128
1145
  )
1129
1146