torchrl-nightly 2025.6.19__cp39-cp39-win_amd64.whl → 2025.6.21__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrl/_torchrl.cp39-win_amd64.pyd +0 -0
- torchrl/collectors/collectors.py +49 -24
- torchrl/collectors/llm/base.py +13 -6
- torchrl/collectors/llm/ray_collector.py +3 -0
- torchrl/data/__init__.py +2 -0
- torchrl/data/datasets/minari_data.py +1 -1
- torchrl/data/llm/__init__.py +2 -0
- torchrl/data/llm/chat.py +59 -9
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/replay_buffers/ray_buffer.py +15 -1
- torchrl/data/replay_buffers/replay_buffers.py +50 -11
- torchrl/data/replay_buffers/samplers.py +98 -21
- torchrl/data/replay_buffers/storages.py +29 -2
- torchrl/envs/llm/__init__.py +2 -0
- torchrl/envs/llm/chat.py +4 -1
- torchrl/envs/llm/reward/gsm8k.py +15 -8
- torchrl/envs/llm/transforms/__init__.py +2 -1
- torchrl/envs/llm/transforms/kl.py +240 -4
- torchrl/envs/transforms/transforms.py +11 -27
- torchrl/modules/llm/policies/transformers_wrapper.py +71 -15
- torchrl/modules/llm/policies/vllm_wrapper.py +38 -5
- torchrl/objectives/llm/__init__.py +2 -1
- torchrl/objectives/llm/sft.py +465 -0
- torchrl/objectives/ppo.py +35 -12
- torchrl/version.py +2 -2
- {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/METADATA +1 -1
- {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/RECORD +30 -28
- {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/LICENSE +0 -0
- {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/WHEEL +0 -0
- {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/top_level.txt +0 -0
@@ -702,7 +702,7 @@ class ReplayBuffer:
|
|
702
702
|
self._sampler.add(index)
|
703
703
|
return index
|
704
704
|
|
705
|
-
def _extend(self, data: Sequence) -> torch.Tensor:
|
705
|
+
def _extend(self, data: Sequence, *, update_priority: bool = True) -> torch.Tensor:
|
706
706
|
is_comp = is_compiling()
|
707
707
|
nc = contextlib.nullcontext()
|
708
708
|
with self._replay_lock if not is_comp else nc, self._write_lock if not is_comp else nc:
|
@@ -712,7 +712,9 @@ class ReplayBuffer:
|
|
712
712
|
self._sampler.extend(index)
|
713
713
|
return index
|
714
714
|
|
715
|
-
def extend(
|
715
|
+
def extend(
|
716
|
+
self, data: Sequence, *, update_priority: bool | None = None
|
717
|
+
) -> torch.Tensor:
|
716
718
|
"""Extends the replay buffer with one or more elements contained in an iterable.
|
717
719
|
|
718
720
|
If present, the inverse transforms will be called.`
|
@@ -721,6 +723,10 @@ class ReplayBuffer:
|
|
721
723
|
data (iterable): collection of data to be added to the replay
|
722
724
|
buffer.
|
723
725
|
|
726
|
+
Keyword Args:
|
727
|
+
update_priority (bool, optional): Whether to update the priority of the data. Defaults to True.
|
728
|
+
Without effect in this class. See :meth:`~torchrl.data.TensorDictReplayBuffer.extend` for more details.
|
729
|
+
|
724
730
|
Returns:
|
725
731
|
Indices of the data added to the replay buffer.
|
726
732
|
|
@@ -735,12 +741,16 @@ class ReplayBuffer:
|
|
735
741
|
unbound elements can be provided (no PyTrees).
|
736
742
|
|
737
743
|
"""
|
744
|
+
if update_priority is not None:
|
745
|
+
raise NotImplementedError(
|
746
|
+
"update_priority is not supported in this class. See :meth:`~torchrl.data.TensorDictReplayBuffer.extend` for more details."
|
747
|
+
)
|
738
748
|
if self._transform is not None and len(self._transform):
|
739
749
|
with _set_dispatch_td_nn_modules(is_tensor_collection(data)):
|
740
750
|
data = self._transform.inv(data)
|
741
751
|
if data is None:
|
742
752
|
return torch.zeros((0, self._storage.ndim), dtype=torch.long)
|
743
|
-
return self._extend(data)
|
753
|
+
return self._extend(data, update_priority=update_priority)
|
744
754
|
|
745
755
|
def update_priority(
|
746
756
|
self,
|
@@ -914,8 +924,8 @@ class ReplayBuffer:
|
|
914
924
|
self._iterator = iter(self)
|
915
925
|
out = next(self._iterator)
|
916
926
|
# if any, we don't want the device ref to be passed in distributed settings
|
917
|
-
if out is not None:
|
918
|
-
out.clear_device_()
|
927
|
+
if out is not None and (out.device != "cpu"):
|
928
|
+
out = out.copy().clear_device_()
|
919
929
|
return out
|
920
930
|
except StopIteration:
|
921
931
|
self._iterator = None
|
@@ -1015,6 +1025,9 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|
1015
1025
|
storage (Storage, optional): the storage to be used. If none is provided
|
1016
1026
|
a default :class:`~torchrl.data.replay_buffers.ListStorage` with
|
1017
1027
|
``max_size`` of ``1_000`` will be created.
|
1028
|
+
sampler (Sampler, optional): the sampler to be used. If none is provided,
|
1029
|
+
a default :class:`~torchrl.data.replay_buffers.PrioritizedSampler` with
|
1030
|
+
``alpha``, ``beta``, and ``eps`` will be created.
|
1018
1031
|
collate_fn (callable, optional): merges a list of samples to form a
|
1019
1032
|
mini-batch of Tensor(s)/outputs. Used when using batched
|
1020
1033
|
loading from a map-style dataset. The default value will be decided
|
@@ -1107,6 +1120,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|
1107
1120
|
eps: float = 1e-8,
|
1108
1121
|
dtype: torch.dtype = torch.float,
|
1109
1122
|
storage: Storage | None = None,
|
1123
|
+
sampler: Sampler | None = None,
|
1110
1124
|
collate_fn: Callable | None = None,
|
1111
1125
|
pin_memory: bool = False,
|
1112
1126
|
prefetch: int | None = None,
|
@@ -1116,7 +1130,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|
1116
1130
|
) -> None:
|
1117
1131
|
if storage is None:
|
1118
1132
|
storage = ListStorage(max_size=1_000)
|
1119
|
-
sampler
|
1133
|
+
if sampler is None:
|
1134
|
+
sampler = PrioritizedSampler(storage.max_size, alpha, beta, eps, dtype)
|
1120
1135
|
super().__init__(
|
1121
1136
|
storage=storage,
|
1122
1137
|
sampler=sampler,
|
@@ -1347,7 +1362,20 @@ class TensorDictReplayBuffer(ReplayBuffer):
|
|
1347
1362
|
self.update_tensordict_priority(data)
|
1348
1363
|
return index
|
1349
1364
|
|
1350
|
-
def extend(
|
1365
|
+
def extend(
|
1366
|
+
self, tensordicts: TensorDictBase, *, update_priority: bool | None = None
|
1367
|
+
) -> torch.Tensor:
|
1368
|
+
"""Extends the replay buffer with a batch of data.
|
1369
|
+
|
1370
|
+
Args:
|
1371
|
+
tensordicts (TensorDictBase): The data to extend the replay buffer with.
|
1372
|
+
|
1373
|
+
Keyword Args:
|
1374
|
+
update_priority (bool, optional): Whether to update the priority of the data. Defaults to True.
|
1375
|
+
|
1376
|
+
Returns:
|
1377
|
+
The indices of the data that were added to the replay buffer.
|
1378
|
+
"""
|
1351
1379
|
if not isinstance(tensordicts, TensorDictBase):
|
1352
1380
|
raise ValueError(
|
1353
1381
|
f"{self.__class__.__name__} only accepts TensorDictBase subclasses. tensorclasses "
|
@@ -1365,8 +1393,17 @@ class TensorDictReplayBuffer(ReplayBuffer):
|
|
1365
1393
|
# is that just doing this results in indices that are not sorted like the original data
|
1366
1394
|
# so the actually indices will have to be used on the _storage directly (not on the buffer)
|
1367
1395
|
self._set_index_in_td(tensordicts, index)
|
1368
|
-
|
1369
|
-
|
1396
|
+
if update_priority is None:
|
1397
|
+
update_priority = True
|
1398
|
+
if update_priority:
|
1399
|
+
try:
|
1400
|
+
vector = tensordicts.get(self.priority_key)
|
1401
|
+
if vector is not None:
|
1402
|
+
self.update_priority(index, vector)
|
1403
|
+
except Exception as e:
|
1404
|
+
raise RuntimeError(
|
1405
|
+
"Failed to update priority of extended data. You can try to set update_priority=False in the extend method and update the priority manually."
|
1406
|
+
) from e
|
1370
1407
|
return index
|
1371
1408
|
|
1372
1409
|
def _set_index_in_td(self, tensordict, index):
|
@@ -1685,8 +1722,10 @@ class RemoteTensorDictReplayBuffer(TensorDictReplayBuffer):
|
|
1685
1722
|
def add(self, data: TensorDictBase) -> int:
|
1686
1723
|
return super().add(data)
|
1687
1724
|
|
1688
|
-
def extend(
|
1689
|
-
|
1725
|
+
def extend(
|
1726
|
+
self, tensordicts: list | TensorDictBase, *, update_priority: bool | None = None
|
1727
|
+
) -> torch.Tensor:
|
1728
|
+
return super().extend(tensordicts, update_priority=update_priority)
|
1690
1729
|
|
1691
1730
|
def update_priority(
|
1692
1731
|
self, index: int | torch.Tensor, priority: int | torch.Tensor
|
@@ -291,17 +291,38 @@ class SamplerWithoutReplacement(Sampler):
|
|
291
291
|
|
292
292
|
|
293
293
|
class PrioritizedSampler(Sampler):
|
294
|
-
"""Prioritized sampler for replay buffer.
|
294
|
+
r"""Prioritized sampler for replay buffer.
|
295
295
|
|
296
|
-
|
296
|
+
This sampler implements Prioritized Experience Replay (PER) as presented in
|
297
|
+
"Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay."
|
298
|
+
(https://arxiv.org/abs/1511.05952)
|
299
|
+
|
300
|
+
**Core Idea**: Instead of sampling experiences uniformly from the replay buffer,
|
301
|
+
PER samples experiences with probability proportional to their "importance" - typically
|
302
|
+
measured by the magnitude of their temporal-difference (TD) error. This prioritization
|
303
|
+
can lead to faster learning by focusing on experiences that are most informative.
|
304
|
+
|
305
|
+
**How it works**:
|
306
|
+
1. Each experience is assigned a priority based on its TD error: :math:`p_i = |\delta_i| + \epsilon`
|
307
|
+
2. Sampling probability is computed as: :math:`P(i) = \frac{p_i^\alpha}{\sum_j p_j^\alpha}`
|
308
|
+
3. Importance sampling weights correct for the bias: :math:`w_i = (N \cdot P(i))^{-\beta}`
|
297
309
|
|
298
310
|
Args:
|
299
311
|
max_capacity (int): maximum capacity of the buffer.
|
300
|
-
alpha (:obj:`float`): exponent
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
312
|
+
alpha (:obj:`float`): exponent :math:`\alpha` determines how much prioritization is used.
|
313
|
+
- :math:`\alpha = 0`: uniform sampling (no prioritization)
|
314
|
+
- :math:`\alpha = 1`: full prioritization based on TD error magnitude
|
315
|
+
- Typical values: 0.4-0.7 for balanced prioritization
|
316
|
+
- Higher :math:`\alpha` means more aggressive prioritization of high-error experiences
|
317
|
+
beta (:obj:`float`): importance sampling negative exponent :math:`\beta`.
|
318
|
+
- :math:`\beta` controls the correction for the bias introduced by prioritization
|
319
|
+
- :math:`\beta = 0`: no correction (biased towards high-priority samples)
|
320
|
+
- :math:`\beta = 1`: full correction (unbiased but potentially unstable)
|
321
|
+
- Typical values: start at 0.4-0.6 and anneal to 1.0 during training
|
322
|
+
- Lower :math:`\beta` early in training provides stability, higher :math:`\beta` later reduces bias
|
323
|
+
eps (:obj:`float`, optional): small constant added to priorities to ensure
|
324
|
+
no experience has zero priority. This prevents experiences from never
|
325
|
+
being sampled. Defaults to 1e-8.
|
305
326
|
reduction (str, optional): the reduction method for multidimensional
|
306
327
|
tensordicts (ie stored trajectory). Can be one of "max", "min",
|
307
328
|
"median" or "mean".
|
@@ -309,6 +330,23 @@ class PrioritizedSampler(Sampler):
|
|
309
330
|
is tracked within the buffer. When ``False``, the max-priority tracks
|
310
331
|
the maximum value since the instantiation of the sampler.
|
311
332
|
|
333
|
+
**Parameter Guidelines**:
|
334
|
+
- **:math:`\alpha` (alpha)**: Controls how much to prioritize high-error experiences
|
335
|
+
- 0.4-0.7: Good balance between learning speed and stability
|
336
|
+
- 1.0: Maximum prioritization (may be unstable)
|
337
|
+
- 0.0: Uniform sampling (no prioritization benefit)
|
338
|
+
|
339
|
+
- **:math:`\beta` (beta)**: Controls importance sampling correction
|
340
|
+
- Start at 0.4-0.6 for training stability
|
341
|
+
- Anneal to 1.0 over training to reduce bias
|
342
|
+
- Lower values = more stable but biased
|
343
|
+
- Higher values = less biased but potentially unstable
|
344
|
+
|
345
|
+
- **:math:`\epsilon`**: Small constant to prevent zero priorities
|
346
|
+
- 1e-8: Good default value
|
347
|
+
- Too small: may cause numerical issues
|
348
|
+
- Too large: reduces prioritization effect
|
349
|
+
|
312
350
|
Examples:
|
313
351
|
>>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
|
314
352
|
>>> from tensordict import TensorDict
|
@@ -412,7 +450,7 @@ class PrioritizedSampler(Sampler):
|
|
412
450
|
)
|
413
451
|
return super().__getstate__()
|
414
452
|
|
415
|
-
def _init(self):
|
453
|
+
def _init(self) -> None:
|
416
454
|
if self.dtype in (torch.float, torch.FloatType, torch.float32):
|
417
455
|
self._sum_tree = SumSegmentTreeFp32(self._max_capacity)
|
418
456
|
self._min_tree = MinSegmentTreeFp32(self._max_capacity)
|
@@ -425,21 +463,23 @@ class PrioritizedSampler(Sampler):
|
|
425
463
|
)
|
426
464
|
self._max_priority = None
|
427
465
|
|
428
|
-
def _empty(self):
|
466
|
+
def _empty(self) -> None:
|
429
467
|
self._init()
|
430
468
|
|
431
469
|
@property
|
432
|
-
def _max_priority(self):
|
470
|
+
def _max_priority(self) -> tuple[float | None, int | None]:
|
433
471
|
max_priority_index = self.__dict__.get("_max_priority")
|
434
472
|
if max_priority_index is None:
|
435
473
|
return (None, None)
|
436
474
|
return max_priority_index
|
437
475
|
|
438
476
|
@_max_priority.setter
|
439
|
-
def _max_priority(self, value):
|
477
|
+
def _max_priority(self, value: tuple[float | None, int | None]) -> None:
|
440
478
|
self.__dict__["_max_priority"] = value
|
441
479
|
|
442
|
-
def _maybe_erase_max_priority(
|
480
|
+
def _maybe_erase_max_priority(
|
481
|
+
self, index: torch.Tensor | int | slice | tuple
|
482
|
+
) -> None:
|
443
483
|
if not self._max_priority_within_buffer:
|
444
484
|
return
|
445
485
|
max_priority_index = self._max_priority[1]
|
@@ -1839,11 +1879,21 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
|
|
1839
1879
|
|
1840
1880
|
|
1841
1881
|
class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
|
1842
|
-
"""Samples slices of data along the first dimension, given start and stop signals, using prioritized sampling.
|
1882
|
+
r"""Samples slices of data along the first dimension, given start and stop signals, using prioritized sampling.
|
1883
|
+
|
1884
|
+
This class combines trajectory sampling with Prioritized Experience Replay (PER) as presented in
|
1885
|
+
"Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay."
|
1886
|
+
(https://arxiv.org/abs/1511.05952)
|
1887
|
+
|
1888
|
+
**Core Idea**: Instead of sampling trajectory slices uniformly, this sampler prioritizes
|
1889
|
+
trajectory start points based on the importance of the transitions at those positions.
|
1890
|
+
This allows focusing learning on the most informative parts of trajectories.
|
1843
1891
|
|
1844
|
-
|
1845
|
-
|
1846
|
-
|
1892
|
+
**How it works**:
|
1893
|
+
1. Each transition is assigned a priority based on its TD error: :math:`p_i = |\\delta_i| + \\epsilon`
|
1894
|
+
2. Trajectory start points are sampled with probability: :math:`P(i) = \frac{p_i^\alpha}{\\sum_j p_j^\alpha}`
|
1895
|
+
3. Importance sampling weights correct for bias: :math:`w_i = (N \\cdot P(i))^{-\beta}`
|
1896
|
+
4. Complete trajectory slices are extracted from the sampled start points
|
1847
1897
|
|
1848
1898
|
For more info see :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` and :class:`~torchrl.data.replay_buffers.samplers.PrioritizedSampler`.
|
1849
1899
|
|
@@ -1855,15 +1905,42 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
|
|
1855
1905
|
:meth:`update_priority`.
|
1856
1906
|
|
1857
1907
|
Args:
|
1858
|
-
|
1859
|
-
|
1860
|
-
|
1861
|
-
|
1862
|
-
|
1908
|
+
max_capacity (int): maximum capacity of the buffer.
|
1909
|
+
alpha (:obj:`float`): exponent :math:`\alpha` determines how much prioritization is used.
|
1910
|
+
- :math:`\alpha = 0`: uniform sampling of trajectory start points
|
1911
|
+
- :math:`\alpha = 1`: full prioritization based on TD error magnitude at start points
|
1912
|
+
- Typical values: 0.4-0.7 for balanced prioritization
|
1913
|
+
- Higher :math:`\alpha` means more aggressive prioritization of high-error trajectory regions
|
1914
|
+
beta (:obj:`float`): importance sampling negative exponent :math:`\beta`.
|
1915
|
+
- :math:`\beta` controls the correction for the bias introduced by prioritization
|
1916
|
+
- :math:`\beta = 0`: no correction (biased towards high-priority trajectory regions)
|
1917
|
+
- :math:`\beta = 1`: full correction (unbiased but potentially unstable)
|
1918
|
+
- Typical values: start at 0.4-0.6 and anneal to 1.0 during training
|
1919
|
+
- Lower :math:`\beta` early in training provides stability, higher :math:`\beta` later reduces bias
|
1920
|
+
eps (:obj:`float`, optional): small constant added to priorities to ensure
|
1921
|
+
no transition has zero priority. This prevents trajectory regions from never
|
1922
|
+
being sampled. Defaults to 1e-8.
|
1863
1923
|
reduction (str, optional): the reduction method for multidimensional
|
1864
1924
|
tensordicts (i.e., stored trajectory). Can be one of "max", "min",
|
1865
1925
|
"median" or "mean".
|
1866
1926
|
|
1927
|
+
**Parameter Guidelines**:
|
1928
|
+
- **:math:`\alpha` (alpha)**: Controls how much to prioritize high-error trajectory regions
|
1929
|
+
- 0.4-0.7: Good balance between learning speed and stability
|
1930
|
+
- 1.0: Maximum prioritization (may be unstable)
|
1931
|
+
- 0.0: Uniform sampling (no prioritization benefit)
|
1932
|
+
|
1933
|
+
- **:math:`\beta` (beta)**: Controls importance sampling correction
|
1934
|
+
- Start at 0.4-0.6 for training stability
|
1935
|
+
- Anneal to 1.0 over training to reduce bias
|
1936
|
+
- Lower values = more stable but biased
|
1937
|
+
- Higher values = less biased but potentially unstable
|
1938
|
+
|
1939
|
+
- **:math:`\\epsilon`**: Small constant to prevent zero priorities
|
1940
|
+
- 1e-8: Good default value
|
1941
|
+
- Too small: may cause numerical issues
|
1942
|
+
- Too large: reduces prioritization effect
|
1943
|
+
|
1867
1944
|
Keyword Args:
|
1868
1945
|
num_slices (int): the number of slices to be sampled. The batch-size
|
1869
1946
|
must be greater or equal to the ``num_slices`` argument. Exclusive
|
@@ -230,15 +230,38 @@ class ListStorage(Storage):
|
|
230
230
|
max_size (int, optional): the maximum number of elements stored in the storage.
|
231
231
|
If not provided, an unlimited storage is created.
|
232
232
|
|
233
|
+
Keyword Args:
|
234
|
+
compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at
|
235
|
+
the cost of being executable in multiprocessed settings.
|
236
|
+
device (str, optional): the device to use for the storage. Defaults to `None` (inputs are not moved to the device).
|
237
|
+
|
233
238
|
"""
|
234
239
|
|
235
240
|
_default_checkpointer = ListStorageCheckpointer
|
236
241
|
|
237
|
-
def __init__(
|
242
|
+
def __init__(
|
243
|
+
self,
|
244
|
+
max_size: int | None = None,
|
245
|
+
*,
|
246
|
+
compilable: bool = False,
|
247
|
+
device: torch.device | str | int | None = None,
|
248
|
+
):
|
238
249
|
if max_size is None:
|
239
250
|
max_size = torch.iinfo(torch.int64).max
|
240
251
|
super().__init__(max_size, compilable=compilable)
|
241
252
|
self._storage = []
|
253
|
+
self.device = device
|
254
|
+
|
255
|
+
def _to_device(self, data: Any) -> Any:
|
256
|
+
"""Utility method to move data to the device."""
|
257
|
+
if self.device is not None:
|
258
|
+
if hasattr(data, "to"):
|
259
|
+
data = data.to(self.device)
|
260
|
+
else:
|
261
|
+
data = tree_map(
|
262
|
+
lambda x: x.to(self.device) if hasattr(x, "to") else x, data
|
263
|
+
)
|
264
|
+
return data
|
242
265
|
|
243
266
|
def set(
|
244
267
|
self,
|
@@ -254,6 +277,7 @@ class ListStorage(Storage):
|
|
254
277
|
self.set(int(cursor), data, set_cursor=set_cursor)
|
255
278
|
return
|
256
279
|
if isinstance(cursor, slice):
|
280
|
+
data = self._to_device(data)
|
257
281
|
self._storage[cursor] = data
|
258
282
|
return
|
259
283
|
if isinstance(
|
@@ -290,6 +314,7 @@ class ListStorage(Storage):
|
|
290
314
|
f"maximum capacity is {self.max_size} "
|
291
315
|
f"and the index of the item to be set is {cursor}."
|
292
316
|
)
|
317
|
+
data = self._to_device(data)
|
293
318
|
if cursor == len(self._storage):
|
294
319
|
self._storage.append(data)
|
295
320
|
else:
|
@@ -387,6 +412,7 @@ class LazyStackStorage(ListStorage):
|
|
387
412
|
compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at
|
388
413
|
the cost of being executable in multiprocessed settings.
|
389
414
|
stack_dim (int, optional): the stack dimension in terms of TensorDict batch sizes. Defaults to `0`.
|
415
|
+
device (str, optional): the device to use for the storage. Defaults to `None` (inputs are not moved to the device).
|
390
416
|
|
391
417
|
Examples:
|
392
418
|
>>> import torch
|
@@ -421,8 +447,9 @@ class LazyStackStorage(ListStorage):
|
|
421
447
|
*,
|
422
448
|
compilable: bool = False,
|
423
449
|
stack_dim: int = 0,
|
450
|
+
device: torch.device | str | int | None = None,
|
424
451
|
):
|
425
|
-
super().__init__(max_size=max_size, compilable=compilable)
|
452
|
+
super().__init__(max_size=max_size, compilable=compilable, device=device)
|
426
453
|
self.stack_dim = stack_dim
|
427
454
|
|
428
455
|
def get(self, index: int | Sequence[int] | slice) -> Any:
|
torchrl/envs/llm/__init__.py
CHANGED
@@ -22,12 +22,14 @@ from .transforms import (
|
|
22
22
|
KLRewardTransform,
|
23
23
|
MCPToolTransform,
|
24
24
|
PythonInterpreter,
|
25
|
+
RetrieveLogProb,
|
25
26
|
TemplateTransform,
|
26
27
|
Tokenizer,
|
27
28
|
)
|
28
29
|
|
29
30
|
__all__ = [
|
30
31
|
"BrowserTransform",
|
32
|
+
"RetrieveLogProb",
|
31
33
|
"ChatEnv",
|
32
34
|
"DataLoadingPrimer",
|
33
35
|
"DatasetChatEnv",
|
torchrl/envs/llm/chat.py
CHANGED
@@ -206,7 +206,10 @@ class ChatEnv(EnvBase):
|
|
206
206
|
if lh.role != self.policy_role:
|
207
207
|
raise ValueError(
|
208
208
|
"The role received in the last block parsed from the policy "
|
209
|
-
f"output does not match the expected policy role: received {lh.role} but expected {self.policy_role}
|
209
|
+
f"output does not match the expected policy role: received {lh.role} but expected {self.policy_role}.\n"
|
210
|
+
f"Parsed input: {text=}\n"
|
211
|
+
f"Parsed history: {parsed_history=}\n"
|
212
|
+
f"Final element: {local_history=}"
|
210
213
|
)
|
211
214
|
# Append history item
|
212
215
|
history = history.append(local_history, inplace=False)
|
torchrl/envs/llm/reward/gsm8k.py
CHANGED
@@ -145,25 +145,32 @@ class GSM8KRewardParser(Transform):
|
|
145
145
|
potential_answer = [potential_answer]
|
146
146
|
if isinstance(cot, str):
|
147
147
|
cot = [cot]
|
148
|
-
reward_answer = 5.0 * (len(potential_answer) == 1)
|
149
148
|
|
149
|
+
# Format quality rewards (always applied)
|
150
|
+
reward_answer = 5.0 * (len(potential_answer) == 1)
|
150
151
|
reward_think = 5.0 * (len(cot) == 1)
|
151
152
|
|
152
|
-
#
|
153
|
+
# Answer correctness rewards
|
153
154
|
reward_right = 20.0 * (
|
154
155
|
any(attempt == true_answer for attempt in potential_answer)
|
155
156
|
)
|
156
|
-
|
157
|
-
# One of the answer tags contains the right answer (might be e.g. $20 instead of 20)
|
158
157
|
reward_contained = 10.0 * (
|
159
158
|
any((true_answer in attempt) for attempt in potential_answer)
|
160
159
|
)
|
161
160
|
|
162
161
|
success = len(potential_answer) > 0 and potential_answer[-1] == true_answer
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
162
|
+
|
163
|
+
# Base success reward (lower than before to make format quality more important)
|
164
|
+
base_success_reward = 60.0 if success else 0.0
|
165
|
+
|
166
|
+
# Compose the rewards - always include format quality, even when successful
|
167
|
+
reward = (
|
168
|
+
base_success_reward
|
169
|
+
+ reward_answer
|
170
|
+
+ reward_think
|
171
|
+
+ reward_contained
|
172
|
+
+ reward_right
|
173
|
+
)
|
167
174
|
|
168
175
|
rewards = TensorDict(
|
169
176
|
reward_answer=reward_answer,
|
@@ -6,7 +6,7 @@
|
|
6
6
|
from .browser import BrowserTransform
|
7
7
|
from .dataloading import as_nested_tensor, as_padded_tensor, DataLoadingPrimer
|
8
8
|
from .format import TemplateTransform
|
9
|
-
from .kl import KLRewardTransform
|
9
|
+
from .kl import KLRewardTransform, RetrieveLogProb
|
10
10
|
from .policy_version import PolicyVersion
|
11
11
|
from .tokenizer import Tokenizer
|
12
12
|
from .tools import MCPToolTransform, PythonInterpreter
|
@@ -15,6 +15,7 @@ __all__ = [
|
|
15
15
|
"BrowserTransform",
|
16
16
|
"DataLoadingPrimer",
|
17
17
|
"KLRewardTransform",
|
18
|
+
"RetrieveLogProb",
|
18
19
|
"MCPToolTransform",
|
19
20
|
"PolicyVersion",
|
20
21
|
"PythonInterpreter",
|