torchrl-nightly 2025.6.19__cp311-cp311-macosx_10_9_universal2.whl → 2025.6.21__cp311-cp311-macosx_10_9_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. torchrl/_torchrl.cpython-311-darwin.so +0 -0
  2. torchrl/collectors/collectors.py +49 -24
  3. torchrl/collectors/llm/base.py +13 -6
  4. torchrl/collectors/llm/ray_collector.py +3 -0
  5. torchrl/data/__init__.py +2 -0
  6. torchrl/data/datasets/minari_data.py +1 -1
  7. torchrl/data/llm/__init__.py +2 -0
  8. torchrl/data/llm/chat.py +59 -9
  9. torchrl/data/llm/topk.py +186 -0
  10. torchrl/data/replay_buffers/ray_buffer.py +15 -1
  11. torchrl/data/replay_buffers/replay_buffers.py +50 -11
  12. torchrl/data/replay_buffers/samplers.py +98 -21
  13. torchrl/data/replay_buffers/storages.py +29 -2
  14. torchrl/envs/llm/__init__.py +2 -0
  15. torchrl/envs/llm/chat.py +4 -1
  16. torchrl/envs/llm/reward/gsm8k.py +15 -8
  17. torchrl/envs/llm/transforms/__init__.py +2 -1
  18. torchrl/envs/llm/transforms/kl.py +240 -4
  19. torchrl/envs/transforms/transforms.py +11 -27
  20. torchrl/modules/llm/policies/transformers_wrapper.py +71 -15
  21. torchrl/modules/llm/policies/vllm_wrapper.py +38 -5
  22. torchrl/objectives/llm/__init__.py +2 -1
  23. torchrl/objectives/llm/sft.py +465 -0
  24. torchrl/objectives/ppo.py +35 -12
  25. torchrl/version.py +2 -2
  26. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/METADATA +1 -1
  27. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/RECORD +30 -28
  28. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/LICENSE +0 -0
  29. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/WHEEL +0 -0
  30. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/top_level.txt +0 -0
Binary file
@@ -352,8 +352,8 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
352
352
  self._iterator = iter(self)
353
353
  out = next(self._iterator)
354
354
  # if any, we don't want the device ref to be passed in distributed settings
355
- if out is not None:
356
- out.clear_device_()
355
+ if out is not None and (out.device != "cpu"):
356
+ out = out.copy().clear_device_()
357
357
  return out
358
358
  except StopIteration:
359
359
  return None
@@ -892,7 +892,10 @@ class SyncDataCollector(DataCollectorBase):
892
892
  and hasattr(self.postproc, "to")
893
893
  and self.storing_device
894
894
  ):
895
- self.postproc.to(self.storing_device)
895
+ postproc = self.postproc.to(self.storing_device)
896
+ if postproc is not self.postproc and postproc is not None:
897
+ self.postproc = postproc
898
+
896
899
  if frames_per_batch % self.n_env != 0 and RL_WARNINGS:
897
900
  warnings.warn(
898
901
  f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), "
@@ -1253,9 +1256,9 @@ class SyncDataCollector(DataCollectorBase):
1253
1256
  yield
1254
1257
  continue
1255
1258
  self._increment_frames(tensordict_out.numel())
1256
- if self.verbose:
1257
- torchrl_logger.info("Collector: postproc.")
1258
1259
  tensordict_out = self._postproc(tensordict_out)
1260
+ if self.verbose:
1261
+ torchrl_logger.info("Collector: postproc done.")
1259
1262
  if self.return_same_td:
1260
1263
  # This is used with multiprocessed collectors to use the buffers
1261
1264
  # stored in the tensordict.
@@ -1765,8 +1768,9 @@ class _MultiDataCollector(DataCollectorBase):
1765
1768
  .. warning:: `policy_factory` is currently not compatible with multiprocessed data
1766
1769
  collectors.
1767
1770
 
1768
- frames_per_batch (int): A keyword-only argument representing the
1769
- total number of elements in a batch.
1771
+ frames_per_batch (int, Sequence[int]): A keyword-only argument representing the
1772
+ total number of elements in a batch. If a sequence is provided, represents the number of elements in a
1773
+ batch per worker. Total number of elements in a batch is then the sum over the sequence.
1770
1774
  total_frames (int, optional): A keyword-only argument representing the
