torchrl-nightly 2025.6.20__cp310-cp310-macosx_10_9_universal2.whl → 2025.6.21__cp310-cp310-macosx_10_9_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Binary file
@@ -352,8 +352,8 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
352
352
  self._iterator = iter(self)
353
353
  out = next(self._iterator)
354
354
  # if any, we don't want the device ref to be passed in distributed settings
355
- if out is not None:
356
- out.clear_device_()
355
+ if out is not None and (out.device != "cpu"):
356
+ out = out.copy().clear_device_()
357
357
  return out
358
358
  except StopIteration:
359
359
  return None
@@ -892,7 +892,10 @@ class SyncDataCollector(DataCollectorBase):
892
892
  and hasattr(self.postproc, "to")
893
893
  and self.storing_device
894
894
  ):
895
- self.postproc.to(self.storing_device)
895
+ postproc = self.postproc.to(self.storing_device)
896
+ if postproc is not self.postproc and postproc is not None:
897
+ self.postproc = postproc
898
+
896
899
  if frames_per_batch % self.n_env != 0 and RL_WARNINGS:
897
900
  warnings.warn(
898
901
  f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), "
@@ -1253,9 +1256,9 @@ class SyncDataCollector(DataCollectorBase):
1253
1256
  yield
1254
1257
  continue
1255
1258
  self._increment_frames(tensordict_out.numel())
1256
- if self.verbose:
1257
- torchrl_logger.info("Collector: postproc.")
1258
1259
  tensordict_out = self._postproc(tensordict_out)
1260
+ if self.verbose:
1261
+ torchrl_logger.info("Collector: postproc done.")
1259
1262
  if self.return_same_td:
1260
1263
  # This is used with multiprocessed collectors to use the buffers
1261
1264
  # stored in the tensordict.
@@ -242,6 +242,11 @@ class LLMCollector(SyncDataCollector):
242
242
  else:
243
243
  self.policy_version_tracker = None
244
244
 
245
+ def set_postproc(self, postproc: Callable[[TensorDictBase], TensorDictBase]):
246
+ if self.postproc is not None:
247
+ raise RuntimeError("Postproc already set")
248
+ self.postproc = postproc
249
+
245
250
  def increment_version(self):
246
251
  """Increment the policy version."""
247
252
  if self.policy_version_tracker is not None:
@@ -361,9 +366,10 @@ class LLMCollector(SyncDataCollector):
361
366
  )
362
367
  self._yield_queues[idx].clear()
363
368
  result = self._trajectory_queue.popleft()
364
- torchrl_logger.info(
365
- f"LLMCollector: Yielding completed trajectory with shape {result.shape}."
366
- )
369
+ if self.verbose:
370
+ torchrl_logger.info(
371
+ f"LLMCollector: Yielding completed trajectory with shape {result.shape}."
372
+ )
367
373
  return result
368
374
 
369
375
  started = False
@@ -422,9 +428,10 @@ class LLMCollector(SyncDataCollector):
422
428
  self.env.async_step_and_maybe_reset_send(env_input)
423
429
 
424
430
  result = self._trajectory_queue.popleft()
425
- torchrl_logger.info(
426
- f"LLMCollector: Yielding completed trajectory with shape {result.shape}."
427
- )
431
+ if self.verbose:
432
+ torchrl_logger.info(
433
+ f"LLMCollector: Yielding completed trajectory with shape {result.shape}."
434
+ )
428
435
  return result
429
436
 
430
437
  as_remote = as_remote
@@ -134,6 +134,9 @@ class RayLLMCollector(LLMCollector):
134
134
  verbose=verbose,
135
135
  )
136
136
 
137
+ def set_postproc(self, postproc: Callable[[TensorDictBase], TensorDictBase]):
138
+ return ray.get(self._collector.set_postproc.remote(postproc))
139
+
137
140
  def _next_remote(self) -> None:
138
141
  return self._collector.next.remote()
