torchrl-nightly 2025.6.19__cp39-cp39-win_amd64.whl → 2025.6.21__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. torchrl/_torchrl.cp39-win_amd64.pyd +0 -0
  2. torchrl/collectors/collectors.py +49 -24
  3. torchrl/collectors/llm/base.py +13 -6
  4. torchrl/collectors/llm/ray_collector.py +3 -0
  5. torchrl/data/__init__.py +2 -0
  6. torchrl/data/datasets/minari_data.py +1 -1
  7. torchrl/data/llm/__init__.py +2 -0
  8. torchrl/data/llm/chat.py +59 -9
  9. torchrl/data/llm/topk.py +186 -0
  10. torchrl/data/replay_buffers/ray_buffer.py +15 -1
  11. torchrl/data/replay_buffers/replay_buffers.py +50 -11
  12. torchrl/data/replay_buffers/samplers.py +98 -21
  13. torchrl/data/replay_buffers/storages.py +29 -2
  14. torchrl/envs/llm/__init__.py +2 -0
  15. torchrl/envs/llm/chat.py +4 -1
  16. torchrl/envs/llm/reward/gsm8k.py +15 -8
  17. torchrl/envs/llm/transforms/__init__.py +2 -1
  18. torchrl/envs/llm/transforms/kl.py +240 -4
  19. torchrl/envs/transforms/transforms.py +11 -27
  20. torchrl/modules/llm/policies/transformers_wrapper.py +71 -15
  21. torchrl/modules/llm/policies/vllm_wrapper.py +38 -5
  22. torchrl/objectives/llm/__init__.py +2 -1
  23. torchrl/objectives/llm/sft.py +465 -0
  24. torchrl/objectives/ppo.py +35 -12
  25. torchrl/version.py +2 -2
  26. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/METADATA +1 -1
  27. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/RECORD +30 -28
  28. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/LICENSE +0 -0
  29. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/WHEEL +0 -0
  30. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/top_level.txt +0 -0
@@ -702,7 +702,7 @@ class ReplayBuffer:
702
702
  self._sampler.add(index)
703
703
  return index
704
704
 
705
- def _extend(self, data: Sequence) -> torch.Tensor:
705
+ def _extend(self, data: Sequence, *, update_priority: bool = True) -> torch.Tensor:
706
706
  is_comp = is_compiling()
707
707
  nc = contextlib.nullcontext()
708
708
  with self._replay_lock if not is_comp else nc, self._write_lock if not is_comp else nc:
@@ -712,7 +712,9 @@ class ReplayBuffer:
712
712
  self._sampler.extend(index)
713
713
  return index
714
714
 
715
- def extend(self, data: Sequence) -> torch.Tensor:
715
+ def extend(
716
+ self, data: Sequence, *, update_priority: bool | None = None
717
+ ) -> torch.Tensor:
716
718
  """Extends the replay buffer with one or more elements contained in an iterable.
717
719
 
718
720
  If present, the inverse transforms will be called.`
@@ -721,6 +723,10 @@ class ReplayBuffer:
721
723
  data (iterable): collection of data to be added to the replay
722
724
  buffer.
723
725
 
726
+ Keyword Args:
727
+ update_priority (bool, optional): Whether to update the priority of the data. Defaults to True.
728
+ Without effect in this class. See :meth:`~torchrl.data.TensorDictReplayBuffer.extend` for more details.
729
+
724
730
  Returns:
725
731
  Indices of the data added to the replay buffer.
726
732
 
@@ -735,12 +741,16 @@ class ReplayBuffer:
735
741
  unbound elements can be provided (no PyTrees).
736
742
 
737
743
  """
744
+ if update_priority is not None:
745
+ raise NotImplementedError(
746
+ "update_priority is not supported in this class. See :meth:`~torchrl.data.TensorDictReplayBuffer.extend` for more details."
747
+ )
738
748
  if self._transform is not None and len(self._transform):
739
749
  with _set_dispatch_td_nn_modules(is_tensor_collection(data)):
740
750
  data = self._transform.inv(data)
741
751
  if data is None:
742
752
  return torch.zeros((0, self._storage.ndim), dtype=torch.long)
743
- return self._extend(data)
753
+ return self._extend(data, update_priority=update_priority)
744
754
 
