freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,780 @@
1
+ """Multi-turn / tool GRPO rollout for TRL's experimental ``rollout_func`` (colocate vLLM).
2
+
3
+ TRL's ``GRPOTrainer`` generates a single assistant turn per prompt, which cannot drive a
4
+ Freesolo ``EnvironmentMultiTurn`` turn loop (model turn -> env reply -> ...). This
5
+ module supplies a ``rollout_func`` that:
6
+
7
+ * drives the env's turn loop via the adapter helpers (``new_rollout_state`` /
8
+ ``record_model_turn`` / ``env_reply`` / ``rollout_done``), so the *env* owns tool
9
+ execution, ``StatefulToolEnv`` state threading, and any simulated-user turns;
10
+ * returns the FULL interleaved token sequence as ``completion_ids`` together with an
11
+ ``env_mask`` that marks model-generated tokens (``1``, trained) vs tool/env tokens
12
+ (``0``, masked out of the loss). ``env_mask`` is TRL's documented mechanism for
13
+ multi-turn credit assignment (it is treated internally as the tool mask), so only the
14
+ policy's own tokens get advantage while the env tokens still provide context for the
15
+ forward pass;
16
+ * scores each rollout with the environment reward (``reward_from_messages``) and returns
17
+ it as an extra field consumed by a pass-through ``reward_func``.
18
+
19
+ Token alignment assumes a **prefix-preserving** chat template: appending a message must not
20
+ retokenize earlier turns (the same assumption TRL's native tool loop documents; auto-patched
21
+ for Qwen3 / DeepSeek-V3). The env segment between two model turns is taken as the suffix of a
22
+ full re-render; if the prefix invariant is violated the rollout raises (fails loudly) rather
23
+ than mis-masking model vs env tokens and silently mistraining.
24
+
25
+ The core (:func:`rollout_one`) is pure Python and takes injected ``render``/``generate``
26
+ callables so it can be unit-tested without a GPU/tokenizer; :func:`build_rollout_func` wires
27
+ the real tokenizer + the colocate vLLM engine into it at runtime.
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ import contextlib
33
+ import json
34
+ import queue
35
+ import threading
36
+ from collections import OrderedDict
37
+ from collections.abc import Callable
38
+ from concurrent.futures import ThreadPoolExecutor, as_completed
39
+ from typing import TypedDict
40
+
41
+
42
+ class RolloutResult(TypedDict):
43
+ """Token-aligned fields returned per rollout for TRL's ``rollout_func``."""
44
+
45
+ prompt_ids: list[int]
46
+ completion_ids: list[int]
47
+ logprobs: list[float]
48
+ env_mask: list[int]
49
+ reward: float
50
+
51
+
52
+ # Field names shared between a single RolloutResult and the batched dict-of-lists that
53
+ # build_rollout_func returns. Kept as a plain tuple (not RolloutResult.__annotations__) so
54
+ # the batch accumulator's key source isn't a single-rollout type whose value types (float,
55
+ # list[int], ...) deliberately differ from the accumulator's list-of-those.
56
+ _ROLLOUT_FIELDS: tuple[str, ...] = (
57
+ "prompt_ids",
58
+ "completion_ids",
59
+ "logprobs",
60
+ "env_mask",
61
+ "reward",
62
+ )
63
+
64
+
65
+ def _prompt_key(prompt) -> str:
66
+ """Stable key for mapping a dataset ``prompt`` value back to its example row."""
67
+ try:
68
+ return json.dumps(prompt, sort_keys=True, default=str)
69
+ except (TypeError, ValueError):
70
+ return str(prompt)
71
+
72
+
73
+ class _LRUCache:
74
+ """Tiny bounded LRU cache (string key -> ``list[int]``) for the render / env_glue closures.
75
+
76
+ A plain ``len(d) < MAX`` guard FREEZES the cache once full: any new key after the cap is never
77
+ admitted, so later-repeated-but-diverse prompts/glue re-render forever and the cache stops paying
78
+ off over a long run. This evicts the least-recently-used entry on insert-when-full instead, so a
79
+ fixed-size window of recently-seen keys stays cached no matter how many distinct keys appear.
80
+ Recency is updated on every hit (``move_to_end``); not thread-safe (each cache is owned by a
81
+ single closure called from one thread).
82
+ """
83
+
84
+ __slots__ = ("_data", "maxsize")
85
+
86
+ def __init__(self, maxsize: int):
87
+ if maxsize <= 0:
88
+ raise ValueError("LRU cache maxsize must be positive")
89
+ self.maxsize = maxsize
90
+ self._data: OrderedDict[str, list[int]] = OrderedDict()
91
+
92
+ def get(self, key: str) -> list[int] | None:
93
+ """Return the cached value and mark it most-recently-used, or None on a miss."""
94
+ value = self._data.get(key)
95
+ if value is not None:
96
+ self._data.move_to_end(key)
97
+ return value
98
+
99
+ def put(self, key: str, value: list[int]) -> None:
100
+ """Insert/refresh ``key`` as most-recently-used, evicting the oldest entry if at capacity."""
101
+ if key in self._data:
102
+ self._data.move_to_end(key)
103
+ self._data[key] = value
104
+ if len(self._data) > self.maxsize:
105
+ self._data.popitem(last=False) # drop the least-recently-used entry
106
+
107
+ def __len__(self) -> int:
108
+ return len(self._data)
109
+
110
+
111
+ def build_examples_index(rows: list[dict], prompt_of: Callable[[dict], object]) -> dict:
112
+ """Map each row's rendered ``prompt`` value to the example row (for reward/answer lookup).
113
+
114
+ Collisions (two rows producing the same prompt) keep the last row and are reported by the
115
+ caller via :func:`index_collisions`; duplicates are rare in training data and only affect
116
+ which ``answer``/``info`` a shared prompt scores against.
117
+ """
118
+ return {_prompt_key(prompt_of(r)): r for r in rows}
119
+
120
+
121
+ def index_collisions(rows: list[dict], prompt_of: Callable[[dict], object]) -> int:
122
+ """Number of rows dropped by prompt-key collisions in :func:`build_examples_index`."""
123
+ return len(rows) - len({_prompt_key(prompt_of(r)) for r in rows})
124
+
125
+
126
+ def rollout_one(
127
+ *,
128
+ example: dict,
129
+ active_env,
130
+ render: Callable[[list, bool], list[int]],
131
+ generate: Callable[[list, int], tuple[list[int], list[float], str]],
132
+ env_glue: Callable[[list], list[int]],
133
+ max_turns: int,
134
+ per_turn_max_tokens: int,
135
+ engine_max_len: int | None = None,
136
+ ) -> RolloutResult:
137
+ """Run one multi-turn/tool rollout and return TRL ``rollout_func`` fields for it.
138
+
139
+ Args:
140
+ example: the dataset row carried into environment scoring.
141
+ active_env: the Freesolo environment adapter (drives the turn loop + scoring).
142
+ render: ``render(messages, add_generation_prompt) -> token_ids`` (chat template) — used
143
+ only for the INITIAL prompt.
144
+ generate: ``generate(prefix_token_ids, max_tokens) -> (token_ids, token_logprobs,
145
+ text)`` for one sampled assistant turn (model tokens + sampling logprobs + text);
146
+ ``max_tokens`` bounds that turn so it can't overflow the engine context.
147
+ env_glue: ``env_glue(env_messages) -> token_ids`` — the tokens that CLOSE the
148
+ just-finished assistant turn, render the env reply message(s), and OPEN the next
149
+ generation prompt. The running token sequence is built incrementally from these
150
+ (the model's generated ids + env glue), never by re-rendering the whole
151
+ conversation — so a chat template that does not round-trip prior turns (e.g. Qwen3's
152
+ empty ``<think>`` block, which is injected into the generation prompt but stripped
153
+ from history) stays token-aligned instead of failing the old prefix check.
154
+ max_turns: hard cap on model turns (defense against a non-terminating env).
155
+
156
+ Returns a dict with ``prompt_ids``, ``completion_ids``, ``logprobs``, ``env_mask`` (all
157
+ token-aligned) and the scalar ``reward`` for this rollout.
158
+ """
159
+ state = active_env.new_rollout_state(example)
160
+ initial_messages = state.get("prompt") or state.get("messages")
161
+ if not isinstance(initial_messages, list):
162
+ raise KeyError("multi-turn rollout state must include prompt or messages")
163
+ messages = [dict(m) for m in initial_messages]
164
+ prompt_ids = render(messages, True)
165
+ cur_ids = list(prompt_ids) # invariant: cur_ids == prompt_ids + completion_ids so far
166
+ # Per-rollout completion cap so prompt + accumulated completion never exceeds the colocate
167
+ # engine's context (which would overflow the next generate()); leave a small margin.
168
+ token_budget = (engine_max_len - len(prompt_ids) - 8) if engine_max_len else None
169
+ completion_ids: list[int] = []
170
+ logprobs: list[float] = []
171
+ env_mask: list[int] = []
172
+
173
+ turns = 0
174
+ while True:
175
+ # Bound THIS turn's generation by the remaining engine headroom so even a single
176
+ # generate() can't push prompt+completion past the context (the cap below stops the
177
+ # loop AFTER a turn; this stops the turn itself from overflowing).
178
+ max_new = per_turn_max_tokens
179
+ if token_budget is not None:
180
+ remaining = token_budget - len(completion_ids)
181
+ if remaining <= 0:
182
+ break
183
+ max_new = min(max_new, remaining)
184
+ asst_ids, asst_lp, text = generate(cur_ids, max_new)
185
+ completion_ids.extend(asst_ids)
186
+ logprobs.extend(asst_lp)
187
+ env_mask.extend([1] * len(asst_ids))
188
+ cur_ids.extend(asst_ids)
189
+ active_env.record_model_turn(state, text)
190
+ messages.append({"role": "assistant", "content": text})
191
+ turns += 1
192
+
193
+ if token_budget is not None and len(completion_ids) >= token_budget:
194
+ break
195
+ if turns >= max_turns or active_env.rollout_done(state, max_turns):
196
+ break
197
+ env_msgs = active_env.env_reply(messages, state)
198
+ if not env_msgs:
199
+ break
200
+ messages.extend(env_msgs)
201
+ # If the env step finished the episode (it can set done / hit its budget while replying),
202
+ # stop here: do NOT append the next-generation glue — there is no next model turn, and the
203
+ # glue would leave a trailing assistant prompt in completion_ids (and could trigger one
204
+ # more generate()).
205
+ if active_env.rollout_done(state, max_turns):
206
+ break
207
+
208
+ # Env-segment tokens = close the just-finished assistant turn + render the env reply +
209
+ # open the next generation prompt, computed INCREMENTALLY (env_glue) rather than by
210
+ # re-rendering the whole conversation. Masked (0) — they are not the policy's tokens —
211
+ # but kept in completion_ids so the next turn conditions on them. Building the sequence
212
+ # by id-concatenation (model ids + glue) keeps it token-aligned even for templates that
213
+ # don't round-trip history (Qwen3's empty <think> block), which the old re-render +
214
+ # prefix-check could not handle.
215
+ glue = env_glue(env_msgs)
216
+ # Don't append glue that would push prompt+completion past the engine budget (the next
217
+ # generate() would be skipped anyway); end the rollout cleanly instead of returning an
218
+ # over-length sequence that could break the trainer's forward/loss pass.
219
+ if token_budget is not None and len(completion_ids) + len(glue) > token_budget:
220
+ break
221
+ completion_ids.extend(glue)
222
+ logprobs.extend([0.0] * len(glue))
223
+ env_mask.extend([0] * len(glue))
224
+ cur_ids.extend(glue)
225
+
226
+ # Score with the ACTUAL rollout state (not a fresh one) so reward funcs see the tool/env
227
+ # state the rollout accumulated. state["completion"] holds the full transcript.
228
+ reward = active_env.reward("", example, state)
229
+ return {
230
+ "prompt_ids": prompt_ids,
231
+ "completion_ids": completion_ids,
232
+ "logprobs": logprobs,
233
+ "env_mask": env_mask,
234
+ "reward": float(reward),
235
+ }
236
+
237
+
238
+ class _RolloutState:
239
+ """Mutable per-rollout accumulator for the continuous-batched rollout (:func:`rollout_async`).
240
+
241
+ Holds exactly the running fields :func:`rollout_one` keeps in locals, so the two paths produce
242
+ byte-identical token alignment / env_mask / reward — the only difference is that the async path
243
+ advances rollouts' turns as independent, continuously-batched engine requests.
244
+ """
245
+
246
+ __slots__ = (
247
+ "budget",
248
+ "completion_ids",
249
+ "cur_ids",
250
+ "done",
251
+ "env_mask",
252
+ "example",
253
+ "logprobs",
254
+ "messages",
255
+ "prompt_ids",
256
+ "state",
257
+ "turns",
258
+ )
259
+
260
+ def __init__(self, example, messages, prompt_ids, state, budget):
261
+ self.example = example
262
+ self.messages = messages
263
+ self.prompt_ids = prompt_ids
264
+ self.cur_ids = list(prompt_ids) # invariant: cur_ids == prompt_ids + completion_ids so far
265
+ self.completion_ids: list[int] = []
266
+ self.logprobs: list[float] = []
267
+ self.env_mask: list[int] = []
268
+ self.state = state
269
+ self.turns = 0
270
+ self.budget = budget # max completion tokens (engine headroom), or None
271
+ self.done = False
272
+
273
+ def result(self, reward: float) -> RolloutResult:
274
+ return {
275
+ "prompt_ids": self.prompt_ids,
276
+ "completion_ids": self.completion_ids,
277
+ "logprobs": self.logprobs,
278
+ "env_mask": self.env_mask,
279
+ "reward": float(reward),
280
+ }
281
+
282
+
283
+ def _advance_after_turn(
284
+ r: _RolloutState,
285
+ asst_ids: list[int],
286
+ asst_lp: list[float],
287
+ text: str,
288
+ *,
289
+ active_env,
290
+ env_glue: Callable[[list], list[int]],
291
+ max_turns: int,
292
+ ) -> None:
293
+ """Fold one freshly-sampled assistant turn into rollout ``r`` and run its env step, mirroring the
294
+ body of :func:`rollout_one`'s loop EXACTLY. Sets ``r.done`` when the rollout should stop. Used by
295
+ :func:`rollout_async` so the continuous-batched and single-rollout paths can never drift."""
296
+ r.completion_ids.extend(asst_ids)
297
+ r.logprobs.extend(asst_lp)
298
+ r.env_mask.extend([1] * len(asst_ids))
299
+ r.cur_ids.extend(asst_ids)
300
+ active_env.record_model_turn(r.state, text)
301
+ r.messages.append({"role": "assistant", "content": text})
302
+ r.turns += 1
303
+ if r.budget is not None and len(r.completion_ids) >= r.budget:
304
+ r.done = True
305
+ return
306
+ if r.turns >= max_turns or active_env.rollout_done(r.state, max_turns):
307
+ r.done = True
308
+ return
309
+ env_msgs = active_env.env_reply(r.messages, r.state)
310
+ if not env_msgs:
311
+ r.done = True
312
+ return
313
+ r.messages.extend(env_msgs)
314
+ if active_env.rollout_done(r.state, max_turns):
315
+ r.done = True
316
+ return
317
+ glue = env_glue(env_msgs)
318
+ if r.budget is not None and len(r.completion_ids) + len(glue) > r.budget:
319
+ r.done = True
320
+ return
321
+ r.completion_ids.extend(glue)
322
+ r.logprobs.extend([0.0] * len(glue))
323
+ r.env_mask.extend([0] * len(glue))
324
+ r.cur_ids.extend(glue)
325
+
326
+
327
+ def _build_rollout_states(
328
+ examples: list[dict],
329
+ active_env,
330
+ render: Callable[[list, bool], list[int]],
331
+ engine_max_len: int | None,
332
+ ) -> list[_RolloutState]:
333
+ """Initialise one :class:`_RolloutState` per example (initial prompt rendered, engine budget
334
+ computed) for :func:`rollout_async`, starting from the same state :func:`rollout_one` builds
335
+ inline so the two paths stay byte-identical."""
336
+ rollouts: list[_RolloutState] = []
337
+ for example in examples:
338
+ state = active_env.new_rollout_state(example)
339
+ initial_messages = state.get("prompt") or state.get("messages")
340
+ if not isinstance(initial_messages, list):
341
+ raise KeyError("multi-turn rollout state must include prompt or messages")
342
+ messages = [dict(m) for m in initial_messages]
343
+ prompt_ids = render(messages, True)
344
+ budget = (engine_max_len - len(prompt_ids) - 8) if engine_max_len else None
345
+ rollouts.append(_RolloutState(example, messages, prompt_ids, state, budget))
346
+ return rollouts
347
+
348
+
349
+ def _turn_budget(r: _RolloutState, per_turn_max_tokens: int) -> int | None:
350
+ """Max new tokens for ``r``'s next assistant turn, bounded by the remaining engine headroom so
351
+ prompt+completion can't overflow the context. Returns ``None`` and marks ``r.done`` when the
352
+ headroom is already exhausted. Identical cap for both rollout paths (no drift)."""
353
+ max_new = per_turn_max_tokens
354
+ if r.budget is not None:
355
+ remaining = r.budget - len(r.completion_ids)
356
+ if remaining <= 0: # prompt already fills the context -> this rollout is done
357
+ r.done = True
358
+ return None
359
+ max_new = min(max_new, remaining)
360
+ return max(1, max_new)
361
+
362
+
363
+ def _score_rollouts(active_env, rollouts: list[_RolloutState]) -> list[float]:
364
+ """Reward for each rollout, in order. Uses ``active_env.reward_many`` when the env provides it
365
+ (one batched, env-concurrent scoring call per task instead of a blocking call per rollout — the
366
+ win for judge/expensive-reward envs). Otherwise falls back to per-rollout ``active_env.reward()``,
367
+ run CONCURRENTLY in a thread pool when the env declares its reward thread-safe (PR #224) so an
368
+ IO-bound judge/tool reward still overlaps instead of N serial GPU-idle round-trips. Every path
369
+ reassembles in INPUT ORDER and yields identical values — only scoring concurrency differs."""
370
+ reward_many = getattr(active_env, "reward_many", None)
371
+ if callable(reward_many):
372
+ rewards = reward_many([(r.example, r.state) for r in rollouts])
373
+ if len(rewards) != len(rollouts):
374
+ raise RuntimeError("env.reward_many returned the wrong number of rewards")
375
+ return [float(x) for x in rewards]
376
+
377
+ def _score(r: _RolloutState) -> float:
378
+ return float(active_env.reward("", r.example, r.state))
379
+
380
+ # Serial for a single rollout, or when the env declares its reward NOT thread-safe (a scorer that
381
+ # keeps mutable state or a thread-bound client) — it worked serially and must not be raced.
382
+ if len(rollouts) <= 1 or not getattr(active_env, "reward_thread_safe", True):
383
+ return [_score(r) for r in rollouts]
384
+ # Concurrent. On the first reward error, cancel not-yet-started scorers and drain in-flight ones
385
+ # so a failed step spends no further judge/API calls and leaves no scorer running into the next.
386
+ pool = ThreadPoolExecutor(max_workers=min(16, len(rollouts)))
387
+ try:
388
+ futures = {pool.submit(_score, r): i for i, r in enumerate(rollouts)}
389
+ scores: list[float] = [0.0] * len(rollouts)
390
+ for fut in as_completed(futures):
391
+ scores[futures[fut]] = fut.result() # re-raises the first failed scorer
392
+ finally:
393
+ pool.shutdown(wait=True, cancel_futures=True)
394
+ return scores
395
+
396
+
397
+ def rollout_async(
398
+ *,
399
+ examples: list[dict],
400
+ active_env,
401
+ render: Callable[[list, bool], list[int]],
402
+ submit: Callable[[str, list[int], int, bool], None],
403
+ poll: Callable[[], list[tuple[str, list[int], list[float], str]]],
404
+ busy: Callable[[], bool],
405
+ env_glue: Callable[[list], list[int]],
406
+ max_turns: int,
407
+ per_turn_max_tokens: int,
408
+ engine_max_len: int | None = None,
409
+ ) -> list[RolloutResult]:
410
+ """Run ``len(examples)`` multi-turn rollouts with CONTINUOUS-BATCHED generation (no turn barrier).
411
+
412
+ Same result as one :func:`rollout_one` per example — identical token alignment, env_mask,
413
+ per-rollout reward and input order — but rollouts are NOT advanced in lockstep. Each rollout's
414
+ assistant turn is an independent engine request; the moment one finishes, its env step runs and
415
+ its NEXT turn is submitted, so the decode batch stays full instead of stalling at a turn boundary
416
+ while the slowest rollout's turn (and then every rollout's env reply) completes. For high-variance
417
+ multi-turn (rollouts of very different depths) this keeps the GPU busy across the many turn
418
+ boundaries a turn-synchronized rollout would idle at.
419
+
420
+ The work is split across two threads so the per-turn ENV work (env reply + glue render — the
421
+ overhead that bounds an otherwise GPU-light rollout) overlaps the GPU decode instead of blocking
422
+ it: the MAIN thread owns the engine (submit / poll / busy) and a single WORKER thread owns the
423
+ env (``_advance_after_turn``). vLLM's ``step()`` runs the model forward in CUDA with the GIL
424
+ released, so the worker advances finished turns DURING the decode of the still-running ones. The
425
+ two threads share no mutable state — only thread-safe queues, and each next-turn prefix is handed
426
+ over as a copy — so the env (not thread-safe) is touched by exactly one thread and the engine by
427
+ exactly one thread. Results are byte-identical to one :func:`rollout_one` per example (a rollout
428
+ keeps at most one request in flight, so its turns stay strictly sequential), in input order.
429
+
430
+ The engine is injected as three callables so the loop is unit-testable on CPU:
431
+ * ``submit(req_id, prefix_ids, max_tokens, initial)`` — enqueue one assistant-turn request
432
+ (``initial`` marks a turn-0 prompt, the only externally-rendered ids worth bounds-checking);
433
+ * ``poll()`` — return ``(req_id, token_ids, logprobs, text)`` for every request that FINISHED
434
+ since the last call (``[]`` if none finished this step);
435
+ * ``busy()`` — whether any request is still in flight.
436
+ """
437
+ rollouts = _build_rollout_states(examples, active_env, render, engine_max_len)
438
+ by_id: dict[str, _RolloutState] = {}
439
+ counter = 0
440
+ to_env: queue.Queue = queue.Queue() # main -> worker: finished turns to fold + run the env step
441
+ to_submit: queue.Queue = queue.Queue() # worker -> main: ("next", r, prefix, max_new) | ("done", r)
442
+
443
+ def do_submit(r: _RolloutState, prefix: list[int], max_new: int, initial: bool) -> None:
444
+ nonlocal counter
445
+ req_id = f"r{counter}"
446
+ counter += 1
447
+ by_id[req_id] = r
448
+ submit(req_id, prefix, max_new, initial)
449
+
450
+ def env_worker() -> None:
451
+ # Owns the env: fold each finished turn, run its env step (env reply + glue render), and hand
452
+ # the next-turn prefix (a copy) back to the main thread — or signal the rollout is done. An
453
+ # env/template error here must propagate to the main thread (which owns the engine), not die
454
+ # silently in this thread and hang the main loop waiting for a result that never comes.
455
+ while True:
456
+ item = to_env.get()
457
+ if item is None:
458
+ return
459
+ r, asst_ids, asst_lp, text = item
460
+ try:
461
+ _advance_after_turn(
462
+ r, asst_ids, asst_lp, text,
463
+ active_env=active_env, env_glue=env_glue, max_turns=max_turns,
464
+ )
465
+ max_new = None if r.done else _turn_budget(r, per_turn_max_tokens)
466
+ except Exception as exc: # surfaced + re-raised on the main thread (engine owner)
467
+ to_submit.put(("error", exc))
468
+ return
469
+ to_submit.put(("done", r) if max_new is None else ("next", r, list(r.cur_ids), max_new))
470
+
471
+ worker = threading.Thread(target=env_worker, daemon=True)
472
+ worker.start()
473
+ n = len(rollouts)
474
+ completed = 0
475
+
476
+ def take(msg) -> None:
477
+ nonlocal completed
478
+ if msg[0] == "error":
479
+ # Re-raise the worker's env/template error on the main thread (the engine owner),
480
+ # preserving the ORIGINAL worker traceback so the stack points at the real failing line.
481
+ err = msg[1]
482
+ raise err.with_traceback(err.__traceback__)
483
+ if msg[0] == "done":
484
+ completed += 1
485
+ else:
486
+ _, r, prefix, max_new = msg
487
+ do_submit(r, prefix, max_new, False)
488
+
489
+ try:
490
+ for r in rollouts: # prime turn 0 on the main thread
491
+ max_new = _turn_budget(r, per_turn_max_tokens)
492
+ if max_new is None:
493
+ completed += 1
494
+ else:
495
+ do_submit(r, list(r.cur_ids), max_new, r.turns == 0)
496
+ while completed < n:
497
+ progressed = False
498
+ while True: # submit every next-turn the worker has produced (and count finished ones)
499
+ try:
500
+ take(to_submit.get_nowait())
501
+ progressed = True
502
+ except queue.Empty:
503
+ break
504
+ if completed >= n:
505
+ break
506
+ if busy(): # step the engine; hand finished turns to the worker (overlaps its env work)
507
+ for req_id, asst_ids, asst_lp, text in poll():
508
+ to_env.put((by_id.pop(req_id), asst_ids, asst_lp, text))
509
+ elif not progressed:
510
+ # nothing in flight and nothing newly ready: the worker is mid-advance — block on its
511
+ # next output instead of spinning (every rollout is in exactly one stage, so this
512
+ # can't deadlock: the only state with all queues + in-flight empty is all-done).
513
+ with contextlib.suppress(queue.Empty):
514
+ take(to_submit.get(timeout=0.1))
515
+ finally:
516
+ to_env.put(None)
517
+ worker.join()
518
+
519
+ # Score with the ACTUAL accumulated rollout state (matches rollout_one), batched per task.
520
+ rewards = _score_rollouts(active_env, rollouts)
521
+ return [r.result(rw) for r, rw in zip(rollouts, rewards, strict=True)]
522
+
523
+
524
+ def render_message_ids(tok, messages, add_generation_prompt: bool, *, thinking: bool) -> list[int]:
525
+ """Render ``messages`` with the chat template, then tokenize to a flat ``list[int]``.
526
+
527
+ Render to text first, then tokenize — the return shape of apply_chat_template(tokenize=True)
528
+ varies by tokenizer, whereas tok(text).input_ids is reliably a flat list[int] (matches the
529
+ single-turn render_prompt path). add_special_tokens=False because the template already
530
+ emits the special tokens. Shared by the GRPO rollout closure and mid-run eval so both
531
+ produce identical token alignment.
532
+ """
533
+ text = tok.apply_chat_template(
534
+ messages,
535
+ add_generation_prompt=add_generation_prompt,
536
+ tokenize=False,
537
+ enable_thinking=thinking,
538
+ )
539
+ return [int(t) for t in tok(text, add_special_tokens=False).input_ids]
540
+
541
+
542
+ def _engine_vocab_size(engine) -> int | None:
543
+ """Best-effort vocab size of the colocate vLLM engine, or None if it can't be read.
544
+
545
+ Used only for a cheap fail-loud bounds check on the pre-tokenized prompt ids before they
546
+ reach ``engine.generate`` (the ``prompt_token_ids`` path does no bounds checking, so an
547
+ out-of-range id would otherwise surface as an opaque CUDA illegal-access). Never raises.
548
+ """
549
+ try:
550
+ mc = engine.llm_engine.model_config
551
+ except Exception:
552
+ return None
553
+ for attr in ("get_vocab_size", "get_hf_config_vocab_size"):
554
+ getter = getattr(mc, attr, None)
555
+ if callable(getter):
556
+ try:
557
+ return int(getter())
558
+ except Exception:
559
+ pass
560
+ try:
561
+ return int(mc.hf_text_config.vocab_size)
562
+ except Exception:
563
+ return None
564
+
565
+
566
+ def build_rollout_func(
567
+ *,
568
+ active_env,
569
+ tok,
570
+ examples_by_key: dict,
571
+ max_completion: int,
572
+ max_turns: int,
573
+ temperature: float,
574
+ top_p: float,
575
+ stop: list[str] | None,
576
+ thinking: bool,
577
+ engine_max_len: int | None = None,
578
+ ):
579
+ """Return a TRL ``rollout_func`` closure that drives ``active_env`` on the colocate engine.
580
+
581
+ The closure reaches the in-process vLLM engine through ``trainer.vllm_generation.llm`` and
582
+ samples each assistant turn with per-token logprobs. It returns exactly ONE rollout per
583
+ prompt in the slice TRL passes: TRL's ``RepeatSampler`` already repeats each unique prompt
584
+ ``num_generations`` times before calling ``rollout_func`` (the consecutive repeats form the
585
+ GRPO group), so the closure must NOT multiply by ``num_generations`` again.
586
+ """
587
+ from vllm import SamplingParams # gpu-only; imported lazily so the module loads on CPU
588
+
589
+ try:
590
+ # FINAL_ONLY makes each manual add_request emit exactly one RequestOutput, at finish, with
591
+ # the complete turn (matching LLM.generate); without it the engine streams a cumulative
592
+ # output every step. Optional so the CPU import (stubbed vllm) still works — poll() filters
593
+ # on `finished` either way.
594
+ from vllm.sampling_params import RequestOutputKind
595
+
596
+ _final_only_kind = RequestOutputKind.FINAL_ONLY
597
+ except Exception:
598
+ _final_only_kind = None
599
+
600
+ _render_cache = _LRUCache(8192)
601
+
602
+ def render(messages: list, add_generation_prompt: bool) -> list[int]:
603
+ # The initial-prompt render is identical for every rollout in a GRPO group (they share one
604
+ # prompt), so cache it by content instead of re-rendering num_generations times per step.
605
+ # LRU-bounded: when full it EVICTS the least-recently-used entry rather than freezing, so a
606
+ # long run with many distinct prompts keeps caching the recently-seen ones (a freeze-when-full
607
+ # cache would stop admitting any new prompt after the cap and re-render them forever).
608
+ cache_key = f"{add_generation_prompt}\x00{json.dumps(messages, sort_keys=True, default=str)}"
609
+ cached = _render_cache.get(cache_key)
610
+ if cached is not None:
611
+ return cached
612
+ ids = render_message_ids(tok, messages, add_generation_prompt, thinking=thinking)
613
+ _render_cache.put(cache_key, ids)
614
+ return ids
615
+
616
+ _glue_cache = _LRUCache(8192)
617
+
618
+ def env_glue(env_messages: list) -> list[int]:
619
+ # The inter-turn glue is a pure function of env_messages (+ this closure's tokenizer /
620
+ # thinking). Within a GRPO group every rollout gets the SAME env reply each turn, and many
621
+ # turns repeat env messages across rollouts and steps, so apply_chat_template would
622
+ # otherwise re-render byte-identical glue dozens-to-hundreds of times — the dominant per-turn
623
+ # CPU cost in the (otherwise overhead-bound) multi-turn rollout. Cache by env-message
624
+ # content; LRU-bounded so an env whose every reply is unique can't grow it without limit and,
625
+ # unlike a freeze-when-full cache, recently-seen glue stays cached over a long diverse run.
626
+ cache_key = json.dumps(env_messages, sort_keys=True, default=str)
627
+ cached = _glue_cache.get(cache_key)
628
+ if cached is not None:
629
+ return cached
630
+ # Tokens between two assistant turns: close the previous assistant turn, render the env
631
+ # reply message(s), and open the next generation prompt. Derived by rendering a probe
632
+ # assistant turn followed by the env messages (+ generation prompt) and taking everything
633
+ # AFTER the probe content — so the glue is exactly the template's inter-turn wrapper,
634
+ # whatever it is (Qwen's <|im_end|> + user turn + <|im_start|>assistant + <think> block).
635
+ # This avoids re-rendering history (which Qwen3 does not round-trip) and matches how the
636
+ # model actually conditioned during generation. The probe is plain text the template
637
+ # inserts verbatim into assistant content; its FIRST occurrence is the probe turn.
638
+ probe = "flash-env-glue-probe"
639
+ text = tok.apply_chat_template(
640
+ [{"role": "assistant", "content": probe}, *env_messages],
641
+ add_generation_prompt=True,
642
+ tokenize=False,
643
+ enable_thinking=thinking,
644
+ )
645
+ # Locate the probe to slice off the inter-turn glue. Fail LOUD with context if the
646
+ # template did not insert the assistant content verbatim (some templates strip/escape it,
647
+ # or could emit the probe more than once) instead of a bare "substring not found".
648
+ first = text.find(probe)
649
+ if first == -1 or text.find(probe, first + len(probe)) != -1:
650
+ raise ValueError(
651
+ "multi-turn env_glue could not uniquely locate its probe in the rendered chat "
652
+ "template; this model's template does not insert assistant content verbatim, so "
653
+ "token-aligned multi-turn rollout is unsupported for it (use a single-turn/tool "
654
+ "env or a different model)."
655
+ )
656
+ glue_text = text[first + len(probe) :]
657
+ glue = [int(t) for t in tok(glue_text, add_special_tokens=False).input_ids]
658
+ _glue_cache.put(cache_key, glue)
659
+ return glue
660
+
661
+ def rollout_func(prompts, trainer):
662
+ engine = trainer.vllm_generation.llm
663
+ # The colocate engine is a vLLM `LLM`; its V1 `LLMEngine` exposes the public
664
+ # add_request / step / has_unfinished_requests loop that lets us decode many rollouts'
665
+ # turns CONTINUOUSLY (a finished turn's slot refills with another rollout's next turn)
666
+ # instead of one synchronized batched decode per turn.
667
+ llm_engine = engine.llm_engine
668
+ # Colocate vLLM sleep mode (GRPOConfig.vllm_enable_sleep_mode, ON for large / long-context
669
+ # runs) offloads BOTH the rollout engine's weights and its KV cache between steps. TRL's
670
+ # rollout_func path (GRPOTrainer._generate) calls vllm_generation.sync_weights() — which
671
+ # wakes only tags=["weights"] — and then hands control to this closure, but, UNLIKE TRL's
672
+ # own single-turn generate() path, it never wakes tags=["kv_cache"]. So the first decode
673
+ # below would run against a freed/offloaded KV cache and fault with CUDA "illegal memory
674
+ # access" on step 0. Wake the KV cache here and re-sleep after the whole batch, mirroring
675
+ # trl.generation.vllm_generation.generate (and trl.experimental.openenv). No-op when sleep
676
+ # mode is off (small/short-context runs keep the engine resident). See flash issue #162.
677
+ sleep_mode = bool(getattr(getattr(trainer, "args", None), "vllm_enable_sleep_mode", False))
678
+ vocab_size = _engine_vocab_size(engine)
679
+ active_ids: set[str] = set() # submitted-but-not-finished requests, for abort-on-exit
680
+
681
+ def submit(req_id: str, prefix_ids: list[int], max_tokens: int, initial: bool) -> None:
682
+ """Enqueue one assistant-turn request on the colocate engine."""
683
+ if not prefix_ids:
684
+ # Fail loudly on a degenerate prompt instead of letting it reach the embedding gather
685
+ # as an opaque async CUDA illegal-access (the failure mode #162 was first mistaken
686
+ # for): the prompt_token_ids path does no bounds checking.
687
+ raise ValueError("multi-turn rollout produced an empty prompt for engine.add_request()")
688
+ if initial:
689
+ # Turn-0 prefixes are the only externally-rendered initial prompts (later turns are
690
+ # vLLM-generated / tokenizer glue, already in range); validate each, since the
691
+ # prompt_token_ids path does no bounds checking and an out-of-range id would surface
692
+ # as an opaque CUDA illegal-access.
693
+ lo, hi = min(prefix_ids), max(prefix_ids)
694
+ if lo < 0 or (vocab_size is not None and hi >= vocab_size):
695
+ raise ValueError(
696
+ f"multi-turn rollout prompt has out-of-range token id(s) [{lo}, {hi}] for "
697
+ f"vocab size {vocab_size} (tokenizer/model mismatch)"
698
+ )
699
+ sp_kwargs = {
700
+ "max_tokens": max(1, int(max_tokens)),
701
+ "temperature": temperature,
702
+ "top_p": top_p,
703
+ "logprobs": 1, # include the sampled token's logprob at each position
704
+ "stop": list(stop) if stop else None,
705
+ }
706
+ if _final_only_kind is not None:
707
+ sp_kwargs["output_kind"] = _final_only_kind
708
+ llm_engine.add_request(
709
+ req_id, {"prompt_token_ids": list(prefix_ids)}, SamplingParams(**sp_kwargs)
710
+ )
711
+ active_ids.add(req_id)
712
+
713
+ def poll() -> list[tuple[str, list[int], list[float], str]]:
714
+ """Advance the engine one step; return (req_id, token_ids, logprobs, text) for every
715
+ request that finished this step (``[]`` if none did / a dummy batch ran)."""
716
+ finished: list[tuple[str, list[int], list[float], str]] = []
717
+ for out in llm_engine.step():
718
+ if not getattr(out, "finished", False):
719
+ continue
720
+ comp = out.outputs[0]
721
+ token_ids = list(comp.token_ids)
722
+ # comp.logprobs is a list (per position) of {token_id: Logprob}; pull the sampled
723
+ # token's logprob at each position.
724
+ lps: list[float] = []
725
+ for pos, tid in enumerate(token_ids):
726
+ entry = (comp.logprobs or [])[pos] if comp.logprobs else None
727
+ lp = entry.get(tid) if entry else None
728
+ lps.append(float(getattr(lp, "logprob", 0.0)) if lp is not None else 0.0)
729
+ active_ids.discard(out.request_id)
730
+ finished.append((out.request_id, token_ids, lps, comp.text))
731
+ return finished
732
+
733
+ def busy() -> bool:
734
+ return bool(llm_engine.has_unfinished_requests())
735
+
736
+ # Wake the KV cache for the whole batch (see the note above), then re-sleep so the engine
737
+ # returns to its fully-offloaded state and the optimizer step has the freed memory back.
738
+ # `woke` is set AFTER a successful wake so the finally re-sleeps ONLY when we actually woke
739
+ # the engine — a wake_up() that raises leaves the engine asleep (its resting state), and we
740
+ # must not then call sleep() on it; a failure DURING the rollout still re-sleeps.
741
+ woke = False
742
+ try:
743
+ if sleep_mode:
744
+ engine.wake_up(tags=["kv_cache"])
745
+ woke = True
746
+ # ONE rollout per prompt: TRL's RepeatSampler already repeats each unique prompt
747
+ # num_generations times BEFORE handing the slice to rollout_func (trl 1.6/1.7:
748
+ # `prompts = [x["prompt"] for x in inputs]`, no dedup), and it expects exactly
749
+ # len(prompts) completions back — the GRPO group is the consecutive num_generations rows
750
+ # of the same prompt. rollout_async returns one result per example in input order, so
751
+ # the group stays aligned.
752
+ examples = [examples_by_key.get(_prompt_key(p), {"prompt": p}) for p in prompts]
753
+ rollouts = rollout_async(
754
+ examples=examples,
755
+ active_env=active_env,
756
+ render=render,
757
+ submit=submit,
758
+ poll=poll,
759
+ busy=busy,
760
+ env_glue=env_glue,
761
+ max_turns=max_turns,
762
+ per_turn_max_tokens=max_completion,
763
+ engine_max_len=engine_max_len,
764
+ )
765
+ out: dict[str, list] = {k: [] for k in _ROLLOUT_FIELDS}
766
+ for r in rollouts:
767
+ for k in out:
768
+ out[k].append(r[k])
769
+ return out
770
+ finally:
771
+ # Abort any still-in-flight requests so a mid-rollout error (e.g. an env_glue/template
772
+ # failure on a later turn) can't leak live requests into the engine and corrupt the
773
+ # next GRPO step. No-op on the success path (every request finished -> active_ids empty).
774
+ if active_ids:
775
+ with contextlib.suppress(Exception):
776
+ llm_engine.abort_request(list(active_ids))
777
+ if woke:
778
+ engine.sleep(level=2)
779
+
780
+ return rollout_func