139
142
 
torchrl/data/__init__.py CHANGED
@@ -17,6 +17,7 @@ from .llm import (
17
17
  RolloutFromModel,
18
18
  TensorDictTokenizer,
19
19
  TokenizedDatasetLoader,
20
+ TopKRewardSelector,
20
21
  )
21
22
  from .map import (
22
23
  BinaryToDecimal,
@@ -116,6 +117,7 @@ __all__ = [
116
117
  "Categorical",
117
118
  "Choice",
118
119
  "ContentBase",
120
+ "TopKRewardSelector",
119
121
  "Composite",
120
122
  "CompositeSpec",
121
123
  "ConstantKLController",
@@ -13,6 +13,7 @@ from .dataset import (
13
13
  )
14
14
  from .prompt import PromptData, PromptTensorDictTokenizer
15
15
  from .reward import PairwiseDataset, RewardData
16
+ from .topk import TopKRewardSelector
16
17
  from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel
17
18
 
18
19
  __all__ = [
@@ -30,4 +31,5 @@ __all__ = [
30
31
  "TokenizedDatasetLoader",
31
32
  "create_infinite_iterator",
32
33
  "get_dataloader",
34
+ "TopKRewardSelector",
33
35
  ]
torchrl/data/llm/chat.py CHANGED
@@ -11,18 +11,27 @@ from typing import Literal
11
11
 
12
12
  import torch
13
13
 
14
- from tensordict import lazy_stack, LazyStackedTensorDict, list_to_stack, TensorClass
14
+ from tensordict import (
15
+ lazy_stack,
16
+ LazyStackedTensorDict,
17
+ list_to_stack,
18
+ TensorClass,
19
+ TensorDict,
20
+ )
15
21
  from tensordict.utils import _maybe_correct_neg_dim
16
-
17
22
  from torchrl._utils import logger as torchrl_logger
18
23
 
19
24
 
20
25
  _CHAT_TEMPLATES = {
21
26
  "chatml_format": """{% for message in messages %}
27
+ {%- if message['role'] == 'assistant' %}
28
+ {% generation %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endgeneration %}
29
+ {%- else %}
22
30
  {{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
31
+ {%- endif %}
23
32
  {% endfor %}
24
33
  {%- if add_generation_prompt %}
25
- {{- '<|im_start|>assistant\n' }}
34
+ {% generation %}{{- '<|im_start|>assistant\n' }}{% endgeneration %}
26
35
  {%- endif %}
27
36
  """,
28
37
  "qwen": """
@@ -282,7 +291,7 @@ class History(TensorClass["nocast"]):
282
291
 
283
292
  Keyword Args:
284
293
  tokenizer (transformers.PreTrainedTokenizer | transformers.AutoProcessor): The tokenizer to use.
285
- add_generation_prompt (bool, optional): Whether to add a generation prompt. Defaults to `True`.
294
+ add_generation_prompt (bool, optional): Whether to add a generation prompt (e.g. `"<|im_start|>assistant"`). Defaults to `True`.
286
295
  chat_template (str, optional): The chat template to use. Defaults to the tokenizer's default template.
287
296
  chat_template_name (Literal["chatml_format", "qwen"], optional): The name of the chat template to use.
288
297
  Prevalent over `tokenizer.chat_template`. Defaults to `None`.
@@ -293,6 +302,7 @@ class History(TensorClass["nocast"]):
293
302
  return_tensors (str | None, optional): The type of tensors to return. Defaults to "pt".
294
303
  return_dict (bool, optional): Whether to return a dictionary. Defaults to `False`.
295
304
  return_assistant_tokens_mask (bool, optional): Whether to return a mask of the assistant generated tokens.
305
+ If `True`, the mask will be written to the `assistant_masks` key.
296
306
  For tokens generated by the assistant, the mask will contain `1`.
297
307
  For user and system tokens, the mask will contain `0`.
298
308
  This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
@@ -315,6 +325,11 @@ class History(TensorClass["nocast"]):
315
325
  raise RuntimeError(
316
326
  "You must specify a tokenizer to use when chat_template is not specified."
317
327
  )
328
+ elif "qwen" in getattr(tokenizer, "name_or_path", "").lower():
329
+ # We prefer our implementation of the Qwen template,
330
+ # since it accounts for the assistant's masking.
331
+ chat_template = _CHAT_TEMPLATES["qwen"]
332
+ chat_template_name = None
318
333
  else:
319
334
  chat_template = tokenizer.chat_template
320
335
  if chat_template is None:
@@ -333,7 +348,7 @@ class History(TensorClass["nocast"]):
333
348
  return_dict = False
334
349
 
335
350
  if self.ndim > 1:
336
- return [
351
+ result = [
337
352
  self[i].apply_chat_template(
338
353
  tokenizer=tokenizer,
339
354
  add_generation_prompt=add_generation_prompt,
@@ -350,12 +365,16 @@ class History(TensorClass["nocast"]):
350
365
  )
351
366
  for i in range(self.batch_size[0])
352
367
  ]
368
+ if return_dict:
369
+ return lazy_stack(result)
370
+ else:
371
+ return result
353
372
  self_flat = self.view(-1)
354
373
  # tolist_first=True is needed to avoid having a list of dict of dicts, but a list of dicts of lists of dicts
355
374
  self_flat = self_flat.tolist(tolist_first=True)
356
375
  # Remove the "<none>" role
357
376
  self_flat = [item for item in self_flat if item["role"] != "<none>"]
358
- return tokenizer.apply_chat_template(
377
+ result = tokenizer.apply_chat_template(
359
378
  conversation=self_flat,
360
379
  add_generation_prompt=add_generation_prompt,
361
380
  chat_template=chat_template,
@@ -368,6 +387,16 @@ class History(TensorClass["nocast"]):
368
387
  return_assistant_tokens_mask=return_assistant_tokens_mask,
369
388
  **kwargs,
370
389
  )
390
+ if not isinstance(result, (torch.Tensor, list, str)):
391
+ result = TensorDict.from_dict(result, auto_batch_size=True, batch_dims=1)
392
+ # If self has a batch_dims of 1, we have just the time dimension, so we need to remove the batch dim from the result
393
+ if self.batch_dims == 1:
394
+ if result.batch_size[0] != 1:
395
+ raise RuntimeError(
396
+ f"Expected a batch size of 1, got {result.batch_size[0]}."
397
+ )
398
+ result = result.squeeze(0)
399
+ return result
371
400
 
372
401
  @classmethod
373
402
  def from_text(
@@ -375,10 +404,20 @@ class History(TensorClass["nocast"]):
375
404
  text: str | list[str],
376
405
  chat_template_name: Literal["chatml_format", "qwen"] | None = None,
377
406
  chat_template: str | None = None,
407
+ tokenizer: transformers.AutoTokenizer # noqa: F821
408
+ | transformers.AutoProcessor # noqa: F821
409
+ | None = None,
378
410
  ) -> History:
379
- if chat_template_name in ("chatml_format", None):
411
+ if chat_template_name is None and chat_template is None:
412
+ if "qwen" in getattr(tokenizer, "name_or_path", "").lower():
413
+ # We can automatically detect the template name from the tokenizer
414
+ # and use the precoded parser.
415
+ chat_template_name = "qwen"
416
+ else:
417
+ chat_template_name = "chatml_format"
418
+ elif chat_template_name in ("chatml_format",):
380
419
  func = cls._inv_chatml
381
- elif chat_template_name == "qwen":
420
+ elif chat_template_name in ("qwen",):
382
421
  func = cls._inv_qwen
383
422
  else:
384
423
  raise NotImplementedError(
@@ -735,3 +774,15 @@ class History(TensorClass["nocast"]):
735
774
  }
736
775
 
737
776
  return Composite(defaults, shape=shape[:-1], data_cls=cls)
777
+
778
+ @classmethod
779
+ def from_chats(cls, chats: list[list[dict]]) -> History:
780
+ """Create a History object from a list of chats.
781
+
782
+ Args:
783
+ chats (list[list[dict]]): A list of chats, where each chat is a list of dictionaries.
784
+ """
785
+ if isinstance(chats[0], dict):
786
+ return lazy_stack([cls(**chat) for chat in chats])
787
+ else:
788
+ return lazy_stack([cls.from_chats(chat) for chat in chats])
@@ -0,0 +1,186 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from collections import defaultdict, deque
8
+ from typing import Any
9
+
10
+ import torch
11
+ from tensordict import NestedKey, TensorDictBase
12
+ from torchrl._utils import logger as torchrl_logger
13
+ from torchrl.envs.transforms import Transform
14
+
15
+
16
+ class TopKRewardSelector(Transform):
17
+ """A replay-buffer transform that selects the top-k rewards for each prompt.
18
+
19
+ Args:
20
+ total_dialog_turns (int): Number of dialog turns to keep in memory for the top-k selection.
21
+ topk_size (int): Number of top-k rewards to select. Must be smaller than or equal to total_dialog_turns.
22
+ prompt_key (NestedKey): Key to the prompt in the tensordict. Defaults to "text".
23
+ rewards_key (NestedKey): Key to the rewards in the tensordict. Defaults to ("next", "reward").
24
+ done_key (NestedKey): Key to the done state in the tensordict. Defaults to ("next", "done").
25
+ verbose (bool): Whether to print verbose information. Defaults to `False`.
26
+
27
+ Example:
28
+ >>> from torchrl.data import ReplayBuffer, LazyStackStorage, SamplerWithoutReplacement
29
+ >>> from tensordict import TensorDict, lazy_stack
30
+ >>> import torch
31
+ >>> from torchrl.data.llm.topk import TopKRewardSelector
32
+ >>> # Create a replay buffer with 50 items, a sampler that samples without replacement, and a batch size of 5
33
+ >>> rb = ReplayBuffer(
34
+ ... storage=LazyStackStorage(50),
35
+ ... sampler=SamplerWithoutReplacement,
36
+ ... batch_size=5,
37
+ ... )
38
+ >>> # Create a tensordict with 50 items, each with 10 dialog turns
39
+ >>> td = lazy_stack(
40
+ ... [
41
+ ... TensorDict(
42
+ ... {
43
+ ... ("next", "done"): torch.full((1, 1), True),
44
+ ... # Reward for i+5 tokens
45
+ ... ("next", "reward"): torch.full((i + 5, 1), i),
46
+ ... # total of 10 dialogs per prompt
47
+ ... "text": f"Prompt {i // 5}",
48
+ ... }
49
+ ... )
50
+ ... for i in range(50)
51
+ ... ]
52
+ ... )
53
+ >>> # Create a top-k reward selector with 5 dialog turns and a top-k size of 3
54
+ >>> topk = TopKRewardSelector(total_dialog_turns=5, topk_size=3)
55
+ >>> rb.append_transform(topk)
56
+ >>> for _td in td.chunk(25):
57
+ ... rb.extend(_td)
58
+ >>> # Only wrote top3 of 50 items in 10 groups of 5
59
+ >>> assert rb.write_count == 30
60
+ >>> assert len(rb) == 30
61
+ >>> r3 = rb[:3].get(("next", "reward"), as_padded_tensor=True).squeeze()
62
+ >>> # 0 and 1 are missing because they're not part of the top-k
63
+ >>> assert (
64
+ ... r3 == torch.tensor(
65
+ ... [
66
+ ... [4, 4, 4, 4, 4, 4, 4, 4, 4],
67
+ ... [3, 3, 3, 3, 3, 3, 3, 3, 0],
68
+ ... [2, 2, 2, 2, 2, 2, 2, 0, 0],
69
+ ... ]
70
+ ... )
71
+ ... ).all()
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ total_dialog_turns: int,
77
+ topk_size: int,
78
+ prompt_key: NestedKey = "text",
79
+ rewards_key: NestedKey = ("next", "reward"),
80
+ done_key: NestedKey = ("next", "done"),
81
+ verbose: bool = True,
82
+ ):
83
+ super().__init__()
84
+ self.in_keys = [prompt_key, rewards_key, done_key]
85
+ self.prompt_key = prompt_key
86
+ self.rewards_key = rewards_key
87
+ self.done_key = done_key
88
+ self.queues = defaultdict(lambda: deque(maxlen=total_dialog_turns))
89
+ self.total_dialog_turns = total_dialog_turns
90
+ self.topk_size = topk_size
91
+ if topk_size > total_dialog_turns:
92
+ raise ValueError(
93
+ f"topk_size must be smaller than or equal to total_dialog_turns, got {topk_size=} and {total_dialog_turns=}"
94
+ )
95
+ self.verbose = verbose
96
+
97
+ def forward(self, tensordict: TensorDictBase) -> Any:
98
+ return tensordict
99
+
100
+ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
101
+ # Tensordict can be any number of dims, but it must contain entire trajectories
102
+ if tensordict.ndim == 1:
103
+ # Check how many done states we have
104
+ num_done = tensordict[self.done_key].sum()
105
+ if num_done > 1:
106
+ done_idx = tensordict[self.done_key].nonzero(as_tuple=True)[0] + 1
107
+ splits = torch.cat([done_idx.new_zeros((1,)), done_idx], dim=0).diff()
108
+ tensordicts = tensordict.split(splits)
109
+ tensordicts = [self._inv_call(td) for td in tensordicts]
110
+ tensordicts = [td for td in tensordicts if td is not None]
111
+ return torch.cat(tensordicts) if tensordicts else None
112
+ # Then we have a single trajectory
113
+ if not tensordict[-1][self.done_key].all():
114
+ raise RuntimeError("Expected the trajectory to be done.")
115
+ prompt = tensordict[0][self.prompt_key]
116
+ if not isinstance(prompt, str):
117
+ raise TypeError(f"Expected a string as prompt, got {type(prompt)=}")
118
+ self.queues[prompt].append(tensordict)
119
+ if len(self.queues[prompt]) == self.total_dialog_turns:
120
+ if self.verbose:
121
+ torchrl_logger.info(f"Getting top-k rewards for {prompt=}")
122
+ # Cat is the most robust way to combine the trajs
123
+ tds = torch.cat(list(self.queues[prompt]), -1)
124
+ # Collect rewards
125
+ reward = tds.get(self.rewards_key, as_nested_tensor=True)
126
+ reward = self._aggregate_rewards(reward)
127
+ # Check if all rewards are equal
128
+ if (reward == reward[0]).all():
129
+ # If all rewards are equal, we can't select top-k
130
+ if self.verbose:
131
+ torchrl_logger.warning(
132
+ f"All rewards are equal ({reward.unique()=})"
133
+ )
134
+ return
135
+ # Filter out rewards below median
136
+ median_reward = reward.median(dim=-1, keepdim=True)[0]
137
+ mask = reward > median_reward
138
+ filtered_reward = reward[mask]
139
+ filtered_indices = mask.nonzero(as_tuple=True)[0]
140
+ # Get top-k from filtered rewards
141
+ topk_reward = filtered_reward.topk(
142
+ k=min(self.topk_size, len(filtered_indices)), dim=-1
143
+ )
144
+ if not topk_reward.indices.numel():
145
+ if self.verbose:
146
+ torchrl_logger.warning(
147
+ f"No top-{self.topk_size} rewards found ({reward=})"
148
+ )
149
+ return
150
+ # Map back to original indices
151
+ selected_indices = filtered_indices[topk_reward.indices]
152
+ tds = tds[selected_indices]
153
+ if self.verbose:
154
+ torchrl_logger.info(
155
+ f"Selected top-{self.topk_size} rewards, with reward {topk_reward.values=}"
156
+ )
157
+ return tds
158
+ return
159
+ elif tensordict.ndim > 2:
160
+ # keep the time dim at the end
161
+ tensordict = tensordict.flatten(0, -2)
162
+ trajs = tensordict.unbind(-1)
163
+ # Iterate over the trajectories
164
+ result = []
165
+ for traj in trajs:
166
+ td_out = self._inv_call(traj)
167
+ if td_out is None:
168
+ continue
169
+ result.append(td_out)
170
+ if result:
171
+ return torch.cat(result, -1)
172
+ return
173
+
174
+ def _aggregate_rewards(self, reward: torch.Tensor) -> torch.Tensor:
175
+ """Aggregate the rewards across the dialog turns.
176
+
177
+ `reward` is expected to be a nested tensor.
178
+
179
+ The default implementation is to take the mean of the rewards across the dialog turns.
180
+ """
181
+ # reward = reward.to_padded_tensor(padding=0.0)
182
+ if reward.ndim < 2 or reward.ndim > 3:
183
+ raise ValueError(
184
+ f"Expected reward to be a 2D or 3D tensor, got {reward.ndim}D tensor"
185
+ )
186
+ return reward.mean(dim=-2).squeeze(-1)
@@ -54,9 +54,12 @@ class RayReplayBuffer(ReplayBuffer):
54
54
  """A Ray implementation of the Replay Buffer that can be extended and sampled remotely.
55
55
 
56
56
  Keyword Args:
57
+ replay_buffer_cls (type[ReplayBuffer], optional): the class to use for the replay buffer.
58
+ Defaults to :class:`~torchrl.data.ReplayBuffer`.
57
59
  ray_init_config (dict[str, Any], optiona): keyword arguments to pass to `ray.init()`.
58
60
  remote_config (dict[str, Any], optiona): keyword arguments to pass to `cls.as_remote()`.
59
61
  Defaults to `torchrl.collectors.distributed.ray.DEFAULT_REMOTE_CLASS_CONFIG`.
62
+ **kwargs: keyword arguments to pass to the replay buffer class.
60
63
 
61
64
  .. seealso:: :class:`~torchrl.data.ReplayBuffer` for a list of other keyword arguments.
62
65
 
@@ -119,6 +122,7 @@ class RayReplayBuffer(ReplayBuffer):
119
122
  def __init__(
120
123
  self,
121
124
  *args,
125
+ replay_buffer_cls: type[ReplayBuffer] | None = ReplayBuffer,
122
126
  ray_init_config: dict[str, Any] | None = None,
123
127
  remote_config: dict[str, Any] | None = None,
124
128
  **kwargs,
@@ -134,7 +138,13 @@ class RayReplayBuffer(ReplayBuffer):
134
138
  ray_init_config = DEFAULT_RAY_INIT_CONFIG
135
139
  ray.init(**ray_init_config)
136
140
 
137
- remote_cls = ReplayBuffer.as_remote(remote_config).remote
141
+ remote_cls = replay_buffer_cls.as_remote(remote_config).remote
142
+ # We can detect if the buffer has a GPU allocated, if not
143
+ # we'll make sure that the data is sent to CPU when needed.
144
+ if remote_config is not None:
145
+ self.has_gpu = remote_config.get("num_gpus", 0) > 0
146
+ else:
147
+ self.has_gpu = False
138
148
  self._rb = remote_cls(*args, **kwargs)
139
149
 
140
150
  def close(self):
@@ -158,6 +168,10 @@ class RayReplayBuffer(ReplayBuffer):
158
168
  return ray.get(pending_task)
159
169
 
160
170
  def extend(self, *args, **kwargs):
171
+ if not self.has_gpu:
172
+ # Move the data to GPU
173
+ args = [arg.to("cpu") for arg in args if hasattr(arg, "to")]
174
+ kwargs = {k: v.to("cpu") for k, v in kwargs.items() if hasattr(v, "to")}
161
175
  pending_task = self._rb.extend.remote(*args, **kwargs)
162
176
  return ray.get(pending_task)
163
177
 
@@ -702,7 +702,7 @@ class ReplayBuffer:
702
702
  self._sampler.add(index)
703
703
  return index
704
704
 
705
- def _extend(self, data: Sequence) -> torch.Tensor:
705
+ def _extend(self, data: Sequence, *, update_priority: bool = True) -> torch.Tensor:
706
706
  is_comp = is_compiling()
707
707
  nc = contextlib.nullcontext()
708
708
  with self._replay_lock if not is_comp else nc, self._write_lock if not is_comp else nc:
@@ -712,7 +712,9 @@ class ReplayBuffer:
712
712
  self._sampler.extend(index)
713
713
  return index
714
714
 
715
- def extend(self, data: Sequence) -> torch.Tensor:
715
+ def extend(
716
+ self, data: Sequence, *, update_priority: bool | None = None
717
+ ) -> torch.Tensor:
716
718
  """Extends the replay buffer with one or more elements contained in an iterable.
717
719
 
718
720
  If present, the inverse transforms will be called.`
@@ -721,6 +723,10 @@ class ReplayBuffer:
721
723
  data (iterable): collection of data to be added to the replay
722
724
  buffer.
723
725
 
726
+ Keyword Args:
727
+ update_priority (bool, optional): Whether to update the priority of the data. Defaults to True.
728
+ Without effect in this class. See :meth:`~torchrl.data.TensorDictReplayBuffer.extend` for more details.
729
+
724
730
  Returns:
725
731
  Indices of the data added to the replay buffer.
726
732
 
@@ -735,12 +741,16 @@ class ReplayBuffer:
735
741
  unbound elements can be provided (no PyTrees).
736
742
 
737
743
  """
744
+ if update_priority is not None:
745
+ raise NotImplementedError(
746
+ "update_priority is not supported in this class. See :meth:`~torchrl.data.TensorDictReplayBuffer.extend` for more details."
747
+ )
738
748
  if self._transform is not None and len(self._transform):
739
749
  with _set_dispatch_td_nn_modules(is_tensor_collection(data)):
740
750
  data = self._transform.inv(data)
741
751
  if data is None:
742
752
  return torch.zeros((0, self._storage.ndim), dtype=torch.long)
743
- return self._extend(data)
753
+ return self._extend(data, update_priority=update_priority)
744
754
 
745
755
  def update_priority(
746
756
  self,
@@ -914,8 +924,8 @@ class ReplayBuffer:
914
924
  self._iterator = iter(self)
915
925
  out = next(self._iterator)
916
926
  # if any, we don't want the device ref to be passed in distributed settings
917
- if out is not None:
918
- out.clear_device_()
927
+ if out is not None and (out.device != "cpu"):
928
+ out = out.copy().clear_device_()
919
929
  return out
920
930
  except StopIteration:
921
931
  self._iterator = None
@@ -1015,6 +1025,9 @@ class PrioritizedReplayBuffer(ReplayBuffer):
1015
1025
  storage (Storage, optional): the storage to be used. If none is provided
1016
1026
  a default :class:`~torchrl.data.replay_buffers.ListStorage` with
1017
1027
  ``max_size`` of ``1_000`` will be created.
1028
+ sampler (Sampler, optional): the sampler to be used. If none is provided,
1029
+ a default :class:`~torchrl.data.replay_buffers.PrioritizedSampler` with
1030
+ ``alpha``, ``beta``, and ``eps`` will be created.
1018
1031
  collate_fn (callable, optional): merges a list of samples to form a
1019
1032
  mini-batch of Tensor(s)/outputs. Used when using batched
1020
1033
  loading from a map-style dataset. The default value will be decided
@@ -1107,6 +1120,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
1107
1120
  eps: float = 1e-8,
1108
1121
  dtype: torch.dtype = torch.float,
1109
1122
  storage: Storage | None = None,
1123
+ sampler: Sampler | None = None,
1110
1124
  collate_fn: Callable | None = None,
1111
1125
  pin_memory: bool = False,
1112
1126
  prefetch: int | None = None,
@@ -1116,7 +1130,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
1116
1130
  ) -> None:
1117
1131
  if storage is None:
1118
1132
  storage = ListStorage(max_size=1_000)
1119
- sampler = PrioritizedSampler(storage.max_size, alpha, beta, eps, dtype)
1133
+ if sampler is None:
1134
+ sampler = PrioritizedSampler(storage.max_size, alpha, beta, eps, dtype)
1120
1135
  super().__init__(
1121
1136
  storage=storage,
1122
1137
  sampler=sampler,
@@ -1347,7 +1362,20 @@ class TensorDictReplayBuffer(ReplayBuffer):
1347
1362
  self.update_tensordict_priority(data)
1348
1363
  return index
1349
1364
 
1350
- def extend(self, tensordicts: TensorDictBase) -> torch.Tensor:
1365
+ def extend(
1366
+ self, tensordicts: TensorDictBase, *, update_priority: bool | None = None
1367
+ ) -> torch.Tensor:
1368
+ """Extends the replay buffer with a batch of data.
1369
+
1370
+ Args:
1371
+ tensordicts (TensorDictBase): The data to extend the replay buffer with.
1372
+
1373
+ Keyword Args:
1374
+ update_priority (bool, optional): Whether to update the priority of the data. Defaults to True.
1375
+
1376
+ Returns:
1377
+ The indices of the data that were added to the replay buffer.
1378
+ """
1351
1379
  if not isinstance(tensordicts, TensorDictBase):
1352
1380
  raise ValueError(
1353
1381
  f"{self.__class__.__name__} only accepts TensorDictBase subclasses. tensorclasses "
@@ -1365,8 +1393,17 @@ class TensorDictReplayBuffer(ReplayBuffer):
1365
1393
  # is that just doing this results in indices that are not sorted like the original data
1366
1394
  # so the actually indices will have to be used on the _storage directly (not on the buffer)
1367
1395
  self._set_index_in_td(tensordicts, index)
1368
- # TODO: in principle this is a good idea but currently it doesn't work + it re-writes a priority that has just been written
1369
- # self.update_tensordict_priority(tensordicts)
1396
+ if update_priority is None:
1397
+ update_priority = True
1398
+ if update_priority:
1399
+ try:
1400
+ vector = tensordicts.get(self.priority_key)
1401
+ if vector is not None:
1402
+ self.update_priority(index, vector)
1403
+ except Exception as e:
1404
+ raise RuntimeError(
1405
+ "Failed to update priority of extended data. You can try to set update_priority=False in the extend method and update the priority manually."
1406
+ ) from e
1370
1407
  return index
1371
1408
 
1372
1409
  def _set_index_in_td(self, tensordict, index):
@@ -1685,8 +1722,10 @@ class RemoteTensorDictReplayBuffer(TensorDictReplayBuffer):
1685
1722
  def add(self, data: TensorDictBase) -> int:
1686
1723
  return super().add(data)
1687
1724
 
1688
- def extend(self, tensordicts: list | TensorDictBase) -> torch.Tensor:
1689
- return super().extend(tensordicts)
1725
+ def extend(
1726
+ self, tensordicts: list | TensorDictBase, *, update_priority: bool | None = None
1727
+ ) -> torch.Tensor:
1728
+ return super().extend(tensordicts, update_priority=update_priority)
1690
1729
 
1691
1730
  def update_priority(
1692
1731
  self, index: int | torch.Tensor, priority: int | torch.Tensor