745
755
  def update_priority(
746
756
  self,
@@ -914,8 +924,8 @@ class ReplayBuffer:
914
924
  self._iterator = iter(self)
915
925
  out = next(self._iterator)
916
926
  # if any, we don't want the device ref to be passed in distributed settings
917
- if out is not None:
918
- out.clear_device_()
927
+ if out is not None and (out.device != "cpu"):
928
+ out = out.copy().clear_device_()
919
929
  return out
920
930
  except StopIteration:
921
931
  self._iterator = None
@@ -1015,6 +1025,9 @@ class PrioritizedReplayBuffer(ReplayBuffer):
1015
1025
  storage (Storage, optional): the storage to be used. If none is provided
1016
1026
  a default :class:`~torchrl.data.replay_buffers.ListStorage` with
1017
1027
  ``max_size`` of ``1_000`` will be created.
1028
+ sampler (Sampler, optional): the sampler to be used. If none is provided,
1029
+ a default :class:`~torchrl.data.replay_buffers.PrioritizedSampler` with
1030
+ ``alpha``, ``beta``, and ``eps`` will be created.
1018
1031
  collate_fn (callable, optional): merges a list of samples to form a
1019
1032
  mini-batch of Tensor(s)/outputs. Used when using batched
1020
1033
  loading from a map-style dataset. The default value will be decided
@@ -1107,6 +1120,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
1107
1120
  eps: float = 1e-8,
1108
1121
  dtype: torch.dtype = torch.float,
1109
1122
  storage: Storage | None = None,
1123
+ sampler: Sampler | None = None,
1110
1124
  collate_fn: Callable | None = None,
1111
1125
  pin_memory: bool = False,
1112
1126
  prefetch: int | None = None,
@@ -1116,7 +1130,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
1116
1130
  ) -> None:
1117
1131
  if storage is None:
1118
1132
  storage = ListStorage(max_size=1_000)