1771
1775
  total number of frames returned by the collector
1772
1776
  during its lifespan. If the ``total_frames`` is not divisible by
@@ -1923,7 +1927,7 @@ class _MultiDataCollector(DataCollectorBase):
1923
1927
  policy_factory: Callable[[], Callable]
1924
1928
  | list[Callable[[], Callable]]
1925
1929
  | None = None,
1926
- frames_per_batch: int,
1930
+ frames_per_batch: int | Sequence[int],
1927
1931
  total_frames: int | None = -1,
1928
1932
  device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
1929
1933
  storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
@@ -1959,6 +1963,22 @@ class _MultiDataCollector(DataCollectorBase):
1959
1963
  self.closed = True
1960
1964
  self.num_workers = len(create_env_fn)
1961
1965
 
1966
+ if (
1967
+ isinstance(frames_per_batch, Sequence)
1968
+ and len(frames_per_batch) != self.num_workers
1969
+ ):
1970
+ raise ValueError(
1971
+ "If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker."
1972
+ f"Got {len(frames_per_batch)} values for {self.num_workers} workers."
1973
+ )
1974
+
1975
+ self._frames_per_batch = frames_per_batch
1976
+ total_frames_per_batch = (
1977
+ sum(frames_per_batch)
1978
+ if isinstance(frames_per_batch, Sequence)
1979
+ else frames_per_batch
1980
+ )
1981
+
1962
1982
  self.set_truncated = set_truncated
1963
1983
  self.num_sub_threads = num_sub_threads
1964
1984
  self.num_threads = num_threads
@@ -2076,11 +2096,11 @@ class _MultiDataCollector(DataCollectorBase):
2076
2096
  if total_frames is None or total_frames < 0:
2077
2097
  total_frames = float("inf")
2078
2098
  else:
2079
- remainder = total_frames % frames_per_batch
2099
+ remainder = total_frames % total_frames_per_batch
2080
2100
  if remainder != 0 and RL_WARNINGS:
2081
2101
  warnings.warn(
2082
- f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). "
2083
- f"This means {frames_per_batch - remainder} additional frames will be collected. "
2102
+ f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({total_frames_per_batch}). "
2103
+ f"This means {total_frames_per_batch - remainder} additional frames will be collected. "
2084
2104
  "To silence this message, set the environment variable RL_WARNINGS to False."
2085
2105
  )
2086
2106
  self.total_frames = (
@@ -2091,7 +2111,8 @@ class _MultiDataCollector(DataCollectorBase):
2091
2111
  self.max_frames_per_traj = (
2092
2112
  int(max_frames_per_traj) if max_frames_per_traj is not None else 0
2093
2113
  )
2094
- self.requested_frames_per_batch = int(frames_per_batch)
2114
+
2115
+ self.requested_frames_per_batch = total_frames_per_batch
2095
2116
  self.reset_when_done = reset_when_done
2096
2117
  if split_trajs is None:
2097
2118
  split_trajs = False
@@ -2221,8 +2242,7 @@ class _MultiDataCollector(DataCollectorBase):
2221
2242
  )
2222
2243
  return storing_device, policy_device, env_device
2223
2244
 
2224
- @property
2225
- def frames_per_batch_worker(self):
2245
+ def frames_per_batch_worker(self, worker_idx: int | None = None) -> int:
2226
2246
  raise NotImplementedError
2227
2247
 
2228
2248
  @property
@@ -2281,7 +2301,7 @@ class _MultiDataCollector(DataCollectorBase):
2281
2301
  "create_env_kwargs": env_fun_kwargs,
2282
2302
  "policy": policy,
2283
2303
  "max_frames_per_traj": self.max_frames_per_traj,
2284
- "frames_per_batch": self.frames_per_batch_worker,
2304
+ "frames_per_batch": self.frames_per_batch_worker(worker_idx=i),
2285
2305
  "reset_at_each_iter": self.reset_at_each_iter,
2286
2306
  "policy_device": policy_device,
2287
2307
  "storing_device": storing_device,
@@ -2773,8 +2793,9 @@ class MultiSyncDataCollector(_MultiDataCollector):
2773
2793
  policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
