torchrl-nightly 2025.6.20__cp39-cp39-macosx_10_9_universal2.whl → 2025.6.21__cp39-cp39-macosx_10_9_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrl/_torchrl.cpython-39-darwin.so +0 -0
- torchrl/collectors/collectors.py +8 -5
- torchrl/collectors/llm/base.py +13 -6
- torchrl/collectors/llm/ray_collector.py +3 -0
- torchrl/data/__init__.py +2 -0
- torchrl/data/llm/__init__.py +2 -0
- torchrl/data/llm/chat.py +59 -8
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/replay_buffers/ray_buffer.py +15 -1
- torchrl/data/replay_buffers/replay_buffers.py +50 -11
- torchrl/data/replay_buffers/samplers.py +98 -21
- torchrl/data/replay_buffers/storages.py +29 -2
- torchrl/envs/llm/__init__.py +2 -0
- torchrl/envs/llm/chat.py +4 -1
- torchrl/envs/llm/reward/gsm8k.py +15 -8
- torchrl/envs/llm/transforms/__init__.py +2 -1
- torchrl/envs/llm/transforms/kl.py +240 -4
- torchrl/modules/llm/policies/transformers_wrapper.py +71 -15
- torchrl/modules/llm/policies/vllm_wrapper.py +38 -5
- torchrl/objectives/llm/__init__.py +2 -1
- torchrl/objectives/llm/sft.py +465 -0
- torchrl/version.py +2 -2
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.21.dist-info}/METADATA +1 -1
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.21.dist-info}/RECORD +27 -25
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.21.dist-info}/LICENSE +0 -0
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.21.dist-info}/WHEEL +0 -0
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.21.dist-info}/top_level.txt +0 -0
Binary file
|
torchrl/collectors/collectors.py
CHANGED
@@ -352,8 +352,8 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
|
|
352
352
|
self._iterator = iter(self)
|
353
353
|
out = next(self._iterator)
|
354
354
|
# if any, we don't want the device ref to be passed in distributed settings
|
355
|
-
if out is not None:
|
356
|
-
out.clear_device_()
|
355
|
+
if out is not None and (out.device != "cpu"):
|
356
|
+
out = out.copy().clear_device_()
|
357
357
|
return out
|
358
358
|
except StopIteration:
|
359
359
|
return None
|
@@ -892,7 +892,10 @@ class SyncDataCollector(DataCollectorBase):
|
|
892
892
|
and hasattr(self.postproc, "to")
|
893
893
|
and self.storing_device
|
894
894
|
):
|
895
|
-
self.postproc.to(self.storing_device)
|
895
|
+
postproc = self.postproc.to(self.storing_device)
|
896
|
+
if postproc is not self.postproc and postproc is not None:
|
897
|
+
self.postproc = postproc
|
898
|
+
|
896
899
|
if frames_per_batch % self.n_env != 0 and RL_WARNINGS:
|
897
900
|
warnings.warn(
|
898
901
|
f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), "
|
@@ -1253,9 +1256,9 @@ class SyncDataCollector(DataCollectorBase):
|
|
1253
1256
|
yield
|
1254
1257
|
continue
|
1255
1258
|
self._increment_frames(tensordict_out.numel())
|
1256
|
-
if self.verbose:
|
1257
|
-
torchrl_logger.info("Collector: postproc.")
|
1258
1259
|
tensordict_out = self._postproc(tensordict_out)
|
1260
|
+
if self.verbose:
|
1261
|
+
torchrl_logger.info("Collector: postproc done.")
|
1259
1262
|
if self.return_same_td:
|
1260
1263
|
# This is used with multiprocessed collectors to use the buffers
|
1261
1264
|
# stored in the tensordict.
|
torchrl/collectors/llm/base.py
CHANGED
@@ -242,6 +242,11 @@ class LLMCollector(SyncDataCollector):
|
|
242
242
|
else:
|
243
243
|
self.policy_version_tracker = None
|
244
244
|
|
245
|
+
def set_postproc(self, postproc: Callable[[TensorDictBase], TensorDictBase]):
|
246
|
+
if self.postproc is not None:
|
247
|
+
raise RuntimeError("Postproc already set")
|
248
|
+
self.postproc = postproc
|
249
|
+
|
245
250
|
def increment_version(self):
|
246
251
|
"""Increment the policy version."""
|
247
252
|
if self.policy_version_tracker is not None:
|
@@ -361,9 +366,10 @@ class LLMCollector(SyncDataCollector):
|
|
361
366
|
)
|
362
367
|
self._yield_queues[idx].clear()
|
363
368
|
result = self._trajectory_queue.popleft()
|
364
|
-
|
365
|
-
|
366
|
-
|
369
|
+
if self.verbose:
|
370
|
+
torchrl_logger.info(
|
371
|
+
f"LLMCollector: Yielding completed trajectory with shape {result.shape}."
|
372
|
+
)
|
367
373
|
return result
|
368
374
|
|
369
375
|
started = False
|
@@ -422,9 +428,10 @@ class LLMCollector(SyncDataCollector):
|
|
422
428
|
self.env.async_step_and_maybe_reset_send(env_input)
|
423
429
|
|
424
430
|
result = self._trajectory_queue.popleft()
|
425
|
-
|
426
|
-
|
427
|
-
|
431
|
+
if self.verbose:
|
432
|
+
torchrl_logger.info(
|
433
|
+
f"LLMCollector: Yielding completed trajectory with shape {result.shape}."
|
434
|
+
)
|
428
435
|
return result
|
429
436
|
|
430
437
|
as_remote = as_remote
|
@@ -134,6 +134,9 @@ class RayLLMCollector(LLMCollector):
|
|
134
134
|
verbose=verbose,
|
135
135
|
)
|
136
136
|
|
137
|
+
def set_postproc(self, postproc: Callable[[TensorDictBase], TensorDictBase]):
|
138
|
+
return ray.get(self._collector.set_postproc.remote(postproc))
|
139
|
+
|
137
140
|
def _next_remote(self) -> None:
|
138
141
|
return self._collector.next.remote()
|
139
142
|
|
torchrl/data/__init__.py
CHANGED
@@ -17,6 +17,7 @@ from .llm import (
|
|
17
17
|
RolloutFromModel,
|
18
18
|
TensorDictTokenizer,
|
19
19
|
TokenizedDatasetLoader,
|
20
|
+
TopKRewardSelector,
|
20
21
|
)
|
21
22
|
from .map import (
|
22
23
|
BinaryToDecimal,
|
@@ -116,6 +117,7 @@ __all__ = [
|
|
116
117
|
"Categorical",
|
117
118
|
"Choice",
|
118
119
|
"ContentBase",
|
120
|
+
"TopKRewardSelector",
|
119
121
|
"Composite",
|
120
122
|
"CompositeSpec",
|
121
123
|
"ConstantKLController",
|
torchrl/data/llm/__init__.py
CHANGED
@@ -13,6 +13,7 @@ from .dataset import (
|
|
13
13
|
)
|
14
14
|
from .prompt import PromptData, PromptTensorDictTokenizer
|
15
15
|
from .reward import PairwiseDataset, RewardData
|
16
|
+
from .topk import TopKRewardSelector
|
16
17
|
from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel
|
17
18
|
|
18
19
|
__all__ = [
|
@@ -30,4 +31,5 @@ __all__ = [
|
|
30
31
|
"TokenizedDatasetLoader",
|
31
32
|
"create_infinite_iterator",
|
32
33
|
"get_dataloader",
|
34
|
+
"TopKRewardSelector",
|
33
35
|
]
|
torchrl/data/llm/chat.py
CHANGED
@@ -11,18 +11,27 @@ from typing import Literal
|
|
11
11
|
|
12
12
|
import torch
|
13
13
|
|
14
|
-
from tensordict import
|
14
|
+
from tensordict import (
|
15
|
+
lazy_stack,
|
16
|
+
LazyStackedTensorDict,
|
17
|
+
list_to_stack,
|
18
|
+
TensorClass,
|
19
|
+
TensorDict,
|
20
|
+
)
|
15
21
|
from tensordict.utils import _maybe_correct_neg_dim
|
16
|
-
|
17
22
|
from torchrl._utils import logger as torchrl_logger
|
18
23
|
|
19
24
|
|
20
25
|
_CHAT_TEMPLATES = {
|
21
26
|
"chatml_format": """{% for message in messages %}
|
27
|
+
{%- if message['role'] == 'assistant' %}
|
28
|
+
{% generation %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endgeneration %}
|
29
|
+
{%- else %}
|
22
30
|
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
|
31
|
+
{%- endif %}
|
23
32
|
{% endfor %}
|
24
33
|
{%- if add_generation_prompt %}
|
25
|
-
{{- '<|im_start|>assistant\n' }}
|
34
|
+
{% generation %}{{- '<|im_start|>assistant\n' }}{% endgeneration %}
|
26
35
|
{%- endif %}
|
27
36
|
""",
|
28
37
|
"qwen": """
|
@@ -282,7 +291,7 @@ class History(TensorClass["nocast"]):
|
|
282
291
|
|
283
292
|
Keyword Args:
|
284
293
|
tokenizer (transformers.PreTrainedTokenizer | transformers.AutoProcessor): The tokenizer to use.
|
285
|
-
add_generation_prompt (bool, optional): Whether to add a generation prompt. Defaults to `True`.
|
294
|
+
add_generation_prompt (bool, optional): Whether to add a generation prompt (e.g. `"<|im_start|>assistant"`). Defaults to `True`.
|
286
295
|
chat_template (str, optional): The chat template to use. Defaults to the tokenizer's default template.
|
287
296
|
chat_template_name (Literal["chatml_format", "qwen"], optional): The name of the chat template to use.
|
288
297
|
Prevalent over `tokenizer.chat_template`. Defaults to `None`.
|
@@ -293,6 +302,7 @@ class History(TensorClass["nocast"]):
|
|
293
302
|
return_tensors (str | None, optional): The type of tensors to return. Defaults to "pt".
|
294
303
|
return_dict (bool, optional): Whether to return a dictionary. Defaults to `False`.
|
295
304
|
return_assistant_tokens_mask (bool, optional): Whether to return a mask of the assistant generated tokens.
|
305
|
+
If `True`, the mask will be written to the `assistant_masks` key.
|
296
306
|
For tokens generated by the assistant, the mask will contain `1`.
|
297
307
|
For user and system tokens, the mask will contain `0`.
|
298
308
|
This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
|
@@ -315,6 +325,11 @@ class History(TensorClass["nocast"]):
|
|
315
325
|
raise RuntimeError(
|
316
326
|
"You must specify a tokenizer to use when chat_template is not specified."
|
317
327
|
)
|
328
|
+
elif "qwen" in getattr(tokenizer, "name_or_path", "").lower():
|
329
|
+
# We prefer our implementation of the Qwen template,
|
330
|
+
# since it accounts for the assistant's masking.
|
331
|
+
chat_template = _CHAT_TEMPLATES["qwen"]
|
332
|
+
chat_template_name = None
|
318
333
|
else:
|
319
334
|
chat_template = tokenizer.chat_template
|
320
335
|
if chat_template is None:
|
@@ -333,7 +348,7 @@ class History(TensorClass["nocast"]):
|
|
333
348
|
return_dict = False
|
334
349
|
|
335
350
|
if self.ndim > 1:
|
336
|
-
|
351
|
+
result = [
|
337
352
|
self[i].apply_chat_template(
|
338
353
|
tokenizer=tokenizer,
|
339
354
|
add_generation_prompt=add_generation_prompt,
|
@@ -350,12 +365,16 @@ class History(TensorClass["nocast"]):
|
|
350
365
|
)
|
351
366
|
for i in range(self.batch_size[0])
|
352
367
|
]
|
368
|
+
if return_dict:
|
369
|
+
return lazy_stack(result)
|
370
|
+
else:
|
371
|
+
return result
|
353
372
|
self_flat = self.view(-1)
|
354
373
|
# tolist_first=True is needed to avoid having a list of dict of dicts, but a list of dicts of lists of dicts
|
355
374
|
self_flat = self_flat.tolist(tolist_first=True)
|
356
375
|
# Remove the "<none>" role
|
357
376
|
self_flat = [item for item in self_flat if item["role"] != "<none>"]
|
358
|
-
|
377
|
+
result = tokenizer.apply_chat_template(
|
359
378
|
conversation=self_flat,
|
360
379
|
add_generation_prompt=add_generation_prompt,
|
361
380
|
chat_template=chat_template,
|
@@ -368,6 +387,16 @@ class History(TensorClass["nocast"]):
|
|
368
387
|
return_assistant_tokens_mask=return_assistant_tokens_mask,
|
369
388
|
**kwargs,
|
370
389
|
)
|
390
|
+
if not isinstance(result, (torch.Tensor, list, str)):
|
391
|
+
result = TensorDict.from_dict(result, auto_batch_size=True, batch_dims=1)
|
392
|
+
# If self has a batch_dims of 1, we have just the time dimension, so we need to remove the batch dim from the result
|
393
|
+
if self.batch_dims == 1:
|
394
|
+
if result.batch_size[0] != 1:
|
395
|
+
raise RuntimeError(
|
396
|
+
f"Expected a batch size of 1, got {result.batch_size[0]}."
|
397
|
+
)
|
398
|
+
result = result.squeeze(0)
|
399
|
+
return result
|
371
400
|
|
372
401
|
@classmethod
|
373
402
|
def from_text(
|
@@ -375,10 +404,20 @@ class History(TensorClass["nocast"]):
|
|
375
404
|
text: str | list[str],
|
376
405
|
chat_template_name: Literal["chatml_format", "qwen"] | None = None,
|
377
406
|
chat_template: str | None = None,
|
407
|
+
tokenizer: transformers.AutoTokenizer # noqa: F821
|
408
|
+
| transformers.AutoProcessor # noqa: F821
|
409
|
+
| None = None,
|
378
410
|
) -> History:
|
379
|
-
if chat_template_name
|
411
|
+
if chat_template_name is None and chat_template is None:
|
412
|
+
if "qwen" in getattr(tokenizer, "name_or_path", "").lower():
|
413
|
+
# We can automatically detect the template name from the tokenizer
|
414
|
+
# and use the precoded parser.
|
415
|
+
chat_template_name = "qwen"
|
416
|
+
else:
|
417
|
+
chat_template_name = "chatml_format"
|
418
|
+
elif chat_template_name in ("chatml_format",):
|
380
419
|
func = cls._inv_chatml
|
381
|
-
elif chat_template_name
|
420
|
+
elif chat_template_name in ("qwen",):
|
382
421
|
func = cls._inv_qwen
|
383
422
|
else:
|
384
423
|
raise NotImplementedError(
|
@@ -735,3 +774,15 @@ class History(TensorClass["nocast"]):
|
|
735
774
|
}
|
736
775
|
|
737
776
|
return Composite(defaults, shape=shape[:-1], data_cls=cls)
|
777
|
+
|
778
|
+
@classmethod
|
779
|
+
def from_chats(cls, chats: list[list[dict]]) -> History:
|
780
|
+
"""Create a History object from a list of chats.
|
781
|
+
|
782
|
+
Args:
|
783
|
+
chats (list[list[dict]]): A list of chats, where each chat is a list of dictionaries.
|
784
|
+
"""
|
785
|
+
if isinstance(chats[0], dict):
|
786
|
+
return lazy_stack([cls(**chat) for chat in chats])
|
787
|
+
else:
|
788
|
+
return lazy_stack([cls.from_chats(chat) for chat in chats])
|
torchrl/data/llm/topk.py
ADDED
@@ -0,0 +1,186 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
#
|
3
|
+
# This source code is licensed under the MIT license found in the
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
from collections import defaultdict, deque
|
8
|
+
from typing import Any
|
9
|
+
|
10
|
+
import torch
|
11
|
+
from tensordict import NestedKey, TensorDictBase
|
12
|
+
from torchrl._utils import logger as torchrl_logger
|
13
|
+
from torchrl.envs.transforms import Transform
|
14
|
+
|
15
|
+
|
16
|
+
class TopKRewardSelector(Transform):
|
17
|
+
"""A replay-buffer transform that selects the top-k rewards for each prompt.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
total_dialog_turns (int): Number of dialog turns to keep in memory for the top-k selection.
|
21
|
+
topk_size (int): Number of top-k rewards to select. Must be smaller than or equal to total_dialog_turns.
|
22
|
+
prompt_key (NestedKey): Key to the prompt in the tensordict. Defaults to "text".
|
23
|
+
rewards_key (NestedKey): Key to the rewards in the tensordict. Defaults to ("next", "reward").
|
24
|
+
done_key (NestedKey): Key to the done state in the tensordict. Defaults to ("next", "done").
|
25
|
+
verbose (bool): Whether to print verbose information. Defaults to `False`.
|
26
|
+
|
27
|
+
Example:
|
28
|
+
>>> from torchrl.data import ReplayBuffer, LazyStackStorage, SamplerWithoutReplacement
|
29
|
+
>>> from tensordict import TensorDict, lazy_stack
|
30
|
+
>>> import torch
|
31
|
+
>>> from torchrl.data.llm.topk import TopKRewardSelector
|
32
|
+
>>> # Create a replay buffer with 50 items, a sampler that samples without replacement, and a batch size of 5
|
33
|
+
>>> rb = ReplayBuffer(
|
34
|
+
... storage=LazyStackStorage(50),
|
35
|
+
... sampler=SamplerWithoutReplacement,
|
36
|
+
... batch_size=5,
|
37
|
+
... )
|
38
|
+
>>> # Create a tensordict with 50 items, each with 10 dialog turns
|
39
|
+
>>> td = lazy_stack(
|
40
|
+
... [
|
41
|
+
... TensorDict(
|
42
|
+
... {
|
43
|
+
... ("next", "done"): torch.full((1, 1), True),
|
44
|
+
... # Reward for i+5 tokens
|
45
|
+
... ("next", "reward"): torch.full((i + 5, 1), i),
|
46
|
+
... # total of 10 dialogs per prompt
|
47
|
+
... "text": f"Prompt {i // 5}",
|
48
|
+
... }
|
49
|
+
... )
|
50
|
+
... for i in range(50)
|
51
|
+
... ]
|
52
|
+
... )
|
53
|
+
>>> # Create a top-k reward selector with 5 dialog turns and a top-k size of 3
|
54
|
+
>>> topk = TopKRewardSelector(total_dialog_turns=5, topk_size=3)
|
55
|
+
>>> rb.append_transform(topk)
|
56
|
+
>>> for _td in td.chunk(25):
|
57
|
+
... rb.extend(_td)
|
58
|
+
>>> # Only wrote top3 of 50 items in 10 groups of 5
|
59
|
+
>>> assert rb.write_count == 30
|
60
|
+
>>> assert len(rb) == 30
|
61
|
+
>>> r3 = rb[:3].get(("next", "reward"), as_padded_tensor=True).squeeze()
|
62
|
+
>>> # 0 and 1 are missing because they're not part of the top-k
|
63
|
+
>>> assert (
|
64
|
+
... r3 == torch.tensor(
|
65
|
+
... [
|
66
|
+
... [4, 4, 4, 4, 4, 4, 4, 4, 4],
|
67
|
+
... [3, 3, 3, 3, 3, 3, 3, 3, 0],
|
68
|
+
... [2, 2, 2, 2, 2, 2, 2, 0, 0],
|
69
|
+
... ]
|
70
|
+
... )
|
71
|
+
... ).all()
|
72
|
+
"""
|
73
|
+
|
74
|
+
def __init__(
|
75
|
+
self,
|
76
|
+
total_dialog_turns: int,
|
77
|
+
topk_size: int,
|
78
|
+
prompt_key: NestedKey = "text",
|
79
|
+
rewards_key: NestedKey = ("next", "reward"),
|
80
|
+
done_key: NestedKey = ("next", "done"),
|
81
|
+
verbose: bool = True,
|
82
|
+
):
|
83
|
+
super().__init__()
|
84
|
+
self.in_keys = [prompt_key, rewards_key, done_key]
|
85
|
+
self.prompt_key = prompt_key
|
86
|
+
self.rewards_key = rewards_key
|
87
|
+
self.done_key = done_key
|
88
|
+
self.queues = defaultdict(lambda: deque(maxlen=total_dialog_turns))
|
89
|
+
self.total_dialog_turns = total_dialog_turns
|
90
|
+
self.topk_size = topk_size
|
91
|
+
if topk_size > total_dialog_turns:
|
92
|
+
raise ValueError(
|
93
|
+
f"topk_size must be smaller than or equal to total_dialog_turns, got {topk_size=} and {total_dialog_turns=}"
|
94
|
+
)
|
95
|
+
self.verbose = verbose
|
96
|
+
|
97
|
+
def forward(self, tensordict: TensorDictBase) -> Any:
|
98
|
+
return tensordict
|
99
|
+
|
100
|
+
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
|
101
|
+
# Tensordict can be any number of dims, but it must contain entire trajectories
|
102
|
+
if tensordict.ndim == 1:
|
103
|
+
# Check how many done states we have
|
104
|
+
num_done = tensordict[self.done_key].sum()
|
105
|
+
if num_done > 1:
|
106
|
+
done_idx = tensordict[self.done_key].nonzero(as_tuple=True)[0] + 1
|
107
|
+
splits = torch.cat([done_idx.new_zeros((1,)), done_idx], dim=0).diff()
|
108
|
+
tensordicts = tensordict.split(splits)
|
109
|
+
tensordicts = [self._inv_call(td) for td in tensordicts]
|
110
|
+
tensordicts = [td for td in tensordicts if td is not None]
|
111
|
+
return torch.cat(tensordicts) if tensordicts else None
|
112
|
+
# Then we have a single trajectory
|
113
|
+
if not tensordict[-1][self.done_key].all():
|
114
|
+
raise RuntimeError("Expected the trajectory to be done.")
|
115
|
+
prompt = tensordict[0][self.prompt_key]
|
116
|
+
if not isinstance(prompt, str):
|
117
|
+
raise TypeError(f"Expected a string as prompt, got {type(prompt)=}")
|
118
|
+
self.queues[prompt].append(tensordict)
|
119
|
+
if len(self.queues[prompt]) == self.total_dialog_turns:
|
120
|
+
if self.verbose:
|
121
|
+
torchrl_logger.info(f"Getting top-k rewards for {prompt=}")
|
122
|
+
# Cat is the most robust way to combine the trajs
|
123
|
+
tds = torch.cat(list(self.queues[prompt]), -1)
|
124
|
+
# Collect rewards
|
125
|
+
reward = tds.get(self.rewards_key, as_nested_tensor=True)
|
126
|
+
reward = self._aggregate_rewards(reward)
|
127
|
+
# Check if all rewards are equal
|
128
|
+
if (reward == reward[0]).all():
|
129
|
+
# If all rewards are equal, we can't select top-k
|
130
|
+
if self.verbose:
|
131
|
+
torchrl_logger.warning(
|
132
|
+
f"All rewards are equal ({reward.unique()=})"
|
133
|
+
)
|
134
|
+
return
|
135
|
+
# Filter out rewards below median
|
136
|
+
median_reward = reward.median(dim=-1, keepdim=True)[0]
|
137
|
+
mask = reward > median_reward
|
138
|
+
filtered_reward = reward[mask]
|
139
|
+
filtered_indices = mask.nonzero(as_tuple=True)[0]
|
140
|
+
# Get top-k from filtered rewards
|
141
|
+
topk_reward = filtered_reward.topk(
|
142
|
+
k=min(self.topk_size, len(filtered_indices)), dim=-1
|
143
|
+
)
|
144
|
+
if not topk_reward.indices.numel():
|
145
|
+
if self.verbose:
|
146
|
+
torchrl_logger.warning(
|
147
|
+
f"No top-{self.topk_size} rewards found ({reward=})"
|
148
|
+
)
|
149
|
+
return
|
150
|
+
# Map back to original indices
|
151
|
+
selected_indices = filtered_indices[topk_reward.indices]
|
152
|
+
tds = tds[selected_indices]
|
153
|
+
if self.verbose:
|
154
|
+
torchrl_logger.info(
|
155
|
+
f"Selected top-{self.topk_size} rewards, with reward {topk_reward.values=}"
|
156
|
+
)
|
157
|
+
return tds
|
158
|
+
return
|
159
|
+
elif tensordict.ndim > 2:
|
160
|
+
# keep the time dim at the end
|
161
|
+
tensordict = tensordict.flatten(0, -2)
|
162
|
+
trajs = tensordict.unbind(-1)
|
163
|
+
# Iterate over the trajectories
|
164
|
+
result = []
|
165
|
+
for traj in trajs:
|
166
|
+
td_out = self._inv_call(traj)
|
167
|
+
if td_out is None:
|
168
|
+
continue
|
169
|
+
result.append(td_out)
|
170
|
+
if result:
|
171
|
+
return torch.cat(result, -1)
|
172
|
+
return
|
173
|
+
|
174
|
+
def _aggregate_rewards(self, reward: torch.Tensor) -> torch.Tensor:
|
175
|
+
"""Aggregate the rewards across the dialog turns.
|
176
|
+
|
177
|
+
`reward` is expected to be a nested tensor.
|
178
|
+
|
179
|
+
The default implementation is to take the mean of the rewards across the dialog turns.
|
180
|
+
"""
|
181
|
+
# reward = reward.to_padded_tensor(padding=0.0)
|
182
|
+
if reward.ndim < 2 or reward.ndim > 3:
|
183
|
+
raise ValueError(
|
184
|
+
f"Expected reward to be a 2D or 3D tensor, got {reward.ndim}D tensor"
|
185
|
+
)
|
186
|
+
return reward.mean(dim=-2).squeeze(-1)
|
@@ -54,9 +54,12 @@ class RayReplayBuffer(ReplayBuffer):
|
|
54
54
|
"""A Ray implementation of the Replay Buffer that can be extended and sampled remotely.
|
55
55
|
|
56
56
|
Keyword Args:
|
57
|
+
replay_buffer_cls (type[ReplayBuffer], optional): the class to use for the replay buffer.
|
58
|
+
Defaults to :class:`~torchrl.data.ReplayBuffer`.
|
57
59
|
ray_init_config (dict[str, Any], optiona): keyword arguments to pass to `ray.init()`.
|
58
60
|
remote_config (dict[str, Any], optiona): keyword arguments to pass to `cls.as_remote()`.
|
59
61
|
Defaults to `torchrl.collectors.distributed.ray.DEFAULT_REMOTE_CLASS_CONFIG`.
|
62
|
+
**kwargs: keyword arguments to pass to the replay buffer class.
|
60
63
|
|
61
64
|
.. seealso:: :class:`~torchrl.data.ReplayBuffer` for a list of other keyword arguments.
|
62
65
|
|
@@ -119,6 +122,7 @@ class RayReplayBuffer(ReplayBuffer):
|
|
119
122
|
def __init__(
|
120
123
|
self,
|
121
124
|
*args,
|
125
|
+
replay_buffer_cls: type[ReplayBuffer] | None = ReplayBuffer,
|
122
126
|
ray_init_config: dict[str, Any] | None = None,
|
123
127
|
remote_config: dict[str, Any] | None = None,
|
124
128
|
**kwargs,
|
@@ -134,7 +138,13 @@ class RayReplayBuffer(ReplayBuffer):
|
|
134
138
|
ray_init_config = DEFAULT_RAY_INIT_CONFIG
|
135
139
|
ray.init(**ray_init_config)
|
136
140
|
|
137
|
-
remote_cls =
|
141
|
+
remote_cls = replay_buffer_cls.as_remote(remote_config).remote
|
142
|
+
# We can detect if the buffer has a GPU allocated, if not
|
143
|
+
# we'll make sure that the data is sent to CPU when needed.
|
144
|
+
if remote_config is not None:
|
145
|
+
self.has_gpu = remote_config.get("num_gpus", 0) > 0
|
146
|
+
else:
|
147
|
+
self.has_gpu = False
|
138
148
|
self._rb = remote_cls(*args, **kwargs)
|
139
149
|
|
140
150
|
def close(self):
|
@@ -158,6 +168,10 @@ class RayReplayBuffer(ReplayBuffer):
|
|
158
168
|
return ray.get(pending_task)
|
159
169
|
|
160
170
|
def extend(self, *args, **kwargs):
|
171
|
+
if not self.has_gpu:
|
172
|
+
# Move the data to GPU
|
173
|
+
args = [arg.to("cpu") for arg in args if hasattr(arg, "to")]
|
174
|
+
kwargs = {k: v.to("cpu") for k, v in kwargs.items() if hasattr(v, "to")}
|
161
175
|
pending_task = self._rb.extend.remote(*args, **kwargs)
|
162
176
|
return ray.get(pending_task)
|
163
177
|
|
@@ -702,7 +702,7 @@ class ReplayBuffer:
|
|
702
702
|
self._sampler.add(index)
|
703
703
|
return index
|
704
704
|
|
705
|
-
def _extend(self, data: Sequence) -> torch.Tensor:
|
705
|
+
def _extend(self, data: Sequence, *, update_priority: bool = True) -> torch.Tensor:
|
706
706
|
is_comp = is_compiling()
|
707
707
|
nc = contextlib.nullcontext()
|
708
708
|
with self._replay_lock if not is_comp else nc, self._write_lock if not is_comp else nc:
|
@@ -712,7 +712,9 @@ class ReplayBuffer:
|
|
712
712
|
self._sampler.extend(index)
|
713
713
|
return index
|
714
714
|
|
715
|
-
def extend(
|
715
|
+
def extend(
|
716
|
+
self, data: Sequence, *, update_priority: bool | None = None
|
717
|
+
) -> torch.Tensor:
|
716
718
|
"""Extends the replay buffer with one or more elements contained in an iterable.
|
717
719
|
|
718
720
|
If present, the inverse transforms will be called.`
|
@@ -721,6 +723,10 @@ class ReplayBuffer:
|
|
721
723
|
data (iterable): collection of data to be added to the replay
|
722
724
|
buffer.
|
723
725
|
|
726
|
+
Keyword Args:
|
727
|
+
update_priority (bool, optional): Whether to update the priority of the data. Defaults to True.
|
728
|
+
Without effect in this class. See :meth:`~torchrl.data.TensorDictReplayBuffer.extend` for more details.
|
729
|
+
|
724
730
|
Returns:
|
725
731
|
Indices of the data added to the replay buffer.
|
726
732
|
|
@@ -735,12 +741,16 @@ class ReplayBuffer:
|
|
735
741
|
unbound elements can be provided (no PyTrees).
|
736
742
|
|
737
743
|
"""
|
744
|
+
if update_priority is not None:
|
745
|
+
raise NotImplementedError(
|
746
|
+
"update_priority is not supported in this class. See :meth:`~torchrl.data.TensorDictReplayBuffer.extend` for more details."
|
747
|
+
)
|
738
748
|
if self._transform is not None and len(self._transform):
|
739
749
|
with _set_dispatch_td_nn_modules(is_tensor_collection(data)):
|
740
750
|
data = self._transform.inv(data)
|
741
751
|
if data is None:
|
742
752
|
return torch.zeros((0, self._storage.ndim), dtype=torch.long)
|
743
|
-
return self._extend(data)
|
753
|
+
return self._extend(data, update_priority=update_priority)
|
744
754
|
|
745
755
|
def update_priority(
|
746
756
|
self,
|
@@ -914,8 +924,8 @@ class ReplayBuffer:
|
|
914
924
|
self._iterator = iter(self)
|
915
925
|
out = next(self._iterator)
|
916
926
|
# if any, we don't want the device ref to be passed in distributed settings
|
917
|
-
if out is not None:
|
918
|
-
out.clear_device_()
|
927
|
+
if out is not None and (out.device != "cpu"):
|
928
|
+
out = out.copy().clear_device_()
|
919
929
|
return out
|
920
930
|
except StopIteration:
|
921
931
|
self._iterator = None
|
@@ -1015,6 +1025,9 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|
1015
1025
|
storage (Storage, optional): the storage to be used. If none is provided
|
1016
1026
|
a default :class:`~torchrl.data.replay_buffers.ListStorage` with
|
1017
1027
|
``max_size`` of ``1_000`` will be created.
|
1028
|
+
sampler (Sampler, optional): the sampler to be used. If none is provided,
|
1029
|
+
a default :class:`~torchrl.data.replay_buffers.PrioritizedSampler` with
|
1030
|
+
``alpha``, ``beta``, and ``eps`` will be created.
|
1018
1031
|
collate_fn (callable, optional): merges a list of samples to form a
|
1019
1032
|
mini-batch of Tensor(s)/outputs. Used when using batched
|
1020
1033
|
loading from a map-style dataset. The default value will be decided
|
@@ -1107,6 +1120,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|
1107
1120
|
eps: float = 1e-8,
|
1108
1121
|
dtype: torch.dtype = torch.float,
|
1109
1122
|
storage: Storage | None = None,
|
1123
|
+
sampler: Sampler | None = None,
|
1110
1124
|
collate_fn: Callable | None = None,
|
1111
1125
|
pin_memory: bool = False,
|
1112
1126
|
prefetch: int | None = None,
|
@@ -1116,7 +1130,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|
1116
1130
|
) -> None:
|
1117
1131
|
if storage is None:
|
1118
1132
|
storage = ListStorage(max_size=1_000)
|
1119
|
-
sampler
|
1133
|
+
if sampler is None:
|
1134
|
+
sampler = PrioritizedSampler(storage.max_size, alpha, beta, eps, dtype)
|
1120
1135
|
super().__init__(
|
1121
1136
|
storage=storage,
|
1122
1137
|
sampler=sampler,
|
@@ -1347,7 +1362,20 @@ class TensorDictReplayBuffer(ReplayBuffer):
|
|
1347
1362
|
self.update_tensordict_priority(data)
|
1348
1363
|
return index
|
1349
1364
|
|
1350
|
-
def extend(
|
1365
|
+
def extend(
|
1366
|
+
self, tensordicts: TensorDictBase, *, update_priority: bool | None = None
|
1367
|
+
) -> torch.Tensor:
|
1368
|
+
"""Extends the replay buffer with a batch of data.
|
1369
|
+
|
1370
|
+
Args:
|
1371
|
+
tensordicts (TensorDictBase): The data to extend the replay buffer with.
|
1372
|
+
|
1373
|
+
Keyword Args:
|
1374
|
+
update_priority (bool, optional): Whether to update the priority of the data. Defaults to True.
|
1375
|
+
|
1376
|
+
Returns:
|
1377
|
+
The indices of the data that were added to the replay buffer.
|
1378
|
+
"""
|
1351
1379
|
if not isinstance(tensordicts, TensorDictBase):
|
1352
1380
|
raise ValueError(
|
1353
1381
|
f"{self.__class__.__name__} only accepts TensorDictBase subclasses. tensorclasses "
|
@@ -1365,8 +1393,17 @@ class TensorDictReplayBuffer(ReplayBuffer):
|
|
1365
1393
|
# is that just doing this results in indices that are not sorted like the original data
|
1366
1394
|
# so the actually indices will have to be used on the _storage directly (not on the buffer)
|
1367
1395
|
self._set_index_in_td(tensordicts, index)
|
1368
|
-
|
1369
|
-
|
1396
|
+
if update_priority is None:
|
1397
|
+
update_priority = True
|
1398
|
+
if update_priority:
|
1399
|
+
try:
|
1400
|
+
vector = tensordicts.get(self.priority_key)
|
1401
|
+
if vector is not None:
|
1402
|
+
self.update_priority(index, vector)
|
1403
|
+
except Exception as e:
|
1404
|
+
raise RuntimeError(
|
1405
|
+
"Failed to update priority of extended data. You can try to set update_priority=False in the extend method and update the priority manually."
|
1406
|
+
) from e
|
1370
1407
|
return index
|
1371
1408
|
|
1372
1409
|
def _set_index_in_td(self, tensordict, index):
|
@@ -1685,8 +1722,10 @@ class RemoteTensorDictReplayBuffer(TensorDictReplayBuffer):
|
|
1685
1722
|
def add(self, data: TensorDictBase) -> int:
|
1686
1723
|
return super().add(data)
|
1687
1724
|
|
1688
|
-
def extend(
|
1689
|
-
|
1725
|
+
def extend(
|
1726
|
+
self, tensordicts: list | TensorDictBase, *, update_priority: bool | None = None
|
1727
|
+
) -> torch.Tensor:
|
1728
|
+
return super().extend(tensordicts, update_priority=update_priority)
|
1690
1729
|
|
1691
1730
|
def update_priority(
|
1692
1731
|
self, index: int | torch.Tensor, priority: int | torch.Tensor
|