1119
- sampler = PrioritizedSampler(storage.max_size, alpha, beta, eps, dtype)
1133
+ if sampler is None:
1134
+ sampler = PrioritizedSampler(storage.max_size, alpha, beta, eps, dtype)
1120
1135
  super().__init__(
1121
1136
  storage=storage,
1122
1137
  sampler=sampler,
@@ -1347,7 +1362,20 @@ class TensorDictReplayBuffer(ReplayBuffer):
1347
1362
  self.update_tensordict_priority(data)
1348
1363
  return index
1349
1364
 
1350
- def extend(self, tensordicts: TensorDictBase) -> torch.Tensor:
1365
+ def extend(
1366
+ self, tensordicts: TensorDictBase, *, update_priority: bool | None = None
1367
+ ) -> torch.Tensor:
1368
+ """Extends the replay buffer with a batch of data.
1369
+
1370
+ Args:
1371
+ tensordicts (TensorDictBase): The data to extend the replay buffer with.
1372
+
1373
+ Keyword Args:
1374
+ update_priority (bool, optional): Whether to update the priority of the data. Defaults to True.
1375
+
1376
+ Returns:
1377
+ The indices of the data that were added to the replay buffer.
1378
+ """
1351
1379
  if not isinstance(tensordicts, TensorDictBase):
1352
1380
  raise ValueError(
1353
1381
  f"{self.__class__.__name__} only accepts TensorDictBase subclasses. tensorclasses "
@@ -1365,8 +1393,17 @@ class TensorDictReplayBuffer(ReplayBuffer):
1365
1393
  # is that just doing this results in indices that are not sorted like the original data
1366
1394
  # so the actually indices will have to be used on the _storage directly (not on the buffer)
1367
1395
  self._set_index_in_td(tensordicts, index)
1368
- # TODO: in principle this is a good idea but currently it doesn't work + it re-writes a priority that has just been written
1369
- # self.update_tensordict_priority(tensordicts)
1396
+ if update_priority is None:
1397
+ update_priority = True
1398
+ if update_priority:
1399
+ try:
1400
+ vector = tensordicts.get(self.priority_key)
1401
+ if vector is not None:
1402
+ self.update_priority(index, vector)
1403
+ except Exception as e:
1404
+ raise RuntimeError(
1405
+ "Failed to update priority of extended data. You can try to set update_priority=False in the extend method and update the priority manually."
1406
+ ) from e
1370
1407
  return index
1371
1408
 
1372
1409
  def _set_index_in_td(self, tensordict, index):
@@ -1685,8 +1722,10 @@ class RemoteTensorDictReplayBuffer(TensorDictReplayBuffer):
1685
1722
  def add(self, data: TensorDictBase) -> int:
1686
1723
  return super().add(data)
1687
1724
 
1688
- def extend(self, tensordicts: list | TensorDictBase) -> torch.Tensor:
1689
- return super().extend(tensordicts)
1725
+ def extend(
1726
+ self, tensordicts: list | TensorDictBase, *, update_priority: bool | None = None
1727
+ ) -> torch.Tensor:
1728
+ return super().extend(tensordicts, update_priority=update_priority)
1690
1729
 
1691
1730
  def update_priority(
1692
1731
  self, index: int | torch.Tensor, priority: int | torch.Tensor
@@ -291,17 +291,38 @@ class SamplerWithoutReplacement(Sampler):
291
291
 
292
292
 
293
293
  class PrioritizedSampler(Sampler):
294
- """Prioritized sampler for replay buffer.
294
+ r"""Prioritized sampler for replay buffer.
295
295
 
296
- Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay." (https://arxiv.org/abs/1511.05952)
296
+ This sampler implements Prioritized Experience Replay (PER) as presented in
297
+ "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay."
298
+ (https://arxiv.org/abs/1511.05952)
299
+
300
+ **Core Idea**: Instead of sampling experiences uniformly from the replay buffer,
301
+ PER samples experiences with probability proportional to their "importance" - typically
302
+ measured by the magnitude of their temporal-difference (TD) error. This prioritization
303
+ can lead to faster learning by focusing on experiences that are most informative.
304
+
305
+ **How it works**:
306
+ 1. Each experience is assigned a priority based on its TD error: :math:`p_i = |\delta_i| + \epsilon`
307
+ 2. Sampling probability is computed as: :math:`P(i) = \frac{p_i^\alpha}{\sum_j p_j^\alpha}`
308
+ 3. Importance sampling weights correct for the bias: :math:`w_i = (N \cdot P(i))^{-\beta}`
297
309
 
298
310
  Args:
299
311
  max_capacity (int): maximum capacity of the buffer.
300
- alpha (:obj:`float`): exponent α determines how much prioritization is used,
301
- with α = 0 corresponding to the uniform case.
302
- beta (:obj:`float`): importance sampling negative exponent.
303
- eps (:obj:`float`, optional): delta added to the priorities to ensure that the buffer
304
- does not contain null priorities. Defaults to 1e-8.
312
+ alpha (:obj:`float`): exponent :math:`\alpha` determines how much prioritization is used.
313
+ - :math:`\alpha = 0`: uniform sampling (no prioritization)
314
+ - :math:`\alpha = 1`: full prioritization based on TD error magnitude
315
+ - Typical values: 0.4-0.7 for balanced prioritization
316
+ - Higher :math:`\alpha` means more aggressive prioritization of high-error experiences
317
+ beta (:obj:`float`): importance sampling negative exponent :math:`\beta`.
318
+ - :math:`\beta` controls the correction for the bias introduced by prioritization
319
+ - :math:`\beta = 0`: no correction (biased towards high-priority samples)
320
+ - :math:`\beta = 1`: full correction (unbiased but potentially unstable)
321
+ - Typical values: start at 0.4-0.6 and anneal to 1.0 during training
322
+ - Lower :math:`\beta` early in training provides stability, higher :math:`\beta` later reduces bias
323
+ eps (:obj:`float`, optional): small constant added to priorities to ensure
324
+ no experience has zero priority. This prevents experiences from never
325
+ being sampled. Defaults to 1e-8.
305
326
  reduction (str, optional): the reduction method for multidimensional
306
327
  tensordicts (ie stored trajectory). Can be one of "max", "min",
307
328
  "median" or "mean".
@@ -309,6 +330,23 @@ class PrioritizedSampler(Sampler):
309
330
  is tracked within the buffer. When ``False``, the max-priority tracks
310
331
  the maximum value since the instantiation of the sampler.
311
332
 
333
+ **Parameter Guidelines**:
334
+ - **:math:`\alpha` (alpha)**: Controls how much to prioritize high-error experiences
335
+ - 0.4-0.7: Good balance between learning speed and stability
336
+ - 1.0: Maximum prioritization (may be unstable)
337
+ - 0.0: Uniform sampling (no prioritization benefit)
338
+
339
+ - **:math:`\beta` (beta)**: Controls importance sampling correction
340
+ - Start at 0.4-0.6 for training stability
341
+ - Anneal to 1.0 over training to reduce bias
342
+ - Lower values = more stable but biased
343
+ - Higher values = less biased but potentially unstable
344
+
345
+ - **:math:`\epsilon`**: Small constant to prevent zero priorities
346
+ - 1e-8: Good default value
347
+ - Too small: may cause numerical issues
348
+ - Too large: reduces prioritization effect
349
+
312
350
  Examples:
313
351
  >>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
314
352
  >>> from tensordict import TensorDict
@@ -412,7 +450,7 @@ class PrioritizedSampler(Sampler):
412
450
  )
413
451
  return super().__getstate__()
414
452
 
415
- def _init(self):
453
+ def _init(self) -> None:
416
454
  if self.dtype in (torch.float, torch.FloatType, torch.float32):
417
455
  self._sum_tree = SumSegmentTreeFp32(self._max_capacity)
418
456
  self._min_tree = MinSegmentTreeFp32(self._max_capacity)
@@ -425,21 +463,23 @@ class PrioritizedSampler(Sampler):
425
463
  )
426
464
  self._max_priority = None
427
465
 
428
- def _empty(self):
466
+ def _empty(self) -> None:
429
467
  self._init()
430
468
 
431
469
  @property
432
- def _max_priority(self):
470
+ def _max_priority(self) -> tuple[float | None, int | None]:
433
471
  max_priority_index = self.__dict__.get("_max_priority")
434
472
  if max_priority_index is None:
435
473
  return (None, None)
436
474
  return max_priority_index
437
475
 
438
476
  @_max_priority.setter
439
- def _max_priority(self, value):
477
+ def _max_priority(self, value: tuple[float | None, int | None]) -> None:
440
478
  self.__dict__["_max_priority"] = value
441
479
 
442
- def _maybe_erase_max_priority(self, index):
480
+ def _maybe_erase_max_priority(
481
+ self, index: torch.Tensor | int | slice | tuple
482
+ ) -> None:
443
483
  if not self._max_priority_within_buffer:
444
484
  return
445
485
  max_priority_index = self._max_priority[1]
@@ -1839,11 +1879,21 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
1839
1879
 
1840
1880
 
1841
1881
  class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
1842
- """Samples slices of data along the first dimension, given start and stop signals, using prioritized sampling.
1882
+ r"""Samples slices of data along the first dimension, given start and stop signals, using prioritized sampling.
1883
+
1884
+ This class combines trajectory sampling with Prioritized Experience Replay (PER) as presented in
1885
+ "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay."
1886
+ (https://arxiv.org/abs/1511.05952)
1887
+
1888
+ **Core Idea**: Instead of sampling trajectory slices uniformly, this sampler prioritizes
1889
+ trajectory start points based on the importance of the transitions at those positions.
1890
+ This allows focusing learning on the most informative parts of trajectories.
1843
1891
 
1844
- This class samples sub-trajectories with replacement following a priority weighting presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015.
1845
- Prioritized experience replay."
1846
- (https://arxiv.org/abs/1511.05952)
1892
+ **How it works**:
1893
+ 1. Each transition is assigned a priority based on its TD error: :math:`p_i = |\\delta_i| + \\epsilon`
1894
+ 2. Trajectory start points are sampled with probability: :math:`P(i) = \frac{p_i^\alpha}{\\sum_j p_j^\alpha}`
1895
+ 3. Importance sampling weights correct for bias: :math:`w_i = (N \\cdot P(i))^{-\beta}`
1896
+ 4. Complete trajectory slices are extracted from the sampled start points
1847
1897
 
1848
1898
  For more info see :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` and :class:`~torchrl.data.replay_buffers.samplers.PrioritizedSampler`.
1849
1899
 
@@ -1855,15 +1905,42 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
1855
1905
  :meth:`update_priority`.
1856
1906
 
1857
1907
  Args:
1858
- alpha (:obj:`float`): exponent α determines how much prioritization is used,
1859
- with α = 0 corresponding to the uniform case.
1860
- beta (:obj:`float`): importance sampling negative exponent.
1861
- eps (:obj:`float`, optional): delta added to the priorities to ensure that the buffer
1862
- does not contain null priorities. Defaults to 1e-8.
1908
+ max_capacity (int): maximum capacity of the buffer.
1909
+ alpha (:obj:`float`): exponent :math:`\alpha` determines how much prioritization is used.
1910
+ - :math:`\alpha = 0`: uniform sampling of trajectory start points
1911
+ - :math:`\alpha = 1`: full prioritization based on TD error magnitude at start points
1912
+ - Typical values: 0.4-0.7 for balanced prioritization
1913
+ - Higher :math:`\alpha` means more aggressive prioritization of high-error trajectory regions
1914
+ beta (:obj:`float`): importance sampling negative exponent :math:`\beta`.
1915
+ - :math:`\beta` controls the correction for the bias introduced by prioritization
1916
+ - :math:`\beta = 0`: no correction (biased towards high-priority trajectory regions)
1917
+ - :math:`\beta = 1`: full correction (unbiased but potentially unstable)
1918
+ - Typical values: start at 0.4-0.6 and anneal to 1.0 during training
1919
+ - Lower :math:`\beta` early in training provides stability, higher :math:`\beta` later reduces bias
1920
+ eps (:obj:`float`, optional): small constant added to priorities to ensure
1921
+ no transition has zero priority. This prevents trajectory regions from never
1922
+ being sampled. Defaults to 1e-8.
1863
1923
  reduction (str, optional): the reduction method for multidimensional
1864
1924
  tensordicts (i.e., stored trajectory). Can be one of "max", "min",
1865
1925
  "median" or "mean".
1866
1926
 
1927
+ **Parameter Guidelines**:
1928
+ - **:math:`\alpha` (alpha)**: Controls how much to prioritize high-error trajectory regions
1929
+ - 0.4-0.7: Good balance between learning speed and stability
1930
+ - 1.0: Maximum prioritization (may be unstable)
1931
+ - 0.0: Uniform sampling (no prioritization benefit)
1932
+
1933
+ - **:math:`\beta` (beta)**: Controls importance sampling correction
1934
+ - Start at 0.4-0.6 for training stability
1935
+ - Anneal to 1.0 over training to reduce bias
1936
+ - Lower values = more stable but biased
1937
+ - Higher values = less biased but potentially unstable
1938
+
1939
+ - **:math:`\\epsilon`**: Small constant to prevent zero priorities
1940
+ - 1e-8: Good default value
1941
+ - Too small: may cause numerical issues
1942
+ - Too large: reduces prioritization effect
1943
+
1867
1944
  Keyword Args:
1868
1945
  num_slices (int): the number of slices to be sampled. The batch-size
1869
1946
  must be greater or equal to the ``num_slices`` argument. Exclusive
@@ -230,15 +230,38 @@ class ListStorage(Storage):
230
230
  max_size (int, optional): the maximum number of elements stored in the storage.
231
231
  If not provided, an unlimited storage is created.
232
232
 
233
+ Keyword Args:
234
+ compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at
235
+ the cost of being executable in multiprocessed settings.
236
+ device (str, optional): the device to use for the storage. Defaults to `None` (inputs are not moved to the device).
237
+
233
238
  """
234
239
 
235
240
  _default_checkpointer = ListStorageCheckpointer
236
241
 
237
- def __init__(self, max_size: int | None = None, compilable: bool = False):
242
+ def __init__(
243
+ self,
244
+ max_size: int | None = None,
245
+ *,
246
+ compilable: bool = False,
247
+ device: torch.device | str | int | None = None,
248
+ ):
238
249
  if max_size is None:
239
250
  max_size = torch.iinfo(torch.int64).max
240
251
  super().__init__(max_size, compilable=compilable)
241
252
  self._storage = []
253
+ self.device = device
254
+
255
+ def _to_device(self, data: Any) -> Any:
256
+ """Utility method to move data to the device."""
257
+ if self.device is not None:
258
+ if hasattr(data, "to"):
259
+ data = data.to(self.device)
260
+ else:
261
+ data = tree_map(
262
+ lambda x: x.to(self.device) if hasattr(x, "to") else x, data
263
+ )
264
+ return data
242
265
 
243
266
  def set(
244
267
  self,
@@ -254,6 +277,7 @@ class ListStorage(Storage):
254
277
  self.set(int(cursor), data, set_cursor=set_cursor)
255
278
  return
256
279
  if isinstance(cursor, slice):
280
+ data = self._to_device(data)
257
281
  self._storage[cursor] = data
258
282
  return
259
283
  if isinstance(
@@ -290,6 +314,7 @@ class ListStorage(Storage):
290
314
  f"maximum capacity is {self.max_size} "
291
315
  f"and the index of the item to be set is {cursor}."
292
316
  )
317
+ data = self._to_device(data)
293
318
  if cursor == len(self._storage):
294
319
  self._storage.append(data)
295
320
  else:
@@ -387,6 +412,7 @@ class LazyStackStorage(ListStorage):
387
412
  compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at
388
413
  the cost of being executable in multiprocessed settings.
389
414
  stack_dim (int, optional): the stack dimension in terms of TensorDict batch sizes. Defaults to `0`.
415
+ device (str, optional): the device to use for the storage. Defaults to `None` (inputs are not moved to the device).
390
416
 
391
417
  Examples:
392
418
  >>> import torch
@@ -421,8 +447,9 @@ class LazyStackStorage(ListStorage):
421
447
  *,
422
448
  compilable: bool = False,
423
449
  stack_dim: int = 0,
450
+ device: torch.device | str | int | None = None,
424
451
  ):
425
- super().__init__(max_size=max_size, compilable=compilable)
452
+ super().__init__(max_size=max_size, compilable=compilable, device=device)
426
453
  self.stack_dim = stack_dim
427
454
 
428
455
  def get(self, index: int | Sequence[int] | slice) -> Any:
@@ -22,12 +22,14 @@ from .transforms import (
22
22
  KLRewardTransform,
23
23
  MCPToolTransform,
24
24
  PythonInterpreter,
25
+ RetrieveLogProb,
25
26
  TemplateTransform,
26
27
  Tokenizer,
27
28
  )
28
29
 
29
30
  __all__ = [
30
31
  "BrowserTransform",
32
+ "RetrieveLogProb",
31
33
  "ChatEnv",
32
34
  "DataLoadingPrimer",
33
35
  "DatasetChatEnv",
torchrl/envs/llm/chat.py CHANGED
@@ -206,7 +206,10 @@ class ChatEnv(EnvBase):
206
206
  if lh.role != self.policy_role:
207
207
  raise ValueError(
208
208
  "The role received in the last block parsed from the policy "
209
- f"output does not match the expected policy role: received {lh.role} but expected {self.policy_role}."
209
+ f"output does not match the expected policy role: received {lh.role} but expected {self.policy_role}.\n"
210
+ f"Parsed input: {text=}\n"
211
+ f"Parsed history: {parsed_history=}\n"
212
+ f"Final element: {local_history=}"
210
213
  )
211
214
  # Append history item
212
215
  history = history.append(local_history, inplace=False)
@@ -145,25 +145,32 @@ class GSM8KRewardParser(Transform):
145
145
  potential_answer = [potential_answer]
146
146
  if isinstance(cot, str):
147
147
  cot = [cot]
148
- reward_answer = 5.0 * (len(potential_answer) == 1)
149
148
 
149
+ # Format quality rewards (always applied)
150
+ reward_answer = 5.0 * (len(potential_answer) == 1)
150
151
  reward_think = 5.0 * (len(cot) == 1)
151
152
 
152
- # One of the answer tags has the right answer
153
+ # Answer correctness rewards
153
154
  reward_right = 20.0 * (
154
155
  any(attempt == true_answer for attempt in potential_answer)
155
156
  )
156
-
157
- # One of the answer tags contains the right answer (might be e.g. $20 instead of 20)
158
157
  reward_contained = 10.0 * (
159
158
  any((true_answer in attempt) for attempt in potential_answer)
160
159
  )
161
160
 
162
161
  success = len(potential_answer) > 0 and potential_answer[-1] == true_answer
163
- # Compose the rewards
164
- reward = 100.0 * float(success) + (
165
- reward_answer + reward_think + reward_contained + reward_right
166
- ) * (1 - float(success))
162
+
163
+ # Base success reward (lower than before to make format quality more important)
164
+ base_success_reward = 60.0 if success else 0.0
165
+
166
+ # Compose the rewards - always include format quality, even when successful
167
+ reward = (
168
+ base_success_reward
169
+ + reward_answer
170
+ + reward_think
171
+ + reward_contained
172
+ + reward_right
173
+ )
167
174
 
168
175
  rewards = TensorDict(
169
176
  reward_answer=reward_answer,
@@ -6,7 +6,7 @@
6
6
  from .browser import BrowserTransform
7
7
  from .dataloading import as_nested_tensor, as_padded_tensor, DataLoadingPrimer
8
8
  from .format import TemplateTransform
9
- from .kl import KLRewardTransform
9
+ from .kl import KLRewardTransform, RetrieveLogProb
10
10
  from .policy_version import PolicyVersion
11
11
  from .tokenizer import Tokenizer
12
12
  from .tools import MCPToolTransform, PythonInterpreter
@@ -15,6 +15,7 @@ __all__ = [
15
15
  "BrowserTransform",
16
16
  "DataLoadingPrimer",
17
17
  "KLRewardTransform",
18
+ "RetrieveLogProb",
18
19
  "MCPToolTransform",
19
20
  "PolicyVersion",
20
21
  "PythonInterpreter",