2774
2794
  )
2775
2795
 
2776
- @property
2777
- def frames_per_batch_worker(self):
2796
+ def frames_per_batch_worker(self, worker_idx: int | None) -> int:
2797
+ if worker_idx is not None and isinstance(self._frames_per_batch, Sequence):
2798
+ return self._frames_per_batch[worker_idx]
2778
2799
  if self.requested_frames_per_batch % self.num_workers != 0 and RL_WARNINGS:
2779
2800
  warnings.warn(
2780
2801
  f"frames_per_batch {self.requested_frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers},"
@@ -2855,9 +2876,9 @@ class MultiSyncDataCollector(_MultiDataCollector):
2855
2876
  use_buffers = self._use_buffers
2856
2877
  if self.replay_buffer is not None:
2857
2878
  idx = new_data
2858
- workers_frames[idx] = (
2859
- workers_frames[idx] + self.frames_per_batch_worker
2860
- )
2879
+ workers_frames[idx] = workers_frames[
2880
+ idx
2881
+ ] + self.frames_per_batch_worker(worker_idx=idx)
2861
2882
  continue
2862
2883
  elif j == 0 or not use_buffers:
2863
2884
  try:
@@ -2903,7 +2924,12 @@ class MultiSyncDataCollector(_MultiDataCollector):
2903
2924
 
2904
2925
  if self.replay_buffer is not None:
2905
2926
  yield
2906
- self._frames += self.frames_per_batch_worker * self.num_workers
2927
+ self._frames += sum(
2928
+ [
2929
+ self.frames_per_batch_worker(worker_idx)
2930
+ for worker_idx in range(self.num_workers)
2931
+ ]
2932
+ )
2907
2933
  continue
2908
2934
 
2909
2935
  # we have to correct the traj_ids to make sure that they don't overlap
@@ -3156,8 +3182,7 @@ class MultiaSyncDataCollector(_MultiDataCollector):
3156
3182
  policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
3157
3183
  )
3158
3184
 
3159
- @property
3160
- def frames_per_batch_worker(self):
3185
+ def frames_per_batch_worker(self, worker_idx: int | None = None) -> int:
3161
3186
  return self.requested_frames_per_batch
3162
3187
 
3163
3188
  def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]:
@@ -3221,7 +3246,7 @@ class MultiaSyncDataCollector(_MultiDataCollector):
3221
3246
  if self.split_trajs:
3222
3247
  out = split_trajectories(out, prefix="collector")
3223
3248
  else:
3224
- worker_frames = self.frames_per_batch_worker
3249
+ worker_frames = self.frames_per_batch_worker()
3225
3250
  self._frames += worker_frames
3226
3251
  workers_frames[idx] = workers_frames[idx] + worker_frames
3227
3252
  if self.postprocs:
@@ -242,6 +242,11 @@ class LLMCollector(SyncDataCollector):
242
242
  else:
243
243
  self.policy_version_tracker = None
244
244
 
245
+ def set_postproc(self, postproc: Callable[[TensorDictBase], TensorDictBase]):
246
+ if self.postproc is not None:
247
+ raise RuntimeError("Postproc already set")
248
+ self.postproc = postproc
249
+
245
250
  def increment_version(self):
246
251
  """Increment the policy version."""
247
252
  if self.policy_version_tracker is not None:
@@ -361,9 +366,10 @@ class LLMCollector(SyncDataCollector):
361
366
  )
362
367
  self._yield_queues[idx].clear()
363
368
  result = self._trajectory_queue.popleft()
364
- torchrl_logger.info(
365
- f"LLMCollector: Yielding completed trajectory with shape {result.shape}."
366
- )
369
+ if self.verbose:
370
+ torchrl_logger.info(
371
+ f"LLMCollector: Yielding completed trajectory with shape {result.shape}."
372
+ )
367
373
  return result
368
374
 
369
375
  started = False
@@ -422,9 +428,10 @@ class LLMCollector(SyncDataCollector):
422
428
  self.env.async_step_and_maybe_reset_send(env_input)
423
429
 
424
430
  result = self._trajectory_queue.popleft()
