freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,780 @@
|
|
|
1
|
+
"""Multi-turn / tool GRPO rollout for TRL's experimental ``rollout_func`` (colocate vLLM).
|
|
2
|
+
|
|
3
|
+
TRL's ``GRPOTrainer`` generates a single assistant turn per prompt, which cannot drive a
|
|
4
|
+
Freesolo ``EnvironmentMultiTurn`` turn loop (model turn -> env reply -> ...). This
|
|
5
|
+
module supplies a ``rollout_func`` that:
|
|
6
|
+
|
|
7
|
+
* drives the env's turn loop via the adapter helpers (``new_rollout_state`` /
|
|
8
|
+
``record_model_turn`` / ``env_reply`` / ``rollout_done``), so the *env* owns tool
|
|
9
|
+
execution, ``StatefulToolEnv`` state threading, and any simulated-user turns;
|
|
10
|
+
* returns the FULL interleaved token sequence as ``completion_ids`` together with an
|
|
11
|
+
``env_mask`` that marks model-generated tokens (``1``, trained) vs tool/env tokens
|
|
12
|
+
(``0``, masked out of the loss). ``env_mask`` is TRL's documented mechanism for
|
|
13
|
+
multi-turn credit assignment (it is treated internally as the tool mask), so only the
|
|
14
|
+
policy's own tokens get advantage while the env tokens still provide context for the
|
|
15
|
+
forward pass;
|
|
16
|
+
* scores each rollout with the environment reward (``reward_from_messages``) and returns
|
|
17
|
+
it as an extra field consumed by a pass-through ``reward_func``.
|
|
18
|
+
|
|
19
|
+
Token alignment assumes a **prefix-preserving** chat template: appending a message must not
|
|
20
|
+
retokenize earlier turns (the same assumption TRL's native tool loop documents; auto-patched
|
|
21
|
+
for Qwen3 / DeepSeek-V3). The env segment between two model turns is taken as the suffix of a
|
|
22
|
+
full re-render; if the prefix invariant is violated the rollout raises (fails loudly) rather
|
|
23
|
+
than mis-masking model vs env tokens and silently mistraining.
|
|
24
|
+
|
|
25
|
+
The core (:func:`rollout_one`) is pure Python and takes injected ``render``/``generate``
|
|
26
|
+
callables so it can be unit-tested without a GPU/tokenizer; :func:`build_rollout_func` wires
|
|
27
|
+
the real tokenizer + the colocate vLLM engine into it at runtime.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from __future__ import annotations
|
|
31
|
+
|
|
32
|
+
import contextlib
|
|
33
|
+
import json
|
|
34
|
+
import queue
|
|
35
|
+
import threading
|
|
36
|
+
from collections import OrderedDict
|
|
37
|
+
from collections.abc import Callable
|
|
38
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
39
|
+
from typing import TypedDict
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class RolloutResult(TypedDict):
|
|
43
|
+
"""Token-aligned fields returned per rollout for TRL's ``rollout_func``."""
|
|
44
|
+
|
|
45
|
+
prompt_ids: list[int]
|
|
46
|
+
completion_ids: list[int]
|
|
47
|
+
logprobs: list[float]
|
|
48
|
+
env_mask: list[int]
|
|
49
|
+
reward: float
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# Field names shared between a single RolloutResult and the batched dict-of-lists that
|
|
53
|
+
# build_rollout_func returns. Kept as a plain tuple (not RolloutResult.__annotations__) so
|
|
54
|
+
# the batch accumulator's key source isn't a single-rollout type whose value types (float,
|
|
55
|
+
# list[int], ...) deliberately differ from the accumulator's list-of-those.
|
|
56
|
+
_ROLLOUT_FIELDS: tuple[str, ...] = (
|
|
57
|
+
"prompt_ids",
|
|
58
|
+
"completion_ids",
|
|
59
|
+
"logprobs",
|
|
60
|
+
"env_mask",
|
|
61
|
+
"reward",
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _prompt_key(prompt) -> str:
|
|
66
|
+
"""Stable key for mapping a dataset ``prompt`` value back to its example row."""
|
|
67
|
+
try:
|
|
68
|
+
return json.dumps(prompt, sort_keys=True, default=str)
|
|
69
|
+
except (TypeError, ValueError):
|
|
70
|
+
return str(prompt)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class _LRUCache:
|
|
74
|
+
"""Tiny bounded LRU cache (string key -> ``list[int]``) for the render / env_glue closures.
|
|
75
|
+
|
|
76
|
+
A plain ``len(d) < MAX`` guard FREEZES the cache once full: any new key after the cap is never
|
|
77
|
+
admitted, so later-repeated-but-diverse prompts/glue re-render forever and the cache stops paying
|
|
78
|
+
off over a long run. This evicts the least-recently-used entry on insert-when-full instead, so a
|
|
79
|
+
fixed-size window of recently-seen keys stays cached no matter how many distinct keys appear.
|
|
80
|
+
Recency is updated on every hit (``move_to_end``); not thread-safe (each cache is owned by a
|
|
81
|
+
single closure called from one thread).
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
__slots__ = ("_data", "maxsize")
|
|
85
|
+
|
|
86
|
+
def __init__(self, maxsize: int):
|
|
87
|
+
if maxsize <= 0:
|
|
88
|
+
raise ValueError("LRU cache maxsize must be positive")
|
|
89
|
+
self.maxsize = maxsize
|
|
90
|
+
self._data: OrderedDict[str, list[int]] = OrderedDict()
|
|
91
|
+
|
|
92
|
+
def get(self, key: str) -> list[int] | None:
|
|
93
|
+
"""Return the cached value and mark it most-recently-used, or None on a miss."""
|
|
94
|
+
value = self._data.get(key)
|
|
95
|
+
if value is not None:
|
|
96
|
+
self._data.move_to_end(key)
|
|
97
|
+
return value
|
|
98
|
+
|
|
99
|
+
def put(self, key: str, value: list[int]) -> None:
|
|
100
|
+
"""Insert/refresh ``key`` as most-recently-used, evicting the oldest entry if at capacity."""
|
|
101
|
+
if key in self._data:
|
|
102
|
+
self._data.move_to_end(key)
|
|
103
|
+
self._data[key] = value
|
|
104
|
+
if len(self._data) > self.maxsize:
|
|
105
|
+
self._data.popitem(last=False) # drop the least-recently-used entry
|
|
106
|
+
|
|
107
|
+
def __len__(self) -> int:
|
|
108
|
+
return len(self._data)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def build_examples_index(rows: list[dict], prompt_of: Callable[[dict], object]) -> dict:
|
|
112
|
+
"""Map each row's rendered ``prompt`` value to the example row (for reward/answer lookup).
|
|
113
|
+
|
|
114
|
+
Collisions (two rows producing the same prompt) keep the last row and are reported by the
|
|
115
|
+
caller via :func:`index_collisions`; duplicates are rare in training data and only affect
|
|
116
|
+
which ``answer``/``info`` a shared prompt scores against.
|
|
117
|
+
"""
|
|
118
|
+
return {_prompt_key(prompt_of(r)): r for r in rows}
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def index_collisions(rows: list[dict], prompt_of: Callable[[dict], object]) -> int:
|
|
122
|
+
"""Number of rows dropped by prompt-key collisions in :func:`build_examples_index`."""
|
|
123
|
+
return len(rows) - len({_prompt_key(prompt_of(r)) for r in rows})
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def rollout_one(
|
|
127
|
+
*,
|
|
128
|
+
example: dict,
|
|
129
|
+
active_env,
|
|
130
|
+
render: Callable[[list, bool], list[int]],
|
|
131
|
+
generate: Callable[[list, int], tuple[list[int], list[float], str]],
|
|
132
|
+
env_glue: Callable[[list], list[int]],
|
|
133
|
+
max_turns: int,
|
|
134
|
+
per_turn_max_tokens: int,
|
|
135
|
+
engine_max_len: int | None = None,
|
|
136
|
+
) -> RolloutResult:
|
|
137
|
+
"""Run one multi-turn/tool rollout and return TRL ``rollout_func`` fields for it.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
example: the dataset row carried into environment scoring.
|
|
141
|
+
active_env: the Freesolo environment adapter (drives the turn loop + scoring).
|
|
142
|
+
render: ``render(messages, add_generation_prompt) -> token_ids`` (chat template) — used
|
|
143
|
+
only for the INITIAL prompt.
|
|
144
|
+
generate: ``generate(prefix_token_ids, max_tokens) -> (token_ids, token_logprobs,
|
|
145
|
+
text)`` for one sampled assistant turn (model tokens + sampling logprobs + text);
|
|
146
|
+
``max_tokens`` bounds that turn so it can't overflow the engine context.
|
|
147
|
+
env_glue: ``env_glue(env_messages) -> token_ids`` — the tokens that CLOSE the
|
|
148
|
+
just-finished assistant turn, render the env reply message(s), and OPEN the next
|
|
149
|
+
generation prompt. The running token sequence is built incrementally from these
|
|
150
|
+
(the model's generated ids + env glue), never by re-rendering the whole
|
|
151
|
+
conversation — so a chat template that does not round-trip prior turns (e.g. Qwen3's
|
|
152
|
+
empty ``<think>`` block, which is injected into the generation prompt but stripped
|
|
153
|
+
from history) stays token-aligned instead of failing the old prefix check.
|
|
154
|
+
max_turns: hard cap on model turns (defense against a non-terminating env).
|
|
155
|
+
|
|
156
|
+
Returns a dict with ``prompt_ids``, ``completion_ids``, ``logprobs``, ``env_mask`` (all
|
|
157
|
+
token-aligned) and the scalar ``reward`` for this rollout.
|
|
158
|
+
"""
|
|
159
|
+
state = active_env.new_rollout_state(example)
|
|
160
|
+
initial_messages = state.get("prompt") or state.get("messages")
|
|
161
|
+
if not isinstance(initial_messages, list):
|
|
162
|
+
raise KeyError("multi-turn rollout state must include prompt or messages")
|
|
163
|
+
messages = [dict(m) for m in initial_messages]
|
|
164
|
+
prompt_ids = render(messages, True)
|
|
165
|
+
cur_ids = list(prompt_ids) # invariant: cur_ids == prompt_ids + completion_ids so far
|
|
166
|
+
# Per-rollout completion cap so prompt + accumulated completion never exceeds the colocate
|
|
167
|
+
# engine's context (which would overflow the next generate()); leave a small margin.
|
|
168
|
+
token_budget = (engine_max_len - len(prompt_ids) - 8) if engine_max_len else None
|
|
169
|
+
completion_ids: list[int] = []
|
|
170
|
+
logprobs: list[float] = []
|
|
171
|
+
env_mask: list[int] = []
|
|
172
|
+
|
|
173
|
+
turns = 0
|
|
174
|
+
while True:
|
|
175
|
+
# Bound THIS turn's generation by the remaining engine headroom so even a single
|
|
176
|
+
# generate() can't push prompt+completion past the context (the cap below stops the
|
|
177
|
+
# loop AFTER a turn; this stops the turn itself from overflowing).
|
|
178
|
+
max_new = per_turn_max_tokens
|
|
179
|
+
if token_budget is not None:
|
|
180
|
+
remaining = token_budget - len(completion_ids)
|
|
181
|
+
if remaining <= 0:
|
|
182
|
+
break
|
|
183
|
+
max_new = min(max_new, remaining)
|
|
184
|
+
asst_ids, asst_lp, text = generate(cur_ids, max_new)
|
|
185
|
+
completion_ids.extend(asst_ids)
|
|
186
|
+
logprobs.extend(asst_lp)
|
|
187
|
+
env_mask.extend([1] * len(asst_ids))
|
|
188
|
+
cur_ids.extend(asst_ids)
|
|
189
|
+
active_env.record_model_turn(state, text)
|
|
190
|
+
messages.append({"role": "assistant", "content": text})
|
|
191
|
+
turns += 1
|
|
192
|
+
|
|
193
|
+
if token_budget is not None and len(completion_ids) >= token_budget:
|
|
194
|
+
break
|
|
195
|
+
if turns >= max_turns or active_env.rollout_done(state, max_turns):
|
|
196
|
+
break
|
|
197
|
+
env_msgs = active_env.env_reply(messages, state)
|
|
198
|
+
if not env_msgs:
|
|
199
|
+
break
|
|
200
|
+
messages.extend(env_msgs)
|
|
201
|
+
# If the env step finished the episode (it can set done / hit its budget while replying),
|
|
202
|
+
# stop here: do NOT append the next-generation glue — there is no next model turn, and the
|
|
203
|
+
# glue would leave a trailing assistant prompt in completion_ids (and could trigger one
|
|
204
|
+
# more generate()).
|
|
205
|
+
if active_env.rollout_done(state, max_turns):
|
|
206
|
+
break
|
|
207
|
+
|
|
208
|
+
# Env-segment tokens = close the just-finished assistant turn + render the env reply +
|
|
209
|
+
# open the next generation prompt, computed INCREMENTALLY (env_glue) rather than by
|
|
210
|
+
# re-rendering the whole conversation. Masked (0) — they are not the policy's tokens —
|
|
211
|
+
# but kept in completion_ids so the next turn conditions on them. Building the sequence
|
|
212
|
+
# by id-concatenation (model ids + glue) keeps it token-aligned even for templates that
|
|
213
|
+
# don't round-trip history (Qwen3's empty <think> block), which the old re-render +
|
|
214
|
+
# prefix-check could not handle.
|
|
215
|
+
glue = env_glue(env_msgs)
|
|
216
|
+
# Don't append glue that would push prompt+completion past the engine budget (the next
|
|
217
|
+
# generate() would be skipped anyway); end the rollout cleanly instead of returning an
|
|
218
|
+
# over-length sequence that could break the trainer's forward/loss pass.
|
|
219
|
+
if token_budget is not None and len(completion_ids) + len(glue) > token_budget:
|
|
220
|
+
break
|
|
221
|
+
completion_ids.extend(glue)
|
|
222
|
+
logprobs.extend([0.0] * len(glue))
|
|
223
|
+
env_mask.extend([0] * len(glue))
|
|
224
|
+
cur_ids.extend(glue)
|
|
225
|
+
|
|
226
|
+
# Score with the ACTUAL rollout state (not a fresh one) so reward funcs see the tool/env
|
|
227
|
+
# state the rollout accumulated. state["completion"] holds the full transcript.
|
|
228
|
+
reward = active_env.reward("", example, state)
|
|
229
|
+
return {
|
|
230
|
+
"prompt_ids": prompt_ids,
|
|
231
|
+
"completion_ids": completion_ids,
|
|
232
|
+
"logprobs": logprobs,
|
|
233
|
+
"env_mask": env_mask,
|
|
234
|
+
"reward": float(reward),
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class _RolloutState:
|
|
239
|
+
"""Mutable per-rollout accumulator for the continuous-batched rollout (:func:`rollout_async`).
|
|
240
|
+
|
|
241
|
+
Holds exactly the running fields :func:`rollout_one` keeps in locals, so the two paths produce
|
|
242
|
+
byte-identical token alignment / env_mask / reward — the only difference is that the async path
|
|
243
|
+
advances rollouts' turns as independent, continuously-batched engine requests.
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
__slots__ = (
|
|
247
|
+
"budget",
|
|
248
|
+
"completion_ids",
|
|
249
|
+
"cur_ids",
|
|
250
|
+
"done",
|
|
251
|
+
"env_mask",
|
|
252
|
+
"example",
|
|
253
|
+
"logprobs",
|
|
254
|
+
"messages",
|
|
255
|
+
"prompt_ids",
|
|
256
|
+
"state",
|
|
257
|
+
"turns",
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
def __init__(self, example, messages, prompt_ids, state, budget):
|
|
261
|
+
self.example = example
|
|
262
|
+
self.messages = messages
|
|
263
|
+
self.prompt_ids = prompt_ids
|
|
264
|
+
self.cur_ids = list(prompt_ids) # invariant: cur_ids == prompt_ids + completion_ids so far
|
|
265
|
+
self.completion_ids: list[int] = []
|
|
266
|
+
self.logprobs: list[float] = []
|
|
267
|
+
self.env_mask: list[int] = []
|
|
268
|
+
self.state = state
|
|
269
|
+
self.turns = 0
|
|
270
|
+
self.budget = budget # max completion tokens (engine headroom), or None
|
|
271
|
+
self.done = False
|
|
272
|
+
|
|
273
|
+
def result(self, reward: float) -> RolloutResult:
|
|
274
|
+
return {
|
|
275
|
+
"prompt_ids": self.prompt_ids,
|
|
276
|
+
"completion_ids": self.completion_ids,
|
|
277
|
+
"logprobs": self.logprobs,
|
|
278
|
+
"env_mask": self.env_mask,
|
|
279
|
+
"reward": float(reward),
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def _advance_after_turn(
|
|
284
|
+
r: _RolloutState,
|
|
285
|
+
asst_ids: list[int],
|
|
286
|
+
asst_lp: list[float],
|
|
287
|
+
text: str,
|
|
288
|
+
*,
|
|
289
|
+
active_env,
|
|
290
|
+
env_glue: Callable[[list], list[int]],
|
|
291
|
+
max_turns: int,
|
|
292
|
+
) -> None:
|
|
293
|
+
"""Fold one freshly-sampled assistant turn into rollout ``r`` and run its env step, mirroring the
|
|
294
|
+
body of :func:`rollout_one`'s loop EXACTLY. Sets ``r.done`` when the rollout should stop. Used by
|
|
295
|
+
:func:`rollout_async` so the continuous-batched and single-rollout paths can never drift."""
|
|
296
|
+
r.completion_ids.extend(asst_ids)
|
|
297
|
+
r.logprobs.extend(asst_lp)
|
|
298
|
+
r.env_mask.extend([1] * len(asst_ids))
|
|
299
|
+
r.cur_ids.extend(asst_ids)
|
|
300
|
+
active_env.record_model_turn(r.state, text)
|
|
301
|
+
r.messages.append({"role": "assistant", "content": text})
|
|
302
|
+
r.turns += 1
|
|
303
|
+
if r.budget is not None and len(r.completion_ids) >= r.budget:
|
|
304
|
+
r.done = True
|
|
305
|
+
return
|
|
306
|
+
if r.turns >= max_turns or active_env.rollout_done(r.state, max_turns):
|
|
307
|
+
r.done = True
|
|
308
|
+
return
|
|
309
|
+
env_msgs = active_env.env_reply(r.messages, r.state)
|
|
310
|
+
if not env_msgs:
|
|
311
|
+
r.done = True
|
|
312
|
+
return
|
|
313
|
+
r.messages.extend(env_msgs)
|
|
314
|
+
if active_env.rollout_done(r.state, max_turns):
|
|
315
|
+
r.done = True
|
|
316
|
+
return
|
|
317
|
+
glue = env_glue(env_msgs)
|
|
318
|
+
if r.budget is not None and len(r.completion_ids) + len(glue) > r.budget:
|
|
319
|
+
r.done = True
|
|
320
|
+
return
|
|
321
|
+
r.completion_ids.extend(glue)
|
|
322
|
+
r.logprobs.extend([0.0] * len(glue))
|
|
323
|
+
r.env_mask.extend([0] * len(glue))
|
|
324
|
+
r.cur_ids.extend(glue)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def _build_rollout_states(
|
|
328
|
+
examples: list[dict],
|
|
329
|
+
active_env,
|
|
330
|
+
render: Callable[[list, bool], list[int]],
|
|
331
|
+
engine_max_len: int | None,
|
|
332
|
+
) -> list[_RolloutState]:
|
|
333
|
+
"""Initialise one :class:`_RolloutState` per example (initial prompt rendered, engine budget
|
|
334
|
+
computed) for :func:`rollout_async`, starting from the same state :func:`rollout_one` builds
|
|
335
|
+
inline so the two paths stay byte-identical."""
|
|
336
|
+
rollouts: list[_RolloutState] = []
|
|
337
|
+
for example in examples:
|
|
338
|
+
state = active_env.new_rollout_state(example)
|
|
339
|
+
initial_messages = state.get("prompt") or state.get("messages")
|
|
340
|
+
if not isinstance(initial_messages, list):
|
|
341
|
+
raise KeyError("multi-turn rollout state must include prompt or messages")
|
|
342
|
+
messages = [dict(m) for m in initial_messages]
|
|
343
|
+
prompt_ids = render(messages, True)
|
|
344
|
+
budget = (engine_max_len - len(prompt_ids) - 8) if engine_max_len else None
|
|
345
|
+
rollouts.append(_RolloutState(example, messages, prompt_ids, state, budget))
|
|
346
|
+
return rollouts
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def _turn_budget(r: _RolloutState, per_turn_max_tokens: int) -> int | None:
|
|
350
|
+
"""Max new tokens for ``r``'s next assistant turn, bounded by the remaining engine headroom so
|
|
351
|
+
prompt+completion can't overflow the context. Returns ``None`` and marks ``r.done`` when the
|
|
352
|
+
headroom is already exhausted. Identical cap for both rollout paths (no drift)."""
|
|
353
|
+
max_new = per_turn_max_tokens
|
|
354
|
+
if r.budget is not None:
|
|
355
|
+
remaining = r.budget - len(r.completion_ids)
|
|
356
|
+
if remaining <= 0: # prompt already fills the context -> this rollout is done
|
|
357
|
+
r.done = True
|
|
358
|
+
return None
|
|
359
|
+
max_new = min(max_new, remaining)
|
|
360
|
+
return max(1, max_new)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def _score_rollouts(active_env, rollouts: list[_RolloutState]) -> list[float]:
|
|
364
|
+
"""Reward for each rollout, in order. Uses ``active_env.reward_many`` when the env provides it
|
|
365
|
+
(one batched, env-concurrent scoring call per task instead of a blocking call per rollout — the
|
|
366
|
+
win for judge/expensive-reward envs). Otherwise falls back to per-rollout ``active_env.reward()``,
|
|
367
|
+
run CONCURRENTLY in a thread pool when the env declares its reward thread-safe (PR #224) so an
|
|
368
|
+
IO-bound judge/tool reward still overlaps instead of N serial GPU-idle round-trips. Every path
|
|
369
|
+
reassembles in INPUT ORDER and yields identical values — only scoring concurrency differs."""
|
|
370
|
+
reward_many = getattr(active_env, "reward_many", None)
|
|
371
|
+
if callable(reward_many):
|
|
372
|
+
rewards = reward_many([(r.example, r.state) for r in rollouts])
|
|
373
|
+
if len(rewards) != len(rollouts):
|
|
374
|
+
raise RuntimeError("env.reward_many returned the wrong number of rewards")
|
|
375
|
+
return [float(x) for x in rewards]
|
|
376
|
+
|
|
377
|
+
def _score(r: _RolloutState) -> float:
|
|
378
|
+
return float(active_env.reward("", r.example, r.state))
|
|
379
|
+
|
|
380
|
+
# Serial for a single rollout, or when the env declares its reward NOT thread-safe (a scorer that
|
|
381
|
+
# keeps mutable state or a thread-bound client) — it worked serially and must not be raced.
|
|
382
|
+
if len(rollouts) <= 1 or not getattr(active_env, "reward_thread_safe", True):
|
|
383
|
+
return [_score(r) for r in rollouts]
|
|
384
|
+
# Concurrent. On the first reward error, cancel not-yet-started scorers and drain in-flight ones
|
|
385
|
+
# so a failed step spends no further judge/API calls and leaves no scorer running into the next.
|
|
386
|
+
pool = ThreadPoolExecutor(max_workers=min(16, len(rollouts)))
|
|
387
|
+
try:
|
|
388
|
+
futures = {pool.submit(_score, r): i for i, r in enumerate(rollouts)}
|
|
389
|
+
scores: list[float] = [0.0] * len(rollouts)
|
|
390
|
+
for fut in as_completed(futures):
|
|
391
|
+
scores[futures[fut]] = fut.result() # re-raises the first failed scorer
|
|
392
|
+
finally:
|
|
393
|
+
pool.shutdown(wait=True, cancel_futures=True)
|
|
394
|
+
return scores
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def rollout_async(
|
|
398
|
+
*,
|
|
399
|
+
examples: list[dict],
|
|
400
|
+
active_env,
|
|
401
|
+
render: Callable[[list, bool], list[int]],
|
|
402
|
+
submit: Callable[[str, list[int], int, bool], None],
|
|
403
|
+
poll: Callable[[], list[tuple[str, list[int], list[float], str]]],
|
|
404
|
+
busy: Callable[[], bool],
|
|
405
|
+
env_glue: Callable[[list], list[int]],
|
|
406
|
+
max_turns: int,
|
|
407
|
+
per_turn_max_tokens: int,
|
|
408
|
+
engine_max_len: int | None = None,
|
|
409
|
+
) -> list[RolloutResult]:
|
|
410
|
+
"""Run ``len(examples)`` multi-turn rollouts with CONTINUOUS-BATCHED generation (no turn barrier).
|
|
411
|
+
|
|
412
|
+
Same result as one :func:`rollout_one` per example — identical token alignment, env_mask,
|
|
413
|
+
per-rollout reward and input order — but rollouts are NOT advanced in lockstep. Each rollout's
|
|
414
|
+
assistant turn is an independent engine request; the moment one finishes, its env step runs and
|
|
415
|
+
its NEXT turn is submitted, so the decode batch stays full instead of stalling at a turn boundary
|
|
416
|
+
while the slowest rollout's turn (and then every rollout's env reply) completes. For high-variance
|
|
417
|
+
multi-turn (rollouts of very different depths) this keeps the GPU busy across the many turn
|
|
418
|
+
boundaries a turn-synchronized rollout would idle at.
|
|
419
|
+
|
|
420
|
+
The work is split across two threads so the per-turn ENV work (env reply + glue render — the
|
|
421
|
+
overhead that bounds an otherwise GPU-light rollout) overlaps the GPU decode instead of blocking
|
|
422
|
+
it: the MAIN thread owns the engine (submit / poll / busy) and a single WORKER thread owns the
|
|
423
|
+
env (``_advance_after_turn``). vLLM's ``step()`` runs the model forward in CUDA with the GIL
|
|
424
|
+
released, so the worker advances finished turns DURING the decode of the still-running ones. The
|
|
425
|
+
two threads share no mutable state — only thread-safe queues, and each next-turn prefix is handed
|
|
426
|
+
over as a copy — so the env (not thread-safe) is touched by exactly one thread and the engine by
|
|
427
|
+
exactly one thread. Results are byte-identical to one :func:`rollout_one` per example (a rollout
|
|
428
|
+
keeps at most one request in flight, so its turns stay strictly sequential), in input order.
|
|
429
|
+
|
|
430
|
+
The engine is injected as three callables so the loop is unit-testable on CPU:
|
|
431
|
+
* ``submit(req_id, prefix_ids, max_tokens, initial)`` — enqueue one assistant-turn request
|
|
432
|
+
(``initial`` marks a turn-0 prompt, the only externally-rendered ids worth bounds-checking);
|
|
433
|
+
* ``poll()`` — return ``(req_id, token_ids, logprobs, text)`` for every request that FINISHED
|
|
434
|
+
since the last call (``[]`` if none finished this step);
|
|
435
|
+
* ``busy()`` — whether any request is still in flight.
|
|
436
|
+
"""
|
|
437
|
+
rollouts = _build_rollout_states(examples, active_env, render, engine_max_len)
|
|
438
|
+
by_id: dict[str, _RolloutState] = {}
|
|
439
|
+
counter = 0
|
|
440
|
+
to_env: queue.Queue = queue.Queue() # main -> worker: finished turns to fold + run the env step
|
|
441
|
+
to_submit: queue.Queue = queue.Queue() # worker -> main: ("next", r, prefix, max_new) | ("done", r)
|
|
442
|
+
|
|
443
|
+
def do_submit(r: _RolloutState, prefix: list[int], max_new: int, initial: bool) -> None:
|
|
444
|
+
nonlocal counter
|
|
445
|
+
req_id = f"r{counter}"
|
|
446
|
+
counter += 1
|
|
447
|
+
by_id[req_id] = r
|
|
448
|
+
submit(req_id, prefix, max_new, initial)
|
|
449
|
+
|
|
450
|
+
def env_worker() -> None:
|
|
451
|
+
# Owns the env: fold each finished turn, run its env step (env reply + glue render), and hand
|
|
452
|
+
# the next-turn prefix (a copy) back to the main thread — or signal the rollout is done. An
|
|
453
|
+
# env/template error here must propagate to the main thread (which owns the engine), not die
|
|
454
|
+
# silently in this thread and hang the main loop waiting for a result that never comes.
|
|
455
|
+
while True:
|
|
456
|
+
item = to_env.get()
|
|
457
|
+
if item is None:
|
|
458
|
+
return
|
|
459
|
+
r, asst_ids, asst_lp, text = item
|
|
460
|
+
try:
|
|
461
|
+
_advance_after_turn(
|
|
462
|
+
r, asst_ids, asst_lp, text,
|
|
463
|
+
active_env=active_env, env_glue=env_glue, max_turns=max_turns,
|
|
464
|
+
)
|
|
465
|
+
max_new = None if r.done else _turn_budget(r, per_turn_max_tokens)
|
|
466
|
+
except Exception as exc: # surfaced + re-raised on the main thread (engine owner)
|
|
467
|
+
to_submit.put(("error", exc))
|
|
468
|
+
return
|
|
469
|
+
to_submit.put(("done", r) if max_new is None else ("next", r, list(r.cur_ids), max_new))
|
|
470
|
+
|
|
471
|
+
worker = threading.Thread(target=env_worker, daemon=True)
|
|
472
|
+
worker.start()
|
|
473
|
+
n = len(rollouts)
|
|
474
|
+
completed = 0
|
|
475
|
+
|
|
476
|
+
def take(msg) -> None:
|
|
477
|
+
nonlocal completed
|
|
478
|
+
if msg[0] == "error":
|
|
479
|
+
# Re-raise the worker's env/template error on the main thread (the engine owner),
|
|
480
|
+
# preserving the ORIGINAL worker traceback so the stack points at the real failing line.
|
|
481
|
+
err = msg[1]
|
|
482
|
+
raise err.with_traceback(err.__traceback__)
|
|
483
|
+
if msg[0] == "done":
|
|
484
|
+
completed += 1
|
|
485
|
+
else:
|
|
486
|
+
_, r, prefix, max_new = msg
|
|
487
|
+
do_submit(r, prefix, max_new, False)
|
|
488
|
+
|
|
489
|
+
try:
|
|
490
|
+
for r in rollouts: # prime turn 0 on the main thread
|
|
491
|
+
max_new = _turn_budget(r, per_turn_max_tokens)
|
|
492
|
+
if max_new is None:
|
|
493
|
+
completed += 1
|
|
494
|
+
else:
|
|
495
|
+
do_submit(r, list(r.cur_ids), max_new, r.turns == 0)
|
|
496
|
+
while completed < n:
|
|
497
|
+
progressed = False
|
|
498
|
+
while True: # submit every next-turn the worker has produced (and count finished ones)
|
|
499
|
+
try:
|
|
500
|
+
take(to_submit.get_nowait())
|
|
501
|
+
progressed = True
|
|
502
|
+
except queue.Empty:
|
|
503
|
+
break
|
|
504
|
+
if completed >= n:
|
|
505
|
+
break
|
|
506
|
+
if busy(): # step the engine; hand finished turns to the worker (overlaps its env work)
|
|
507
|
+
for req_id, asst_ids, asst_lp, text in poll():
|
|
508
|
+
to_env.put((by_id.pop(req_id), asst_ids, asst_lp, text))
|
|
509
|
+
elif not progressed:
|
|
510
|
+
# nothing in flight and nothing newly ready: the worker is mid-advance — block on its
|
|
511
|
+
# next output instead of spinning (every rollout is in exactly one stage, so this
|
|
512
|
+
# can't deadlock: the only state with all queues + in-flight empty is all-done).
|
|
513
|
+
with contextlib.suppress(queue.Empty):
|
|
514
|
+
take(to_submit.get(timeout=0.1))
|
|
515
|
+
finally:
|
|
516
|
+
to_env.put(None)
|
|
517
|
+
worker.join()
|
|
518
|
+
|
|
519
|
+
# Score with the ACTUAL accumulated rollout state (matches rollout_one), batched per task.
|
|
520
|
+
rewards = _score_rollouts(active_env, rollouts)
|
|
521
|
+
return [r.result(rw) for r, rw in zip(rollouts, rewards, strict=True)]
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
def render_message_ids(tok, messages, add_generation_prompt: bool, *, thinking: bool) -> list[int]:
|
|
525
|
+
"""Render ``messages`` with the chat template, then tokenize to a flat ``list[int]``.
|
|
526
|
+
|
|
527
|
+
Render to text first, then tokenize — the return shape of apply_chat_template(tokenize=True)
|
|
528
|
+
varies by tokenizer, whereas tok(text).input_ids is reliably a flat list[int] (matches the
|
|
529
|
+
single-turn render_prompt path). add_special_tokens=False because the template already
|
|
530
|
+
emits the special tokens. Shared by the GRPO rollout closure and mid-run eval so both
|
|
531
|
+
produce identical token alignment.
|
|
532
|
+
"""
|
|
533
|
+
text = tok.apply_chat_template(
|
|
534
|
+
messages,
|
|
535
|
+
add_generation_prompt=add_generation_prompt,
|
|
536
|
+
tokenize=False,
|
|
537
|
+
enable_thinking=thinking,
|
|
538
|
+
)
|
|
539
|
+
return [int(t) for t in tok(text, add_special_tokens=False).input_ids]
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def _engine_vocab_size(engine) -> int | None:
|
|
543
|
+
"""Best-effort vocab size of the colocate vLLM engine, or None if it can't be read.
|
|
544
|
+
|
|
545
|
+
Used only for a cheap fail-loud bounds check on the pre-tokenized prompt ids before they
|
|
546
|
+
reach ``engine.generate`` (the ``prompt_token_ids`` path does no bounds checking, so an
|
|
547
|
+
out-of-range id would otherwise surface as an opaque CUDA illegal-access). Never raises.
|
|
548
|
+
"""
|
|
549
|
+
try:
|
|
550
|
+
mc = engine.llm_engine.model_config
|
|
551
|
+
except Exception:
|
|
552
|
+
return None
|
|
553
|
+
for attr in ("get_vocab_size", "get_hf_config_vocab_size"):
|
|
554
|
+
getter = getattr(mc, attr, None)
|
|
555
|
+
if callable(getter):
|
|
556
|
+
try:
|
|
557
|
+
return int(getter())
|
|
558
|
+
except Exception:
|
|
559
|
+
pass
|
|
560
|
+
try:
|
|
561
|
+
return int(mc.hf_text_config.vocab_size)
|
|
562
|
+
except Exception:
|
|
563
|
+
return None
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def build_rollout_func(
|
|
567
|
+
*,
|
|
568
|
+
active_env,
|
|
569
|
+
tok,
|
|
570
|
+
examples_by_key: dict,
|
|
571
|
+
max_completion: int,
|
|
572
|
+
max_turns: int,
|
|
573
|
+
temperature: float,
|
|
574
|
+
top_p: float,
|
|
575
|
+
stop: list[str] | None,
|
|
576
|
+
thinking: bool,
|
|
577
|
+
engine_max_len: int | None = None,
|
|
578
|
+
):
|
|
579
|
+
"""Return a TRL ``rollout_func`` closure that drives ``active_env`` on the colocate engine.
|
|
580
|
+
|
|
581
|
+
The closure reaches the in-process vLLM engine through ``trainer.vllm_generation.llm`` and
|
|
582
|
+
samples each assistant turn with per-token logprobs. It returns exactly ONE rollout per
|
|
583
|
+
prompt in the slice TRL passes: TRL's ``RepeatSampler`` already repeats each unique prompt
|
|
584
|
+
``num_generations`` times before calling ``rollout_func`` (the consecutive repeats form the
|
|
585
|
+
GRPO group), so the closure must NOT multiply by ``num_generations`` again.
|
|
586
|
+
"""
|
|
587
|
+
from vllm import SamplingParams # gpu-only; imported lazily so the module loads on CPU
|
|
588
|
+
|
|
589
|
+
try:
|
|
590
|
+
# FINAL_ONLY makes each manual add_request emit exactly one RequestOutput, at finish, with
|
|
591
|
+
# the complete turn (matching LLM.generate); without it the engine streams a cumulative
|
|
592
|
+
# output every step. Optional so the CPU import (stubbed vllm) still works — poll() filters
|
|
593
|
+
# on `finished` either way.
|
|
594
|
+
from vllm.sampling_params import RequestOutputKind
|
|
595
|
+
|
|
596
|
+
_final_only_kind = RequestOutputKind.FINAL_ONLY
|
|
597
|
+
except Exception:
|
|
598
|
+
_final_only_kind = None
|
|
599
|
+
|
|
600
|
+
_render_cache = _LRUCache(8192)
|
|
601
|
+
|
|
602
|
+
def render(messages: list, add_generation_prompt: bool) -> list[int]:
|
|
603
|
+
# The initial-prompt render is identical for every rollout in a GRPO group (they share one
|
|
604
|
+
# prompt), so cache it by content instead of re-rendering num_generations times per step.
|
|
605
|
+
# LRU-bounded: when full it EVICTS the least-recently-used entry rather than freezing, so a
|
|
606
|
+
# long run with many distinct prompts keeps caching the recently-seen ones (a freeze-when-full
|
|
607
|
+
# cache would stop admitting any new prompt after the cap and re-render them forever).
|
|
608
|
+
cache_key = f"{add_generation_prompt}\x00{json.dumps(messages, sort_keys=True, default=str)}"
|
|
609
|
+
cached = _render_cache.get(cache_key)
|
|
610
|
+
if cached is not None:
|
|
611
|
+
return cached
|
|
612
|
+
ids = render_message_ids(tok, messages, add_generation_prompt, thinking=thinking)
|
|
613
|
+
_render_cache.put(cache_key, ids)
|
|
614
|
+
return ids
|
|
615
|
+
|
|
616
|
+
_glue_cache = _LRUCache(8192)
|
|
617
|
+
|
|
618
|
+
def env_glue(env_messages: list) -> list[int]:
|
|
619
|
+
# The inter-turn glue is a pure function of env_messages (+ this closure's tokenizer /
|
|
620
|
+
# thinking). Within a GRPO group every rollout gets the SAME env reply each turn, and many
|
|
621
|
+
# turns repeat env messages across rollouts and steps, so apply_chat_template would
|
|
622
|
+
# otherwise re-render byte-identical glue dozens-to-hundreds of times — the dominant per-turn
|
|
623
|
+
# CPU cost in the (otherwise overhead-bound) multi-turn rollout. Cache by env-message
|
|
624
|
+
# content; LRU-bounded so an env whose every reply is unique can't grow it without limit and,
|
|
625
|
+
# unlike a freeze-when-full cache, recently-seen glue stays cached over a long diverse run.
|
|
626
|
+
cache_key = json.dumps(env_messages, sort_keys=True, default=str)
|
|
627
|
+
cached = _glue_cache.get(cache_key)
|
|
628
|
+
if cached is not None:
|
|
629
|
+
return cached
|
|
630
|
+
# Tokens between two assistant turns: close the previous assistant turn, render the env
|
|
631
|
+
# reply message(s), and open the next generation prompt. Derived by rendering a probe
|
|
632
|
+
# assistant turn followed by the env messages (+ generation prompt) and taking everything
|
|
633
|
+
# AFTER the probe content — so the glue is exactly the template's inter-turn wrapper,
|
|
634
|
+
# whatever it is (Qwen's <|im_end|> + user turn + <|im_start|>assistant + <think> block).
|
|
635
|
+
# This avoids re-rendering history (which Qwen3 does not round-trip) and matches how the
|
|
636
|
+
# model actually conditioned during generation. The probe is plain text the template
|
|
637
|
+
# inserts verbatim into assistant content; its FIRST occurrence is the probe turn.
|
|
638
|
+
probe = "flash-env-glue-probe"
|
|
639
|
+
text = tok.apply_chat_template(
|
|
640
|
+
[{"role": "assistant", "content": probe}, *env_messages],
|
|
641
|
+
add_generation_prompt=True,
|
|
642
|
+
tokenize=False,
|
|
643
|
+
enable_thinking=thinking,
|
|
644
|
+
)
|
|
645
|
+
# Locate the probe to slice off the inter-turn glue. Fail LOUD with context if the
|
|
646
|
+
# template did not insert the assistant content verbatim (some templates strip/escape it,
|
|
647
|
+
# or could emit the probe more than once) instead of a bare "substring not found".
|
|
648
|
+
first = text.find(probe)
|
|
649
|
+
if first == -1 or text.find(probe, first + len(probe)) != -1:
|
|
650
|
+
raise ValueError(
|
|
651
|
+
"multi-turn env_glue could not uniquely locate its probe in the rendered chat "
|
|
652
|
+
"template; this model's template does not insert assistant content verbatim, so "
|
|
653
|
+
"token-aligned multi-turn rollout is unsupported for it (use a single-turn/tool "
|
|
654
|
+
"env or a different model)."
|
|
655
|
+
)
|
|
656
|
+
glue_text = text[first + len(probe) :]
|
|
657
|
+
glue = [int(t) for t in tok(glue_text, add_special_tokens=False).input_ids]
|
|
658
|
+
_glue_cache.put(cache_key, glue)
|
|
659
|
+
return glue
|
|
660
|
+
|
|
661
|
+
def rollout_func(prompts, trainer):
|
|
662
|
+
engine = trainer.vllm_generation.llm
|
|
663
|
+
# The colocate engine is a vLLM `LLM`; its V1 `LLMEngine` exposes the public
|
|
664
|
+
# add_request / step / has_unfinished_requests loop that lets us decode many rollouts'
|
|
665
|
+
# turns CONTINUOUSLY (a finished turn's slot refills with another rollout's next turn)
|
|
666
|
+
# instead of one synchronized batched decode per turn.
|
|
667
|
+
llm_engine = engine.llm_engine
|
|
668
|
+
# Colocate vLLM sleep mode (GRPOConfig.vllm_enable_sleep_mode, ON for large / long-context
|
|
669
|
+
# runs) offloads BOTH the rollout engine's weights and its KV cache between steps. TRL's
|
|
670
|
+
# rollout_func path (GRPOTrainer._generate) calls vllm_generation.sync_weights() — which
|
|
671
|
+
# wakes only tags=["weights"] — and then hands control to this closure, but, UNLIKE TRL's
|
|
672
|
+
# own single-turn generate() path, it never wakes tags=["kv_cache"]. So the first decode
|
|
673
|
+
# below would run against a freed/offloaded KV cache and fault with CUDA "illegal memory
|
|
674
|
+
# access" on step 0. Wake the KV cache here and re-sleep after the whole batch, mirroring
|
|
675
|
+
# trl.generation.vllm_generation.generate (and trl.experimental.openenv). No-op when sleep
|
|
676
|
+
# mode is off (small/short-context runs keep the engine resident). See flash issue #162.
|
|
677
|
+
sleep_mode = bool(getattr(getattr(trainer, "args", None), "vllm_enable_sleep_mode", False))
|
|
678
|
+
vocab_size = _engine_vocab_size(engine)
|
|
679
|
+
active_ids: set[str] = set() # submitted-but-not-finished requests, for abort-on-exit
|
|
680
|
+
|
|
681
|
+
def submit(req_id: str, prefix_ids: list[int], max_tokens: int, initial: bool) -> None:
|
|
682
|
+
"""Enqueue one assistant-turn request on the colocate engine."""
|
|
683
|
+
if not prefix_ids:
|
|
684
|
+
# Fail loudly on a degenerate prompt instead of letting it reach the embedding gather
|
|
685
|
+
# as an opaque async CUDA illegal-access (the failure mode #162 was first mistaken
|
|
686
|
+
# for): the prompt_token_ids path does no bounds checking.
|
|
687
|
+
raise ValueError("multi-turn rollout produced an empty prompt for engine.add_request()")
|
|
688
|
+
if initial:
|
|
689
|
+
# Turn-0 prefixes are the only externally-rendered initial prompts (later turns are
|
|
690
|
+
# vLLM-generated / tokenizer glue, already in range); validate each, since the
|
|
691
|
+
# prompt_token_ids path does no bounds checking and an out-of-range id would surface
|
|
692
|
+
# as an opaque CUDA illegal-access.
|
|
693
|
+
lo, hi = min(prefix_ids), max(prefix_ids)
|
|
694
|
+
if lo < 0 or (vocab_size is not None and hi >= vocab_size):
|
|
695
|
+
raise ValueError(
|
|
696
|
+
f"multi-turn rollout prompt has out-of-range token id(s) [{lo}, {hi}] for "
|
|
697
|
+
f"vocab size {vocab_size} (tokenizer/model mismatch)"
|
|
698
|
+
)
|
|
699
|
+
sp_kwargs = {
|
|
700
|
+
"max_tokens": max(1, int(max_tokens)),
|
|
701
|
+
"temperature": temperature,
|
|
702
|
+
"top_p": top_p,
|
|
703
|
+
"logprobs": 1, # include the sampled token's logprob at each position
|
|
704
|
+
"stop": list(stop) if stop else None,
|
|
705
|
+
}
|
|
706
|
+
if _final_only_kind is not None:
|
|
707
|
+
sp_kwargs["output_kind"] = _final_only_kind
|
|
708
|
+
llm_engine.add_request(
|
|
709
|
+
req_id, {"prompt_token_ids": list(prefix_ids)}, SamplingParams(**sp_kwargs)
|
|
710
|
+
)
|
|
711
|
+
active_ids.add(req_id)
|
|
712
|
+
|
|
713
|
+
def poll() -> list[tuple[str, list[int], list[float], str]]:
|
|
714
|
+
"""Advance the engine one step; return (req_id, token_ids, logprobs, text) for every
|
|
715
|
+
request that finished this step (``[]`` if none did / a dummy batch ran)."""
|
|
716
|
+
finished: list[tuple[str, list[int], list[float], str]] = []
|
|
717
|
+
for out in llm_engine.step():
|
|
718
|
+
if not getattr(out, "finished", False):
|
|
719
|
+
continue
|
|
720
|
+
comp = out.outputs[0]
|
|
721
|
+
token_ids = list(comp.token_ids)
|
|
722
|
+
# comp.logprobs is a list (per position) of {token_id: Logprob}; pull the sampled
|
|
723
|
+
# token's logprob at each position.
|
|
724
|
+
lps: list[float] = []
|
|
725
|
+
for pos, tid in enumerate(token_ids):
|
|
726
|
+
entry = (comp.logprobs or [])[pos] if comp.logprobs else None
|
|
727
|
+
lp = entry.get(tid) if entry else None
|
|
728
|
+
lps.append(float(getattr(lp, "logprob", 0.0)) if lp is not None else 0.0)
|
|
729
|
+
active_ids.discard(out.request_id)
|
|
730
|
+
finished.append((out.request_id, token_ids, lps, comp.text))
|
|
731
|
+
return finished
|
|
732
|
+
|
|
733
|
+
def busy() -> bool:
|
|
734
|
+
return bool(llm_engine.has_unfinished_requests())
|
|
735
|
+
|
|
736
|
+
# Wake the KV cache for the whole batch (see the note above), then re-sleep so the engine
|
|
737
|
+
# returns to its fully-offloaded state and the optimizer step has the freed memory back.
|
|
738
|
+
# `woke` is set AFTER a successful wake so the finally re-sleeps ONLY when we actually woke
|
|
739
|
+
# the engine — a wake_up() that raises leaves the engine asleep (its resting state), and we
|
|
740
|
+
# must not then call sleep() on it; a failure DURING the rollout still re-sleeps.
|
|
741
|
+
woke = False
|
|
742
|
+
try:
|
|
743
|
+
if sleep_mode:
|
|
744
|
+
engine.wake_up(tags=["kv_cache"])
|
|
745
|
+
woke = True
|
|
746
|
+
# ONE rollout per prompt: TRL's RepeatSampler already repeats each unique prompt
|
|
747
|
+
# num_generations times BEFORE handing the slice to rollout_func (trl 1.6/1.7:
|
|
748
|
+
# `prompts = [x["prompt"] for x in inputs]`, no dedup), and it expects exactly
|
|
749
|
+
# len(prompts) completions back — the GRPO group is the consecutive num_generations rows
|
|
750
|
+
# of the same prompt. rollout_async returns one result per example in input order, so
|
|
751
|
+
# the group stays aligned.
|
|
752
|
+
examples = [examples_by_key.get(_prompt_key(p), {"prompt": p}) for p in prompts]
|
|
753
|
+
rollouts = rollout_async(
|
|
754
|
+
examples=examples,
|
|
755
|
+
active_env=active_env,
|
|
756
|
+
render=render,
|
|
757
|
+
submit=submit,
|
|
758
|
+
poll=poll,
|
|
759
|
+
busy=busy,
|
|
760
|
+
env_glue=env_glue,
|
|
761
|
+
max_turns=max_turns,
|
|
762
|
+
per_turn_max_tokens=max_completion,
|
|
763
|
+
engine_max_len=engine_max_len,
|
|
764
|
+
)
|
|
765
|
+
out: dict[str, list] = {k: [] for k in _ROLLOUT_FIELDS}
|
|
766
|
+
for r in rollouts:
|
|
767
|
+
for k in out:
|
|
768
|
+
out[k].append(r[k])
|
|
769
|
+
return out
|
|
770
|
+
finally:
|
|
771
|
+
# Abort any still-in-flight requests so a mid-rollout error (e.g. an env_glue/template
|
|
772
|
+
# failure on a later turn) can't leak live requests into the engine and corrupt the
|
|
773
|
+
# next GRPO step. No-op on the success path (every request finished -> active_ids empty).
|
|
774
|
+
if active_ids:
|
|
775
|
+
with contextlib.suppress(Exception):
|
|
776
|
+
llm_engine.abort_request(list(active_ids))
|
|
777
|
+
if woke:
|
|
778
|
+
engine.sleep(level=2)
|
|
779
|
+
|
|
780
|
+
return rollout_func
|