torchrl-nightly 2025.6.19__cp313-cp313-macosx_10_13_universal2.whl → 2025.6.21__cp313-cp313-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrl/_torchrl.cpython-313-darwin.so +0 -0
- torchrl/collectors/collectors.py +49 -24
- torchrl/collectors/llm/base.py +13 -6
- torchrl/collectors/llm/ray_collector.py +3 -0
- torchrl/data/__init__.py +2 -0
- torchrl/data/datasets/minari_data.py +1 -1
- torchrl/data/llm/__init__.py +2 -0
- torchrl/data/llm/chat.py +59 -9
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/replay_buffers/ray_buffer.py +15 -1
- torchrl/data/replay_buffers/replay_buffers.py +50 -11
- torchrl/data/replay_buffers/samplers.py +98 -21
- torchrl/data/replay_buffers/storages.py +29 -2
- torchrl/envs/llm/__init__.py +2 -0
- torchrl/envs/llm/chat.py +4 -1
- torchrl/envs/llm/reward/gsm8k.py +15 -8
- torchrl/envs/llm/transforms/__init__.py +2 -1
- torchrl/envs/llm/transforms/kl.py +240 -4
- torchrl/envs/transforms/transforms.py +11 -27
- torchrl/modules/llm/policies/transformers_wrapper.py +71 -15
- torchrl/modules/llm/policies/vllm_wrapper.py +38 -5
- torchrl/objectives/llm/__init__.py +2 -1
- torchrl/objectives/llm/sft.py +465 -0
- torchrl/objectives/ppo.py +35 -12
- torchrl/version.py +2 -2
- {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/METADATA +1 -1
- {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/RECORD +30 -28
- {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/LICENSE +0 -0
- {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/WHEEL +0 -0
- {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/top_level.txt +0 -0
Binary file
|
torchrl/collectors/collectors.py
CHANGED
@@ -352,8 +352,8 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
|
|
352
352
|
self._iterator = iter(self)
|
353
353
|
out = next(self._iterator)
|
354
354
|
# if any, we don't want the device ref to be passed in distributed settings
|
355
|
-
if out is not None:
|
356
|
-
out.clear_device_()
|
355
|
+
if out is not None and (out.device != "cpu"):
|
356
|
+
out = out.copy().clear_device_()
|
357
357
|
return out
|
358
358
|
except StopIteration:
|
359
359
|
return None
|
@@ -892,7 +892,10 @@ class SyncDataCollector(DataCollectorBase):
|
|
892
892
|
and hasattr(self.postproc, "to")
|
893
893
|
and self.storing_device
|
894
894
|
):
|
895
|
-
self.postproc.to(self.storing_device)
|
895
|
+
postproc = self.postproc.to(self.storing_device)
|
896
|
+
if postproc is not self.postproc and postproc is not None:
|
897
|
+
self.postproc = postproc
|
898
|
+
|
896
899
|
if frames_per_batch % self.n_env != 0 and RL_WARNINGS:
|
897
900
|
warnings.warn(
|
898
901
|
f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), "
|
@@ -1253,9 +1256,9 @@ class SyncDataCollector(DataCollectorBase):
|
|
1253
1256
|
yield
|
1254
1257
|
continue
|
1255
1258
|
self._increment_frames(tensordict_out.numel())
|
1256
|
-
if self.verbose:
|
1257
|
-
torchrl_logger.info("Collector: postproc.")
|
1258
1259
|
tensordict_out = self._postproc(tensordict_out)
|
1260
|
+
if self.verbose:
|
1261
|
+
torchrl_logger.info("Collector: postproc done.")
|
1259
1262
|
if self.return_same_td:
|
1260
1263
|
# This is used with multiprocessed collectors to use the buffers
|
1261
1264
|
# stored in the tensordict.
|
@@ -1765,8 +1768,9 @@ class _MultiDataCollector(DataCollectorBase):
|
|
1765
1768
|
.. warning:: `policy_factory` is currently not compatible with multiprocessed data
|
1766
1769
|
collectors.
|
1767
1770
|
|
1768
|
-
frames_per_batch (int): A keyword-only argument representing the
|
1769
|
-
total number of elements in a batch.
|
1771
|
+
frames_per_batch (int, Sequence[int]): A keyword-only argument representing the
|
1772
|
+
total number of elements in a batch. If a sequence is provided, represents the number of elements in a
|
1773
|
+
batch per worker. Total number of elements in a batch is then the sum over the sequence.
|
1770
1774
|
total_frames (int, optional): A keyword-only argument representing the
|
1771
1775
|
total number of frames returned by the collector
|
1772
1776
|
during its lifespan. If the ``total_frames`` is not divisible by
|
@@ -1923,7 +1927,7 @@ class _MultiDataCollector(DataCollectorBase):
|
|
1923
1927
|
policy_factory: Callable[[], Callable]
|
1924
1928
|
| list[Callable[[], Callable]]
|
1925
1929
|
| None = None,
|
1926
|
-
frames_per_batch: int,
|
1930
|
+
frames_per_batch: int | Sequence[int],
|
1927
1931
|
total_frames: int | None = -1,
|
1928
1932
|
device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
|
1929
1933
|
storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
|
@@ -1959,6 +1963,22 @@ class _MultiDataCollector(DataCollectorBase):
|
|
1959
1963
|
self.closed = True
|
1960
1964
|
self.num_workers = len(create_env_fn)
|
1961
1965
|
|
1966
|
+
if (
|
1967
|
+
isinstance(frames_per_batch, Sequence)
|
1968
|
+
and len(frames_per_batch) != self.num_workers
|
1969
|
+
):
|
1970
|
+
raise ValueError(
|
1971
|
+
"If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker."
|
1972
|
+
f"Got {len(frames_per_batch)} values for {self.num_workers} workers."
|
1973
|
+
)
|
1974
|
+
|
1975
|
+
self._frames_per_batch = frames_per_batch
|
1976
|
+
total_frames_per_batch = (
|
1977
|
+
sum(frames_per_batch)
|
1978
|
+
if isinstance(frames_per_batch, Sequence)
|
1979
|
+
else frames_per_batch
|
1980
|
+
)
|
1981
|
+
|
1962
1982
|
self.set_truncated = set_truncated
|
1963
1983
|
self.num_sub_threads = num_sub_threads
|
1964
1984
|
self.num_threads = num_threads
|
@@ -2076,11 +2096,11 @@ class _MultiDataCollector(DataCollectorBase):
|
|
2076
2096
|
if total_frames is None or total_frames < 0:
|
2077
2097
|
total_frames = float("inf")
|
2078
2098
|
else:
|
2079
|
-
remainder = total_frames %
|
2099
|
+
remainder = total_frames % total_frames_per_batch
|
2080
2100
|
if remainder != 0 and RL_WARNINGS:
|
2081
2101
|
warnings.warn(
|
2082
|
-
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({
|
2083
|
-
f"This means {
|
2102
|
+
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({total_frames_per_batch}). "
|
2103
|
+
f"This means {total_frames_per_batch - remainder} additional frames will be collected. "
|
2084
2104
|
"To silence this message, set the environment variable RL_WARNINGS to False."
|
2085
2105
|
)
|
2086
2106
|
self.total_frames = (
|
@@ -2091,7 +2111,8 @@ class _MultiDataCollector(DataCollectorBase):
|
|
2091
2111
|
self.max_frames_per_traj = (
|
2092
2112
|
int(max_frames_per_traj) if max_frames_per_traj is not None else 0
|
2093
2113
|
)
|
2094
|
-
|
2114
|
+
|
2115
|
+
self.requested_frames_per_batch = total_frames_per_batch
|
2095
2116
|
self.reset_when_done = reset_when_done
|
2096
2117
|
if split_trajs is None:
|
2097
2118
|
split_trajs = False
|
@@ -2221,8 +2242,7 @@ class _MultiDataCollector(DataCollectorBase):
|
|
2221
2242
|
)
|
2222
2243
|
return storing_device, policy_device, env_device
|
2223
2244
|
|
2224
|
-
|
2225
|
-
def frames_per_batch_worker(self):
|
2245
|
+
def frames_per_batch_worker(self, worker_idx: int | None = None) -> int:
|
2226
2246
|
raise NotImplementedError
|
2227
2247
|
|
2228
2248
|
@property
|
@@ -2281,7 +2301,7 @@ class _MultiDataCollector(DataCollectorBase):
|
|
2281
2301
|
"create_env_kwargs": env_fun_kwargs,
|
2282
2302
|
"policy": policy,
|
2283
2303
|
"max_frames_per_traj": self.max_frames_per_traj,
|
2284
|
-
"frames_per_batch": self.frames_per_batch_worker,
|
2304
|
+
"frames_per_batch": self.frames_per_batch_worker(worker_idx=i),
|
2285
2305
|
"reset_at_each_iter": self.reset_at_each_iter,
|
2286
2306
|
"policy_device": policy_device,
|
2287
2307
|
"storing_device": storing_device,
|
@@ -2773,8 +2793,9 @@ class MultiSyncDataCollector(_MultiDataCollector):
|
|
2773
2793
|
policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
|
2774
2794
|
)
|
2775
2795
|
|
2776
|
-
|
2777
|
-
|
2796
|
+
def frames_per_batch_worker(self, worker_idx: int | None) -> int:
|
2797
|
+
if worker_idx is not None and isinstance(self._frames_per_batch, Sequence):
|
2798
|
+
return self._frames_per_batch[worker_idx]
|
2778
2799
|
if self.requested_frames_per_batch % self.num_workers != 0 and RL_WARNINGS:
|
2779
2800
|
warnings.warn(
|
2780
2801
|
f"frames_per_batch {self.requested_frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers},"
|
@@ -2855,9 +2876,9 @@ class MultiSyncDataCollector(_MultiDataCollector):
|
|
2855
2876
|
use_buffers = self._use_buffers
|
2856
2877
|
if self.replay_buffer is not None:
|
2857
2878
|
idx = new_data
|
2858
|
-
workers_frames[idx] =
|
2859
|
-
|
2860
|
-
)
|
2879
|
+
workers_frames[idx] = workers_frames[
|
2880
|
+
idx
|
2881
|
+
] + self.frames_per_batch_worker(worker_idx=idx)
|
2861
2882
|
continue
|
2862
2883
|
elif j == 0 or not use_buffers:
|
2863
2884
|
try:
|
@@ -2903,7 +2924,12 @@ class MultiSyncDataCollector(_MultiDataCollector):
|
|
2903
2924
|
|
2904
2925
|
if self.replay_buffer is not None:
|
2905
2926
|
yield
|
2906
|
-
self._frames +=
|
2927
|
+
self._frames += sum(
|
2928
|
+
[
|
2929
|
+
self.frames_per_batch_worker(worker_idx)
|
2930
|
+
for worker_idx in range(self.num_workers)
|
2931
|
+
]
|
2932
|
+
)
|
2907
2933
|
continue
|
2908
2934
|
|
2909
2935
|
# we have to correct the traj_ids to make sure that they don't overlap
|
@@ -3156,8 +3182,7 @@ class MultiaSyncDataCollector(_MultiDataCollector):
|
|
3156
3182
|
policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
|
3157
3183
|
)
|
3158
3184
|
|
3159
|
-
|
3160
|
-
def frames_per_batch_worker(self):
|
3185
|
+
def frames_per_batch_worker(self, worker_idx: int | None = None) -> int:
|
3161
3186
|
return self.requested_frames_per_batch
|
3162
3187
|
|
3163
3188
|
def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]:
|
@@ -3221,7 +3246,7 @@ class MultiaSyncDataCollector(_MultiDataCollector):
|
|
3221
3246
|
if self.split_trajs:
|
3222
3247
|
out = split_trajectories(out, prefix="collector")
|
3223
3248
|
else:
|
3224
|
-
worker_frames = self.frames_per_batch_worker
|
3249
|
+
worker_frames = self.frames_per_batch_worker()
|
3225
3250
|
self._frames += worker_frames
|
3226
3251
|
workers_frames[idx] = workers_frames[idx] + worker_frames
|
3227
3252
|
if self.postprocs:
|
torchrl/collectors/llm/base.py
CHANGED
@@ -242,6 +242,11 @@ class LLMCollector(SyncDataCollector):
|
|
242
242
|
else:
|
243
243
|
self.policy_version_tracker = None
|
244
244
|
|
245
|
+
def set_postproc(self, postproc: Callable[[TensorDictBase], TensorDictBase]):
|
246
|
+
if self.postproc is not None:
|
247
|
+
raise RuntimeError("Postproc already set")
|
248
|
+
self.postproc = postproc
|
249
|
+
|
245
250
|
def increment_version(self):
|
246
251
|
"""Increment the policy version."""
|
247
252
|
if self.policy_version_tracker is not None:
|
@@ -361,9 +366,10 @@ class LLMCollector(SyncDataCollector):
|
|
361
366
|
)
|
362
367
|
self._yield_queues[idx].clear()
|
363
368
|
result = self._trajectory_queue.popleft()
|
364
|
-
|
365
|
-
|
366
|
-
|
369
|
+
if self.verbose:
|
370
|
+
torchrl_logger.info(
|
371
|
+
f"LLMCollector: Yielding completed trajectory with shape {result.shape}."
|
372
|
+
)
|
367
373
|
return result
|
368
374
|
|
369
375
|
started = False
|
@@ -422,9 +428,10 @@ class LLMCollector(SyncDataCollector):
|
|
422
428
|
self.env.async_step_and_maybe_reset_send(env_input)
|
423
429
|
|
424
430
|
result = self._trajectory_queue.popleft()
|
425
|
-
|
426
|
-
|
427
|
-
|
431
|
+
if self.verbose:
|
432
|
+
torchrl_logger.info(
|
433
|
+
f"LLMCollector: Yielding completed trajectory with shape {result.shape}."
|
434
|
+
)
|
428
435
|
return result
|
429
436
|
|
430
437
|
as_remote = as_remote
|
@@ -134,6 +134,9 @@ class RayLLMCollector(LLMCollector):
|
|
134
134
|
verbose=verbose,
|
135
135
|
)
|
136
136
|
|
137
|
+
def set_postproc(self, postproc: Callable[[TensorDictBase], TensorDictBase]):
|
138
|
+
return ray.get(self._collector.set_postproc.remote(postproc))
|
139
|
+
|
137
140
|
def _next_remote(self) -> None:
|
138
141
|
return self._collector.next.remote()
|
139
142
|
|
torchrl/data/__init__.py
CHANGED
@@ -17,6 +17,7 @@ from .llm import (
|
|
17
17
|
RolloutFromModel,
|
18
18
|
TensorDictTokenizer,
|
19
19
|
TokenizedDatasetLoader,
|
20
|
+
TopKRewardSelector,
|
20
21
|
)
|
21
22
|
from .map import (
|
22
23
|
BinaryToDecimal,
|
@@ -116,6 +117,7 @@ __all__ = [
|
|
116
117
|
"Categorical",
|
117
118
|
"Choice",
|
118
119
|
"ContentBase",
|
120
|
+
"TopKRewardSelector",
|
119
121
|
"Composite",
|
120
122
|
"CompositeSpec",
|
121
123
|
"ConstantKLController",
|
@@ -350,7 +350,7 @@ class MinariExperienceReplay(BaseDatasetExperienceReplay):
|
|
350
350
|
# Add a "done" entry
|
351
351
|
if self.split_trajs:
|
352
352
|
with td_data.unlock_():
|
353
|
-
from torchrl.
|
353
|
+
from torchrl.collectors.utils import split_trajectories
|
354
354
|
|
355
355
|
td_data = split_trajectories(td_data).memmap_(self.data_path)
|
356
356
|
with open(self.metadata_path, "w") as metadata_file:
|
torchrl/data/llm/__init__.py
CHANGED
@@ -13,6 +13,7 @@ from .dataset import (
|
|
13
13
|
)
|
14
14
|
from .prompt import PromptData, PromptTensorDictTokenizer
|
15
15
|
from .reward import PairwiseDataset, RewardData
|
16
|
+
from .topk import TopKRewardSelector
|
16
17
|
from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel
|
17
18
|
|
18
19
|
__all__ = [
|
@@ -30,4 +31,5 @@ __all__ = [
|
|
30
31
|
"TokenizedDatasetLoader",
|
31
32
|
"create_infinite_iterator",
|
32
33
|
"get_dataloader",
|
34
|
+
"TopKRewardSelector",
|
33
35
|
]
|
torchrl/data/llm/chat.py
CHANGED
@@ -11,19 +11,27 @@ from typing import Literal
|
|
11
11
|
|
12
12
|
import torch
|
13
13
|
|
14
|
-
|
15
|
-
|
14
|
+
from tensordict import (
|
15
|
+
lazy_stack,
|
16
|
+
LazyStackedTensorDict,
|
17
|
+
list_to_stack,
|
18
|
+
TensorClass,
|
19
|
+
TensorDict,
|
20
|
+
)
|
16
21
|
from tensordict.utils import _maybe_correct_neg_dim
|
17
|
-
|
18
22
|
from torchrl._utils import logger as torchrl_logger
|
19
23
|
|
20
24
|
|
21
25
|
_CHAT_TEMPLATES = {
|
22
26
|
"chatml_format": """{% for message in messages %}
|
27
|
+
{%- if message['role'] == 'assistant' %}
|
28
|
+
{% generation %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endgeneration %}
|
29
|
+
{%- else %}
|
23
30
|
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
|
31
|
+
{%- endif %}
|
24
32
|
{% endfor %}
|
25
33
|
{%- if add_generation_prompt %}
|
26
|
-
{{- '<|im_start|>assistant\n' }}
|
34
|
+
{% generation %}{{- '<|im_start|>assistant\n' }}{% endgeneration %}
|
27
35
|
{%- endif %}
|
28
36
|
""",
|
29
37
|
"qwen": """
|
@@ -283,7 +291,7 @@ class History(TensorClass["nocast"]):
|
|
283
291
|
|
284
292
|
Keyword Args:
|
285
293
|
tokenizer (transformers.PreTrainedTokenizer | transformers.AutoProcessor): The tokenizer to use.
|
286
|
-
add_generation_prompt (bool, optional): Whether to add a generation prompt. Defaults to `True`.
|
294
|
+
add_generation_prompt (bool, optional): Whether to add a generation prompt (e.g. `"<|im_start|>assistant"`). Defaults to `True`.
|
287
295
|
chat_template (str, optional): The chat template to use. Defaults to the tokenizer's default template.
|
288
296
|
chat_template_name (Literal["chatml_format", "qwen"], optional): The name of the chat template to use.
|
289
297
|
Prevalent over `tokenizer.chat_template`. Defaults to `None`.
|
@@ -294,6 +302,7 @@ class History(TensorClass["nocast"]):
|
|
294
302
|
return_tensors (str | None, optional): The type of tensors to return. Defaults to "pt".
|
295
303
|
return_dict (bool, optional): Whether to return a dictionary. Defaults to `False`.
|
296
304
|
return_assistant_tokens_mask (bool, optional): Whether to return a mask of the assistant generated tokens.
|
305
|
+
If `True`, the mask will be written to the `assistant_masks` key.
|
297
306
|
For tokens generated by the assistant, the mask will contain `1`.
|
298
307
|
For user and system tokens, the mask will contain `0`.
|
299
308
|
This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
|
@@ -316,6 +325,11 @@ class History(TensorClass["nocast"]):
|
|
316
325
|
raise RuntimeError(
|
317
326
|
"You must specify a tokenizer to use when chat_template is not specified."
|
318
327
|
)
|
328
|
+
elif "qwen" in getattr(tokenizer, "name_or_path", "").lower():
|
329
|
+
# We prefer our implementation of the Qwen template,
|
330
|
+
# since it accounts for the assistant's masking.
|
331
|
+
chat_template = _CHAT_TEMPLATES["qwen"]
|
332
|
+
chat_template_name = None
|
319
333
|
else:
|
320
334
|
chat_template = tokenizer.chat_template
|
321
335
|
if chat_template is None:
|
@@ -334,7 +348,7 @@ class History(TensorClass["nocast"]):
|
|
334
348
|
return_dict = False
|
335
349
|
|
336
350
|
if self.ndim > 1:
|
337
|
-
|
351
|
+
result = [
|
338
352
|
self[i].apply_chat_template(
|
339
353
|
tokenizer=tokenizer,
|
340
354
|
add_generation_prompt=add_generation_prompt,
|
@@ -351,12 +365,16 @@ class History(TensorClass["nocast"]):
|
|
351
365
|
)
|
352
366
|
for i in range(self.batch_size[0])
|
353
367
|
]
|
368
|
+
if return_dict:
|
369
|
+
return lazy_stack(result)
|
370
|
+
else:
|
371
|
+
return result
|
354
372
|
self_flat = self.view(-1)
|
355
373
|
# tolist_first=True is needed to avoid having a list of dict of dicts, but a list of dicts of lists of dicts
|
356
374
|
self_flat = self_flat.tolist(tolist_first=True)
|
357
375
|
# Remove the "<none>" role
|
358
376
|
self_flat = [item for item in self_flat if item["role"] != "<none>"]
|
359
|
-
|
377
|
+
result = tokenizer.apply_chat_template(
|
360
378
|
conversation=self_flat,
|
361
379
|
add_generation_prompt=add_generation_prompt,
|
362
380
|
chat_template=chat_template,
|
@@ -369,6 +387,16 @@ class History(TensorClass["nocast"]):
|
|
369
387
|
return_assistant_tokens_mask=return_assistant_tokens_mask,
|
370
388
|
**kwargs,
|
371
389
|
)
|
390
|
+
if not isinstance(result, (torch.Tensor, list, str)):
|
391
|
+
result = TensorDict.from_dict(result, auto_batch_size=True, batch_dims=1)
|
392
|
+
# If self has a batch_dims of 1, we have just the time dimension, so we need to remove the batch dim from the result
|
393
|
+
if self.batch_dims == 1:
|
394
|
+
if result.batch_size[0] != 1:
|
395
|
+
raise RuntimeError(
|
396
|
+
f"Expected a batch size of 1, got {result.batch_size[0]}."
|
397
|
+
)
|
398
|
+
result = result.squeeze(0)
|
399
|
+
return result
|
372
400
|
|
373
401
|
@classmethod
|
374
402
|
def from_text(
|
@@ -376,10 +404,20 @@ class History(TensorClass["nocast"]):
|
|
376
404
|
text: str | list[str],
|
377
405
|
chat_template_name: Literal["chatml_format", "qwen"] | None = None,
|
378
406
|
chat_template: str | None = None,
|
407
|
+
tokenizer: transformers.AutoTokenizer # noqa: F821
|
408
|
+
| transformers.AutoProcessor # noqa: F821
|
409
|
+
| None = None,
|
379
410
|
) -> History:
|
380
|
-
if chat_template_name
|
411
|
+
if chat_template_name is None and chat_template is None:
|
412
|
+
if "qwen" in getattr(tokenizer, "name_or_path", "").lower():
|
413
|
+
# We can automatically detect the template name from the tokenizer
|
414
|
+
# and use the precoded parser.
|
415
|
+
chat_template_name = "qwen"
|
416
|
+
else:
|
417
|
+
chat_template_name = "chatml_format"
|
418
|
+
elif chat_template_name in ("chatml_format",):
|
381
419
|
func = cls._inv_chatml
|
382
|
-
elif chat_template_name
|
420
|
+
elif chat_template_name in ("qwen",):
|
383
421
|
func = cls._inv_qwen
|
384
422
|
else:
|
385
423
|
raise NotImplementedError(
|
@@ -736,3 +774,15 @@ class History(TensorClass["nocast"]):
|
|
736
774
|
}
|
737
775
|
|
738
776
|
return Composite(defaults, shape=shape[:-1], data_cls=cls)
|
777
|
+
|
778
|
+
@classmethod
|
779
|
+
def from_chats(cls, chats: list[list[dict]]) -> History:
|
780
|
+
"""Create a History object from a list of chats.
|
781
|
+
|
782
|
+
Args:
|
783
|
+
chats (list[list[dict]]): A list of chats, where each chat is a list of dictionaries.
|
784
|
+
"""
|
785
|
+
if isinstance(chats[0], dict):
|
786
|
+
return lazy_stack([cls(**chat) for chat in chats])
|
787
|
+
else:
|
788
|
+
return lazy_stack([cls.from_chats(chat) for chat in chats])
|
torchrl/data/llm/topk.py
ADDED
@@ -0,0 +1,186 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
#
|
3
|
+
# This source code is licensed under the MIT license found in the
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
from collections import defaultdict, deque
|
8
|
+
from typing import Any
|
9
|
+
|
10
|
+
import torch
|
11
|
+
from tensordict import NestedKey, TensorDictBase
|
12
|
+
from torchrl._utils import logger as torchrl_logger
|
13
|
+
from torchrl.envs.transforms import Transform
|
14
|
+
|
15
|
+
|
16
|
+
class TopKRewardSelector(Transform):
|
17
|
+
"""A replay-buffer transform that selects the top-k rewards for each prompt.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
total_dialog_turns (int): Number of dialog turns to keep in memory for the top-k selection.
|
21
|
+
topk_size (int): Number of top-k rewards to select. Must be smaller than or equal to total_dialog_turns.
|
22
|
+
prompt_key (NestedKey): Key to the prompt in the tensordict. Defaults to "text".
|
23
|
+
rewards_key (NestedKey): Key to the rewards in the tensordict. Defaults to ("next", "reward").
|
24
|
+
done_key (NestedKey): Key to the done state in the tensordict. Defaults to ("next", "done").
|
25
|
+
verbose (bool): Whether to print verbose information. Defaults to `False`.
|
26
|
+
|
27
|
+
Example:
|
28
|
+
>>> from torchrl.data import ReplayBuffer, LazyStackStorage, SamplerWithoutReplacement
|
29
|
+
>>> from tensordict import TensorDict, lazy_stack
|
30
|
+
>>> import torch
|
31
|
+
>>> from torchrl.data.llm.topk import TopKRewardSelector
|
32
|
+
>>> # Create a replay buffer with 50 items, a sampler that samples without replacement, and a batch size of 5
|
33
|
+
>>> rb = ReplayBuffer(
|
34
|
+
... storage=LazyStackStorage(50),
|
35
|
+
... sampler=SamplerWithoutReplacement,
|
36
|
+
... batch_size=5,
|
37
|
+
... )
|
38
|
+
>>> # Create a tensordict with 50 items, each with 10 dialog turns
|
39
|
+
>>> td = lazy_stack(
|
40
|
+
... [
|
41
|
+
... TensorDict(
|
42
|
+
... {
|
43
|
+
... ("next", "done"): torch.full((1, 1), True),
|
44
|
+
... # Reward for i+5 tokens
|
45
|
+
... ("next", "reward"): torch.full((i + 5, 1), i),
|
46
|
+
... # total of 10 dialogs per prompt
|
47
|
+
... "text": f"Prompt {i // 5}",
|
48
|
+
... }
|
49
|
+
... )
|
50
|
+
... for i in range(50)
|
51
|
+
... ]
|
52
|
+
... )
|
53
|
+
>>> # Create a top-k reward selector with 5 dialog turns and a top-k size of 3
|
54
|
+
>>> topk = TopKRewardSelector(total_dialog_turns=5, topk_size=3)
|
55
|
+
>>> rb.append_transform(topk)
|
56
|
+
>>> for _td in td.chunk(25):
|
57
|
+
... rb.extend(_td)
|
58
|
+
>>> # Only wrote top3 of 50 items in 10 groups of 5
|
59
|
+
>>> assert rb.write_count == 30
|
60
|
+
>>> assert len(rb) == 30
|
61
|
+
>>> r3 = rb[:3].get(("next", "reward"), as_padded_tensor=True).squeeze()
|
62
|
+
>>> # 0 and 1 are missing because they're not part of the top-k
|
63
|
+
>>> assert (
|
64
|
+
... r3 == torch.tensor(
|
65
|
+
... [
|
66
|
+
... [4, 4, 4, 4, 4, 4, 4, 4, 4],
|
67
|
+
... [3, 3, 3, 3, 3, 3, 3, 3, 0],
|
68
|
+
... [2, 2, 2, 2, 2, 2, 2, 0, 0],
|
69
|
+
... ]
|
70
|
+
... )
|
71
|
+
... ).all()
|
72
|
+
"""
|
73
|
+
|
74
|
+
def __init__(
|
75
|
+
self,
|
76
|
+
total_dialog_turns: int,
|
77
|
+
topk_size: int,
|
78
|
+
prompt_key: NestedKey = "text",
|
79
|
+
rewards_key: NestedKey = ("next", "reward"),
|
80
|
+
done_key: NestedKey = ("next", "done"),
|
81
|
+
verbose: bool = True,
|
82
|
+
):
|
83
|
+
super().__init__()
|
84
|
+
self.in_keys = [prompt_key, rewards_key, done_key]
|
85
|
+
self.prompt_key = prompt_key
|
86
|
+
self.rewards_key = rewards_key
|
87
|
+
self.done_key = done_key
|
88
|
+
self.queues = defaultdict(lambda: deque(maxlen=total_dialog_turns))
|
89
|
+
self.total_dialog_turns = total_dialog_turns
|
90
|
+
self.topk_size = topk_size
|
91
|
+
if topk_size > total_dialog_turns:
|
92
|
+
raise ValueError(
|
93
|
+
f"topk_size must be smaller than or equal to total_dialog_turns, got {topk_size=} and {total_dialog_turns=}"
|
94
|
+
)
|
95
|
+
self.verbose = verbose
|
96
|
+
|
97
|
+
def forward(self, tensordict: TensorDictBase) -> Any:
|
98
|
+
return tensordict
|
99
|
+
|
100
|
+
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
|
101
|
+
# Tensordict can be any number of dims, but it must contain entire trajectories
|
102
|
+
if tensordict.ndim == 1:
|
103
|
+
# Check how many done states we have
|
104
|
+
num_done = tensordict[self.done_key].sum()
|
105
|
+
if num_done > 1:
|
106
|
+
done_idx = tensordict[self.done_key].nonzero(as_tuple=True)[0] + 1
|
107
|
+
splits = torch.cat([done_idx.new_zeros((1,)), done_idx], dim=0).diff()
|
108
|
+
tensordicts = tensordict.split(splits)
|
109
|
+
tensordicts = [self._inv_call(td) for td in tensordicts]
|
110
|
+
tensordicts = [td for td in tensordicts if td is not None]
|
111
|
+
return torch.cat(tensordicts) if tensordicts else None
|
112
|
+
# Then we have a single trajectory
|
113
|
+
if not tensordict[-1][self.done_key].all():
|
114
|
+
raise RuntimeError("Expected the trajectory to be done.")
|
115
|
+
prompt = tensordict[0][self.prompt_key]
|
116
|
+
if not isinstance(prompt, str):
|
117
|
+
raise TypeError(f"Expected a string as prompt, got {type(prompt)=}")
|
118
|
+
self.queues[prompt].append(tensordict)
|
119
|
+
if len(self.queues[prompt]) == self.total_dialog_turns:
|
120
|
+
if self.verbose:
|
121
|
+
torchrl_logger.info(f"Getting top-k rewards for {prompt=}")
|
122
|
+
# Cat is the most robust way to combine the trajs
|
123
|
+
tds = torch.cat(list(self.queues[prompt]), -1)
|
124
|
+
# Collect rewards
|
125
|
+
reward = tds.get(self.rewards_key, as_nested_tensor=True)
|
126
|
+
reward = self._aggregate_rewards(reward)
|
127
|
+
# Check if all rewards are equal
|
128
|
+
if (reward == reward[0]).all():
|
129
|
+
# If all rewards are equal, we can't select top-k
|
130
|
+
if self.verbose:
|
131
|
+
torchrl_logger.warning(
|
132
|
+
f"All rewards are equal ({reward.unique()=})"
|
133
|
+
)
|
134
|
+
return
|
135
|
+
# Filter out rewards below median
|
136
|
+
median_reward = reward.median(dim=-1, keepdim=True)[0]
|
137
|
+
mask = reward > median_reward
|
138
|
+
filtered_reward = reward[mask]
|
139
|
+
filtered_indices = mask.nonzero(as_tuple=True)[0]
|
140
|
+
# Get top-k from filtered rewards
|
141
|
+
topk_reward = filtered_reward.topk(
|
142
|
+
k=min(self.topk_size, len(filtered_indices)), dim=-1
|
143
|
+
)
|
144
|
+
if not topk_reward.indices.numel():
|
145
|
+
if self.verbose:
|
146
|
+
torchrl_logger.warning(
|
147
|
+
f"No top-{self.topk_size} rewards found ({reward=})"
|
148
|
+
)
|
149
|
+
return
|
150
|
+
# Map back to original indices
|
151
|
+
selected_indices = filtered_indices[topk_reward.indices]
|
152
|
+
tds = tds[selected_indices]
|
153
|
+
if self.verbose:
|
154
|
+
torchrl_logger.info(
|
155
|
+
f"Selected top-{self.topk_size} rewards, with reward {topk_reward.values=}"
|
156
|
+
)
|
157
|
+
return tds
|
158
|
+
return
|
159
|
+
elif tensordict.ndim > 2:
|
160
|
+
# keep the time dim at the end
|
161
|
+
tensordict = tensordict.flatten(0, -2)
|
162
|
+
trajs = tensordict.unbind(-1)
|
163
|
+
# Iterate over the trajectories
|
164
|
+
result = []
|
165
|
+
for traj in trajs:
|
166
|
+
td_out = self._inv_call(traj)
|
167
|
+
if td_out is None:
|
168
|
+
continue
|
169
|
+
result.append(td_out)
|
170
|
+
if result:
|
171
|
+
return torch.cat(result, -1)
|
172
|
+
return
|
173
|
+
|
174
|
+
def _aggregate_rewards(self, reward: torch.Tensor) -> torch.Tensor:
|
175
|
+
"""Aggregate the rewards across the dialog turns.
|
176
|
+
|
177
|
+
`reward` is expected to be a nested tensor.
|
178
|
+
|
179
|
+
The default implementation is to take the mean of the rewards across the dialog turns.
|
180
|
+
"""
|
181
|
+
# reward = reward.to_padded_tensor(padding=0.0)
|
182
|
+
if reward.ndim < 2 or reward.ndim > 3:
|
183
|
+
raise ValueError(
|
184
|
+
f"Expected reward to be a 2D or 3D tensor, got {reward.ndim}D tensor"
|
185
|
+
)
|
186
|
+
return reward.mean(dim=-2).squeeze(-1)
|
@@ -54,9 +54,12 @@ class RayReplayBuffer(ReplayBuffer):
|
|
54
54
|
"""A Ray implementation of the Replay Buffer that can be extended and sampled remotely.
|
55
55
|
|
56
56
|
Keyword Args:
|
57
|
+
replay_buffer_cls (type[ReplayBuffer], optional): the class to use for the replay buffer.
|
58
|
+
Defaults to :class:`~torchrl.data.ReplayBuffer`.
|
57
59
|
ray_init_config (dict[str, Any], optiona): keyword arguments to pass to `ray.init()`.
|
58
60
|
remote_config (dict[str, Any], optiona): keyword arguments to pass to `cls.as_remote()`.
|
59
61
|
Defaults to `torchrl.collectors.distributed.ray.DEFAULT_REMOTE_CLASS_CONFIG`.
|
62
|
+
**kwargs: keyword arguments to pass to the replay buffer class.
|
60
63
|
|
61
64
|
.. seealso:: :class:`~torchrl.data.ReplayBuffer` for a list of other keyword arguments.
|
62
65
|
|
@@ -119,6 +122,7 @@ class RayReplayBuffer(ReplayBuffer):
|
|
119
122
|
def __init__(
|
120
123
|
self,
|
121
124
|
*args,
|
125
|
+
replay_buffer_cls: type[ReplayBuffer] | None = ReplayBuffer,
|
122
126
|
ray_init_config: dict[str, Any] | None = None,
|
123
127
|
remote_config: dict[str, Any] | None = None,
|
124
128
|
**kwargs,
|
@@ -134,7 +138,13 @@ class RayReplayBuffer(ReplayBuffer):
|
|
134
138
|
ray_init_config = DEFAULT_RAY_INIT_CONFIG
|
135
139
|
ray.init(**ray_init_config)
|
136
140
|
|
137
|
-
remote_cls =
|
141
|
+
remote_cls = replay_buffer_cls.as_remote(remote_config).remote
|
142
|
+
# We can detect if the buffer has a GPU allocated, if not
|
143
|
+
# we'll make sure that the data is sent to CPU when needed.
|
144
|
+
if remote_config is not None:
|
145
|
+
self.has_gpu = remote_config.get("num_gpus", 0) > 0
|
146
|
+
else:
|
147
|
+
self.has_gpu = False
|
138
148
|
self._rb = remote_cls(*args, **kwargs)
|
139
149
|
|
140
150
|
def close(self):
|
@@ -158,6 +168,10 @@ class RayReplayBuffer(ReplayBuffer):
|
|
158
168
|
return ray.get(pending_task)
|
159
169
|
|
160
170
|
def extend(self, *args, **kwargs):
|
171
|
+
if not self.has_gpu:
|
172
|
+
# Move the data to GPU
|
173
|
+
args = [arg.to("cpu") for arg in args if hasattr(arg, "to")]
|
174
|
+
kwargs = {k: v.to("cpu") for k, v in kwargs.items() if hasattr(v, "to")}
|
161
175
|
pending_task = self._rb.extend.remote(*args, **kwargs)
|
162
176
|
return ray.get(pending_task)
|
163
177
|
|