425
- torchrl_logger.info(
426
- f"LLMCollector: Yielding completed trajectory with shape {result.shape}."
427
- )
431
+ if self.verbose:
432
+ torchrl_logger.info(
433
+ f"LLMCollector: Yielding completed trajectory with shape {result.shape}."
434
+ )
428
435
  return result
429
436
 
430
437
  as_remote = as_remote
@@ -134,6 +134,9 @@ class RayLLMCollector(LLMCollector):
134
134
  verbose=verbose,
135
135
  )
136
136
 
137
+ def set_postproc(self, postproc: Callable[[TensorDictBase], TensorDictBase]):
138
+ return ray.get(self._collector.set_postproc.remote(postproc))
139
+
137
140
  def _next_remote(self) -> None:
138
141
  return self._collector.next.remote()
139
142
 
torchrl/data/__init__.py CHANGED
@@ -17,6 +17,7 @@ from .llm import (
17
17
  RolloutFromModel,
18
18
  TensorDictTokenizer,
19
19
  TokenizedDatasetLoader,
20
+ TopKRewardSelector,
20
21
  )
21
22
  from .map import (
22
23
  BinaryToDecimal,
@@ -116,6 +117,7 @@ __all__ = [
116
117
  "Categorical",
117
118
  "Choice",
118
119
  "ContentBase",
120
+ "TopKRewardSelector",
119
121
  "Composite",
120
122
  "CompositeSpec",
121
123
  "ConstantKLController",
@@ -350,7 +350,7 @@ class MinariExperienceReplay(BaseDatasetExperienceReplay):
350
350
  # Add a "done" entry
351
351
  if self.split_trajs:
352
352
  with td_data.unlock_():
353
- from torchrl.objectives.utils import split_trajectories
353
+ from torchrl.collectors.utils import split_trajectories
354
354
 
355
355
  td_data = split_trajectories(td_data).memmap_(self.data_path)
356
356
  with open(self.metadata_path, "w") as metadata_file:
@@ -13,6 +13,7 @@ from .dataset import (
13
13
  )
14
14
  from .prompt import PromptData, PromptTensorDictTokenizer
15
15
  from .reward import PairwiseDataset, RewardData
16
+ from .topk import TopKRewardSelector
16
17
  from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel
17
18
 
18
19
  __all__ = [
@@ -30,4 +31,5 @@ __all__ = [
30
31
  "TokenizedDatasetLoader",
31
32
  "create_infinite_iterator",
32
33
  "get_dataloader",
34
+ "TopKRewardSelector",
33
35
  ]
torchrl/data/llm/chat.py CHANGED
@@ -11,19 +11,27 @@ from typing import Literal
11
11
 
12
12
  import torch
13
13
 
14
-
15
- from tensordict import lazy_stack, LazyStackedTensorDict, list_to_stack, TensorClass
14
+ from tensordict import (
15
+ lazy_stack,
16
+ LazyStackedTensorDict,
17
+ list_to_stack,
18
+ TensorClass,
19
+ TensorDict,
20
+ )
16
21
  from tensordict.utils import _maybe_correct_neg_dim
17
-
18
22
  from torchrl._utils import logger as torchrl_logger
19
23
 
20
24
 
21
25
  _CHAT_TEMPLATES = {
22
26
  "chatml_format": """{% for message in messages %}
27
+ {%- if message['role'] == 'assistant' %}
28
+ {% generation %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endgeneration %}
29
+ {%- else %}
23
30
  {{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
31
+ {%- endif %}
24
32
  {% endfor %}
25
33
  {%- if add_generation_prompt %}
26
- {{- '<|im_start|>assistant\n' }}
34
+ {% generation %}{{- '<|im_start|>assistant\n' }}{% endgeneration %}
27
35
  {%- endif %}
28
36
  """,
29
37
  "qwen": """
@@ -283,7 +291,7 @@ class History(TensorClass["nocast"]):
283
291
 
284
292
  Keyword Args:
285
293
  tokenizer (transformers.PreTrainedTokenizer | transformers.AutoProcessor): The tokenizer to use.
286
- add_generation_prompt (bool, optional): Whether to add a generation prompt. Defaults to `True`.
294
+ add_generation_prompt (bool, optional): Whether to add a generation prompt (e.g. `"<|im_start|>assistant"`). Defaults to `True`.
287
295
  chat_template (str, optional): The chat template to use. Defaults to the tokenizer's default template.
288
296
  chat_template_name (Literal["chatml_format", "qwen"], optional): The name of the chat template to use.
289
297
  Prevalent over `tokenizer.chat_template`. Defaults to `None`.
@@ -294,6 +302,7 @@ class History(TensorClass["nocast"]):
294
302
  return_tensors (str | None, optional): The type of tensors to return. Defaults to "pt".
295
303
  return_dict (bool, optional): Whether to return a dictionary. Defaults to `False`.
296
304
  return_assistant_tokens_mask (bool, optional): Whether to return a mask of the assistant generated tokens.
305
+ If `True`, the mask will be written to the `assistant_masks` key.
297
306
  For tokens generated by the assistant, the mask will contain `1`.
298
307
  For user and system tokens, the mask will contain `0`.
299
308
  This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
@@ -316,6 +325,11 @@ class History(TensorClass["nocast"]):
316
325
  raise RuntimeError(
317
326
  "You must specify a tokenizer to use when chat_template is not specified."
318
327
  )
328
+ elif "qwen" in getattr(tokenizer, "name_or_path", "").lower():
329
+ # We prefer our implementation of the Qwen template,
330
+ # since it accounts for the assistant's masking.
331
+ chat_template = _CHAT_TEMPLATES["qwen"]
332
+ chat_template_name = None
319
333
  else:
320
334
  chat_template = tokenizer.chat_template
321
335
  if chat_template is None:
@@ -334,7 +348,7 @@ class History(TensorClass["nocast"]):
334
348
  return_dict = False
335
349
 
336
350
  if self.ndim > 1:
337
- return [
351
+ result = [
338
352
  self[i].apply_chat_template(
339
353
  tokenizer=tokenizer,
340
354
  add_generation_prompt=add_generation_prompt,
@@ -351,12 +365,16 @@ class History(TensorClass["nocast"]):
351
365
  )
352
366
  for i in range(self.batch_size[0])
353
367
  ]
368
+ if return_dict:
369
+ return lazy_stack(result)
370
+ else:
371
+ return result
354
372
  self_flat = self.view(-1)
355
373
  # tolist_first=True is needed to avoid having a list of dict of dicts, but a list of dicts of lists of dicts
356
374
  self_flat = self_flat.tolist(tolist_first=True)
357
375
  # Remove the "<none>" role
358
376
  self_flat = [item for item in self_flat if item["role"] != "<none>"]
359
- return tokenizer.apply_chat_template(
377
+ result = tokenizer.apply_chat_template(
360
378
  conversation=self_flat,
361
379
  add_generation_prompt=add_generation_prompt,
362
380
  chat_template=chat_template,
@@ -369,6 +387,16 @@ class History(TensorClass["nocast"]):
369
387
  return_assistant_tokens_mask=return_assistant_tokens_mask,
370
388
  **kwargs,
371
389
  )
390
+ if not isinstance(result, (torch.Tensor, list, str)):
391
+ result = TensorDict.from_dict(result, auto_batch_size=True, batch_dims=1)
392
+ # If self has a batch_dims of 1, we have just the time dimension, so we need to remove the batch dim from the result
393
+ if self.batch_dims == 1:
394
+ if result.batch_size[0] != 1:
395
+ raise RuntimeError(
396
+ f"Expected a batch size of 1, got {result.batch_size[0]}."
397
+ )
398
+ result = result.squeeze(0)
399
+ return result
372
400
 
373
401
  @classmethod
374
402
  def from_text(
@@ -376,10 +404,20 @@ class History(TensorClass["nocast"]):
376
404
  text: str | list[str],
377
405
  chat_template_name: Literal["chatml_format", "qwen"] | None = None,
378
406
  chat_template: str | None = None,
407
+ tokenizer: transformers.AutoTokenizer # noqa: F821
408
+ | transformers.AutoProcessor # noqa: F821
409
+ | None = None,
379
410
  ) -> History:
380
- if chat_template_name in ("chatml_format", None):
411
+ if chat_template_name is None and chat_template is None:
412
+ if "qwen" in getattr(tokenizer, "name_or_path", "").lower():
413
+ # We can automatically detect the template name from the tokenizer
414
+ # and use the precoded parser.
415
+ chat_template_name = "qwen"
416
+ else:
417
+ chat_template_name = "chatml_format"
418
+ elif chat_template_name in ("chatml_format",):
381
419
  func = cls._inv_chatml
382
- elif chat_template_name == "qwen":
420
+ elif chat_template_name in ("qwen",):
383
421
  func = cls._inv_qwen
384
422
  else:
385
423
  raise NotImplementedError(
@@ -736,3 +774,15 @@ class History(TensorClass["nocast"]):
736
774
  }
737
775
 
738
776
  return Composite(defaults, shape=shape[:-1], data_cls=cls)
777
+
778
+ @classmethod
779
+ def from_chats(cls, chats: list[list[dict]]) -> History:
780
+ """Create a History object from a list of chats.
781
+
782
+ Args:
783
+ chats (list[list[dict]]): A list of chats, where each chat is a list of dictionaries.
784
+ """
785
+ if isinstance(chats[0], dict):
786
+ return lazy_stack([cls(**chat) for chat in chats])
787
+ else:
788
+ return lazy_stack([cls.from_chats(chat) for chat in chats])
@@ -0,0 +1,186 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ from collections import defaultdict, deque
8
+ from typing import Any
9
+
10
+ import torch
11
+ from tensordict import NestedKey, TensorDictBase
12
+ from torchrl._utils import logger as torchrl_logger
13
+ from torchrl.envs.transforms import Transform
14
+
15
+
16
+ class TopKRewardSelector(Transform):
17
+ """A replay-buffer transform that selects the top-k rewards for each prompt.
18
+
19
+ Args:
20
+ total_dialog_turns (int): Number of dialog turns to keep in memory for the top-k selection.
21
+ topk_size (int): Number of top-k rewards to select. Must be smaller than or equal to total_dialog_turns.
22
+ prompt_key (NestedKey): Key to the prompt in the tensordict. Defaults to "text".
23
+ rewards_key (NestedKey): Key to the rewards in the tensordict. Defaults to ("next", "reward").
24
+ done_key (NestedKey): Key to the done state in the tensordict. Defaults to ("next", "done").
25
+ verbose (bool): Whether to print verbose information. Defaults to `False`.
26
+
27
+ Example:
28
+ >>> from torchrl.data import ReplayBuffer, LazyStackStorage, SamplerWithoutReplacement
29
+ >>> from tensordict import TensorDict, lazy_stack
30
+ >>> import torch
31
+ >>> from torchrl.data.llm.topk import TopKRewardSelector
32
+ >>> # Create a replay buffer with 50 items, a sampler that samples without replacement, and a batch size of 5
33
+ >>> rb = ReplayBuffer(
34
+ ... storage=LazyStackStorage(50),
35
+ ... sampler=SamplerWithoutReplacement,
36
+ ... batch_size=5,
37
+ ... )
38
+ >>> # Create a tensordict with 50 items, each with 10 dialog turns
39
+ >>> td = lazy_stack(
40
+ ... [
41
+ ... TensorDict(
42
+ ... {
43
+ ... ("next", "done"): torch.full((1, 1), True),
44
+ ... # Reward for i+5 tokens
45
+ ... ("next", "reward"): torch.full((i + 5, 1), i),
46
+ ... # total of 10 dialogs per prompt
47
+ ... "text": f"Prompt {i // 5}",
48
+ ... }
49
+ ... )
50
+ ... for i in range(50)
51
+ ... ]
52
+ ... )
53
+ >>> # Create a top-k reward selector with 5 dialog turns and a top-k size of 3
54
+ >>> topk = TopKRewardSelector(total_dialog_turns=5, topk_size=3)
55
+ >>> rb.append_transform(topk)
56
+ >>> for _td in td.chunk(25):
57
+ ... rb.extend(_td)
58
+ >>> # Only wrote top3 of 50 items in 10 groups of 5
59
+ >>> assert rb.write_count == 30
60
+ >>> assert len(rb) == 30
61
+ >>> r3 = rb[:3].get(("next", "reward"), as_padded_tensor=True).squeeze()
62
+ >>> # 0 and 1 are missing because they're not part of the top-k
63
+ >>> assert (
64
+ ... r3 == torch.tensor(
65
+ ... [
66
+ ... [4, 4, 4, 4, 4, 4, 4, 4, 4],
67
+ ... [3, 3, 3, 3, 3, 3, 3, 3, 0],
68
+ ... [2, 2, 2, 2, 2, 2, 2, 0, 0],
69
+ ... ]
70
+ ... )
71
+ ... ).all()
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ total_dialog_turns: int,
77
+ topk_size: int,
78
+ prompt_key: NestedKey = "text",
79
+ rewards_key: NestedKey = ("next", "reward"),
80
+ done_key: NestedKey = ("next", "done"),
81
+ verbose: bool = True,
82
+ ):
83
+ super().__init__()
84
+ self.in_keys = [prompt_key, rewards_key, done_key]
85
+ self.prompt_key = prompt_key
86
+ self.rewards_key = rewards_key
87
+ self.done_key = done_key
88
+ self.queues = defaultdict(lambda: deque(maxlen=total_dialog_turns))
89
+ self.total_dialog_turns = total_dialog_turns
90
+ self.topk_size = topk_size
91
+ if topk_size > total_dialog_turns:
92
+ raise ValueError(
93
+ f"topk_size must be smaller than or equal to total_dialog_turns, got {topk_size=} and {total_dialog_turns=}"
94
+ )
95
+ self.verbose = verbose
96
+
97
+ def forward(self, tensordict: TensorDictBase) -> Any:
98
+ return tensordict
99
+
100
+ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
101
+ # Tensordict can be any number of dims, but it must contain entire trajectories
102
+ if tensordict.ndim == 1:
103
+ # Check how many done states we have
104
+ num_done = tensordict[self.done_key].sum()
105
+ if num_done > 1:
106
+ done_idx = tensordict[self.done_key].nonzero(as_tuple=True)[0] + 1
107
+ splits = torch.cat([done_idx.new_zeros((1,)), done_idx], dim=0).diff()
108
+ tensordicts = tensordict.split(splits)
109
+ tensordicts = [self._inv_call(td) for td in tensordicts]
110
+ tensordicts = [td for td in tensordicts if td is not None]
111
+ return torch.cat(tensordicts) if tensordicts else None
112
+ # Then we have a single trajectory
113
+ if not tensordict[-1][self.done_key].all():
114
+ raise RuntimeError("Expected the trajectory to be done.")
115
+ prompt = tensordict[0][self.prompt_key]
116
+ if not isinstance(prompt, str):
117
+ raise TypeError(f"Expected a string as prompt, got {type(prompt)=}")
118
+ self.queues[prompt].append(tensordict)
119
+ if len(self.queues[prompt]) == self.total_dialog_turns:
120
+ if self.verbose:
121
+ torchrl_logger.info(f"Getting top-k rewards for {prompt=}")
122
+ # Cat is the most robust way to combine the trajs
123
+ tds = torch.cat(list(self.queues[prompt]), -1)
124
+ # Collect rewards
125
+ reward = tds.get(self.rewards_key, as_nested_tensor=True)
126
+ reward = self._aggregate_rewards(reward)
127
+ # Check if all rewards are equal
128
+ if (reward == reward[0]).all():
129
+ # If all rewards are equal, we can't select top-k
130
+ if self.verbose:
131
+ torchrl_logger.warning(
132
+ f"All rewards are equal ({reward.unique()=})"
133
+ )
134
+ return
135
+ # Filter out rewards below median
136
+ median_reward = reward.median(dim=-1, keepdim=True)[0]
137
+ mask = reward > median_reward
138
+ filtered_reward = reward[mask]
139
+ filtered_indices = mask.nonzero(as_tuple=True)[0]
140
+ # Get top-k from filtered rewards
141
+ topk_reward = filtered_reward.topk(
142
+ k=min(self.topk_size, len(filtered_indices)), dim=-1
143
+ )
144
+ if not topk_reward.indices.numel():
145
+ if self.verbose:
146
+ torchrl_logger.warning(
147
+ f"No top-{self.topk_size} rewards found ({reward=})"
148
+ )
149
+ return
150
+ # Map back to original indices
151
+ selected_indices = filtered_indices[topk_reward.indices]
152
+ tds = tds[selected_indices]
153
+ if self.verbose:
154
+ torchrl_logger.info(
155
+ f"Selected top-{self.topk_size} rewards, with reward {topk_reward.values=}"
156
+ )
157
+ return tds
158
+ return
159
+ elif tensordict.ndim > 2:
160
+ # keep the time dim at the end
161
+ tensordict = tensordict.flatten(0, -2)
162
+ trajs = tensordict.unbind(-1)
163
+ # Iterate over the trajectories
164
+ result = []
165
+ for traj in trajs:
166
+ td_out = self._inv_call(traj)
167
+ if td_out is None:
168
+ continue
169
+ result.append(td_out)
170
+ if result:
171
+ return torch.cat(result, -1)
172
+ return
173
+
174
+ def _aggregate_rewards(self, reward: torch.Tensor) -> torch.Tensor:
175
+ """Aggregate the rewards across the dialog turns.
176
+
177
+ `reward` is expected to be a nested tensor.
178
+
179
+ The default implementation is to take the mean of the rewards across the dialog turns.
180
+ """
181
+ # reward = reward.to_padded_tensor(padding=0.0)
182
+ if reward.ndim < 2 or reward.ndim > 3:
183
+ raise ValueError(
184
+ f"Expected reward to be a 2D or 3D tensor, got {reward.ndim}D tensor"
185
+ )
186
+ return reward.mean(dim=-2).squeeze(-1)
@@ -54,9 +54,12 @@ class RayReplayBuffer(ReplayBuffer):
54
54
  """A Ray implementation of the Replay Buffer that can be extended and sampled remotely.
55
55
 
56
56
  Keyword Args:
57
+ replay_buffer_cls (type[ReplayBuffer], optional): the class to use for the replay buffer.
58
+ Defaults to :class:`~torchrl.data.ReplayBuffer`.
57
59
  ray_init_config (dict[str, Any], optiona): keyword arguments to pass to `ray.init()`.
58
60
  remote_config (dict[str, Any], optiona): keyword arguments to pass to `cls.as_remote()`.
59
61
  Defaults to `torchrl.collectors.distributed.ray.DEFAULT_REMOTE_CLASS_CONFIG`.
62
+ **kwargs: keyword arguments to pass to the replay buffer class.
60
63
 
61
64
  .. seealso:: :class:`~torchrl.data.ReplayBuffer` for a list of other keyword arguments.
62
65
 
@@ -119,6 +122,7 @@ class RayReplayBuffer(ReplayBuffer):
119
122
  def __init__(
120
123
  self,
121
124
  *args,
125
+ replay_buffer_cls: type[ReplayBuffer] | None = ReplayBuffer,
122
126
  ray_init_config: dict[str, Any] | None = None,
123
127
  remote_config: dict[str, Any] | None = None,
124
128
  **kwargs,
@@ -134,7 +138,13 @@ class RayReplayBuffer(ReplayBuffer):
134
138
  ray_init_config = DEFAULT_RAY_INIT_CONFIG
135
139
  ray.init(**ray_init_config)
136
140
 
137
- remote_cls = ReplayBuffer.as_remote(remote_config).remote
141
+ remote_cls = replay_buffer_cls.as_remote(remote_config).remote
142
+ # We can detect if the buffer has a GPU allocated, if not
143
+ # we'll make sure that the data is sent to CPU when needed.
144
+ if remote_config is not None:
145
+ self.has_gpu = remote_config.get("num_gpus", 0) > 0
146
+ else:
147
+ self.has_gpu = False
138
148
  self._rb = remote_cls(*args, **kwargs)
139
149
 
140
150
  def close(self):
@@ -158,6 +168,10 @@ class RayReplayBuffer(ReplayBuffer):
158
168
  return ray.get(pending_task)
159
169
 
160
170
  def extend(self, *args, **kwargs):
171
+ if not self.has_gpu:
172
+ # Move the data to GPU
173
+ args = [arg.to("cpu") for arg in args if hasattr(arg, "to")]
174
+ kwargs = {k: v.to("cpu") for k, v in kwargs.items() if hasattr(v, "to")}
161
175
  pending_task = self._rb.extend.remote(*args, **kwargs)
162
176
  return ray.get(pending_task)
163
177