torchrl-nightly 2025.6.20__cp310-cp310-win_amd64.whl → 2025.6.22__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrl/_torchrl.cp310-win_amd64.pyd +0 -0
- torchrl/collectors/collectors.py +8 -5
- torchrl/collectors/llm/base.py +13 -6
- torchrl/collectors/llm/ray_collector.py +3 -0
- torchrl/data/__init__.py +2 -0
- torchrl/data/llm/__init__.py +2 -0
- torchrl/data/llm/chat.py +59 -8
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/replay_buffers/ray_buffer.py +15 -1
- torchrl/data/replay_buffers/replay_buffers.py +50 -11
- torchrl/data/replay_buffers/samplers.py +98 -21
- torchrl/data/replay_buffers/storages.py +29 -2
- torchrl/envs/llm/__init__.py +2 -0
- torchrl/envs/llm/chat.py +4 -1
- torchrl/envs/llm/reward/gsm8k.py +15 -8
- torchrl/envs/llm/transforms/__init__.py +2 -1
- torchrl/envs/llm/transforms/kl.py +240 -4
- torchrl/modules/llm/policies/transformers_wrapper.py +71 -15
- torchrl/modules/llm/policies/vllm_wrapper.py +38 -5
- torchrl/objectives/llm/__init__.py +2 -1
- torchrl/objectives/llm/sft.py +465 -0
- torchrl/version.py +2 -2
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.22.dist-info}/METADATA +1 -1
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.22.dist-info}/RECORD +27 -25
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.22.dist-info}/LICENSE +0 -0
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.22.dist-info}/WHEEL +0 -0
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.22.dist-info}/top_level.txt +0 -0
@@ -291,17 +291,38 @@ class SamplerWithoutReplacement(Sampler):
|
|
291
291
|
|
292
292
|
|
293
293
|
class PrioritizedSampler(Sampler):
|
294
|
-
"""Prioritized sampler for replay buffer.
|
294
|
+
r"""Prioritized sampler for replay buffer.
|
295
295
|
|
296
|
-
|
296
|
+
This sampler implements Prioritized Experience Replay (PER) as presented in
|
297
|
+
"Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay."
|
298
|
+
(https://arxiv.org/abs/1511.05952)
|
299
|
+
|
300
|
+
**Core Idea**: Instead of sampling experiences uniformly from the replay buffer,
|
301
|
+
PER samples experiences with probability proportional to their "importance" - typically
|
302
|
+
measured by the magnitude of their temporal-difference (TD) error. This prioritization
|
303
|
+
can lead to faster learning by focusing on experiences that are most informative.
|
304
|
+
|
305
|
+
**How it works**:
|
306
|
+
1. Each experience is assigned a priority based on its TD error: :math:`p_i = |\delta_i| + \epsilon`
|
307
|
+
2. Sampling probability is computed as: :math:`P(i) = \frac{p_i^\alpha}{\sum_j p_j^\alpha}`
|
308
|
+
3. Importance sampling weights correct for the bias: :math:`w_i = (N \cdot P(i))^{-\beta}`
|
297
309
|
|
298
310
|
Args:
|
299
311
|
max_capacity (int): maximum capacity of the buffer.
|
300
|
-
alpha (:obj:`float`): exponent
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
312
|
+
alpha (:obj:`float`): exponent :math:`\alpha` determines how much prioritization is used.
|
313
|
+
- :math:`\alpha = 0`: uniform sampling (no prioritization)
|
314
|
+
- :math:`\alpha = 1`: full prioritization based on TD error magnitude
|
315
|
+
- Typical values: 0.4-0.7 for balanced prioritization
|
316
|
+
- Higher :math:`\alpha` means more aggressive prioritization of high-error experiences
|
317
|
+
beta (:obj:`float`): importance sampling negative exponent :math:`\beta`.
|
318
|
+
- :math:`\beta` controls the correction for the bias introduced by prioritization
|
319
|
+
- :math:`\beta = 0`: no correction (biased towards high-priority samples)
|
320
|
+
- :math:`\beta = 1`: full correction (unbiased but potentially unstable)
|
321
|
+
- Typical values: start at 0.4-0.6 and anneal to 1.0 during training
|
322
|
+
- Lower :math:`\beta` early in training provides stability, higher :math:`\beta` later reduces bias
|
323
|
+
eps (:obj:`float`, optional): small constant added to priorities to ensure
|
324
|
+
no experience has zero priority. This prevents experiences from never
|
325
|
+
being sampled. Defaults to 1e-8.
|
305
326
|
reduction (str, optional): the reduction method for multidimensional
|
306
327
|
tensordicts (ie stored trajectory). Can be one of "max", "min",
|
307
328
|
"median" or "mean".
|
@@ -309,6 +330,23 @@ class PrioritizedSampler(Sampler):
|
|
309
330
|
is tracked within the buffer. When ``False``, the max-priority tracks
|
310
331
|
the maximum value since the instantiation of the sampler.
|
311
332
|
|
333
|
+
**Parameter Guidelines**:
|
334
|
+
- **:math:`\alpha` (alpha)**: Controls how much to prioritize high-error experiences
|
335
|
+
- 0.4-0.7: Good balance between learning speed and stability
|
336
|
+
- 1.0: Maximum prioritization (may be unstable)
|
337
|
+
- 0.0: Uniform sampling (no prioritization benefit)
|
338
|
+
|
339
|
+
- **:math:`\beta` (beta)**: Controls importance sampling correction
|
340
|
+
- Start at 0.4-0.6 for training stability
|
341
|
+
- Anneal to 1.0 over training to reduce bias
|
342
|
+
- Lower values = more stable but biased
|
343
|
+
- Higher values = less biased but potentially unstable
|
344
|
+
|
345
|
+
- **:math:`\epsilon`**: Small constant to prevent zero priorities
|
346
|
+
- 1e-8: Good default value
|
347
|
+
- Too small: may cause numerical issues
|
348
|
+
- Too large: reduces prioritization effect
|
349
|
+
|
312
350
|
Examples:
|
313
351
|
>>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
|
314
352
|
>>> from tensordict import TensorDict
|
@@ -412,7 +450,7 @@ class PrioritizedSampler(Sampler):
|
|
412
450
|
)
|
413
451
|
return super().__getstate__()
|
414
452
|
|
415
|
-
def _init(self):
|
453
|
+
def _init(self) -> None:
|
416
454
|
if self.dtype in (torch.float, torch.FloatType, torch.float32):
|
417
455
|
self._sum_tree = SumSegmentTreeFp32(self._max_capacity)
|
418
456
|
self._min_tree = MinSegmentTreeFp32(self._max_capacity)
|
@@ -425,21 +463,23 @@ class PrioritizedSampler(Sampler):
|
|
425
463
|
)
|
426
464
|
self._max_priority = None
|
427
465
|
|
428
|
-
def _empty(self):
|
466
|
+
def _empty(self) -> None:
|
429
467
|
self._init()
|
430
468
|
|
431
469
|
@property
|
432
|
-
def _max_priority(self):
|
470
|
+
def _max_priority(self) -> tuple[float | None, int | None]:
|
433
471
|
max_priority_index = self.__dict__.get("_max_priority")
|
434
472
|
if max_priority_index is None:
|
435
473
|
return (None, None)
|
436
474
|
return max_priority_index
|
437
475
|
|
438
476
|
@_max_priority.setter
|
439
|
-
def _max_priority(self, value):
|
477
|
+
def _max_priority(self, value: tuple[float | None, int | None]) -> None:
|
440
478
|
self.__dict__["_max_priority"] = value
|
441
479
|
|
442
|
-
def _maybe_erase_max_priority(
|
480
|
+
def _maybe_erase_max_priority(
|
481
|
+
self, index: torch.Tensor | int | slice | tuple
|
482
|
+
) -> None:
|
443
483
|
if not self._max_priority_within_buffer:
|
444
484
|
return
|
445
485
|
max_priority_index = self._max_priority[1]
|
@@ -1839,11 +1879,21 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
|
|
1839
1879
|
|
1840
1880
|
|
1841
1881
|
class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
|
1842
|
-
"""Samples slices of data along the first dimension, given start and stop signals, using prioritized sampling.
|
1882
|
+
r"""Samples slices of data along the first dimension, given start and stop signals, using prioritized sampling.
|
1883
|
+
|
1884
|
+
This class combines trajectory sampling with Prioritized Experience Replay (PER) as presented in
|
1885
|
+
"Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay."
|
1886
|
+
(https://arxiv.org/abs/1511.05952)
|
1887
|
+
|
1888
|
+
**Core Idea**: Instead of sampling trajectory slices uniformly, this sampler prioritizes
|
1889
|
+
trajectory start points based on the importance of the transitions at those positions.
|
1890
|
+
This allows focusing learning on the most informative parts of trajectories.
|
1843
1891
|
|
1844
|
-
|
1845
|
-
|
1846
|
-
|
1892
|
+
**How it works**:
|
1893
|
+
1. Each transition is assigned a priority based on its TD error: :math:`p_i = |\\delta_i| + \\epsilon`
|
1894
|
+
2. Trajectory start points are sampled with probability: :math:`P(i) = \frac{p_i^\alpha}{\\sum_j p_j^\alpha}`
|
1895
|
+
3. Importance sampling weights correct for bias: :math:`w_i = (N \\cdot P(i))^{-\beta}`
|
1896
|
+
4. Complete trajectory slices are extracted from the sampled start points
|
1847
1897
|
|
1848
1898
|
For more info see :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` and :class:`~torchrl.data.replay_buffers.samplers.PrioritizedSampler`.
|
1849
1899
|
|
@@ -1855,15 +1905,42 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
|
|
1855
1905
|
:meth:`update_priority`.
|
1856
1906
|
|
1857
1907
|
Args:
|
1858
|
-
|
1859
|
-
|
1860
|
-
|
1861
|
-
|
1862
|
-
|
1908
|
+
max_capacity (int): maximum capacity of the buffer.
|
1909
|
+
alpha (:obj:`float`): exponent :math:`\alpha` determines how much prioritization is used.
|
1910
|
+
- :math:`\alpha = 0`: uniform sampling of trajectory start points
|
1911
|
+
- :math:`\alpha = 1`: full prioritization based on TD error magnitude at start points
|
1912
|
+
- Typical values: 0.4-0.7 for balanced prioritization
|
1913
|
+
- Higher :math:`\alpha` means more aggressive prioritization of high-error trajectory regions
|
1914
|
+
beta (:obj:`float`): importance sampling negative exponent :math:`\beta`.
|
1915
|
+
- :math:`\beta` controls the correction for the bias introduced by prioritization
|
1916
|
+
- :math:`\beta = 0`: no correction (biased towards high-priority trajectory regions)
|
1917
|
+
- :math:`\beta = 1`: full correction (unbiased but potentially unstable)
|
1918
|
+
- Typical values: start at 0.4-0.6 and anneal to 1.0 during training
|
1919
|
+
- Lower :math:`\beta` early in training provides stability, higher :math:`\beta` later reduces bias
|
1920
|
+
eps (:obj:`float`, optional): small constant added to priorities to ensure
|
1921
|
+
no transition has zero priority. This prevents trajectory regions from never
|
1922
|
+
being sampled. Defaults to 1e-8.
|
1863
1923
|
reduction (str, optional): the reduction method for multidimensional
|
1864
1924
|
tensordicts (i.e., stored trajectory). Can be one of "max", "min",
|
1865
1925
|
"median" or "mean".
|
1866
1926
|
|
1927
|
+
**Parameter Guidelines**:
|
1928
|
+
- **:math:`\alpha` (alpha)**: Controls how much to prioritize high-error trajectory regions
|
1929
|
+
- 0.4-0.7: Good balance between learning speed and stability
|
1930
|
+
- 1.0: Maximum prioritization (may be unstable)
|
1931
|
+
- 0.0: Uniform sampling (no prioritization benefit)
|
1932
|
+
|
1933
|
+
- **:math:`\beta` (beta)**: Controls importance sampling correction
|
1934
|
+
- Start at 0.4-0.6 for training stability
|
1935
|
+
- Anneal to 1.0 over training to reduce bias
|
1936
|
+
- Lower values = more stable but biased
|
1937
|
+
- Higher values = less biased but potentially unstable
|
1938
|
+
|
1939
|
+
- **:math:`\\epsilon`**: Small constant to prevent zero priorities
|
1940
|
+
- 1e-8: Good default value
|
1941
|
+
- Too small: may cause numerical issues
|
1942
|
+
- Too large: reduces prioritization effect
|
1943
|
+
|
1867
1944
|
Keyword Args:
|
1868
1945
|
num_slices (int): the number of slices to be sampled. The batch-size
|
1869
1946
|
must be greater or equal to the ``num_slices`` argument. Exclusive
|
@@ -230,15 +230,38 @@ class ListStorage(Storage):
|
|
230
230
|
max_size (int, optional): the maximum number of elements stored in the storage.
|
231
231
|
If not provided, an unlimited storage is created.
|
232
232
|
|
233
|
+
Keyword Args:
|
234
|
+
compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at
|
235
|
+
the cost of being executable in multiprocessed settings.
|
236
|
+
device (str, optional): the device to use for the storage. Defaults to `None` (inputs are not moved to the device).
|
237
|
+
|
233
238
|
"""
|
234
239
|
|
235
240
|
_default_checkpointer = ListStorageCheckpointer
|
236
241
|
|
237
|
-
def __init__(
|
242
|
+
def __init__(
|
243
|
+
self,
|
244
|
+
max_size: int | None = None,
|
245
|
+
*,
|
246
|
+
compilable: bool = False,
|
247
|
+
device: torch.device | str | int | None = None,
|
248
|
+
):
|
238
249
|
if max_size is None:
|
239
250
|
max_size = torch.iinfo(torch.int64).max
|
240
251
|
super().__init__(max_size, compilable=compilable)
|
241
252
|
self._storage = []
|
253
|
+
self.device = device
|
254
|
+
|
255
|
+
def _to_device(self, data: Any) -> Any:
|
256
|
+
"""Utility method to move data to the device."""
|
257
|
+
if self.device is not None:
|
258
|
+
if hasattr(data, "to"):
|
259
|
+
data = data.to(self.device)
|
260
|
+
else:
|
261
|
+
data = tree_map(
|
262
|
+
lambda x: x.to(self.device) if hasattr(x, "to") else x, data
|
263
|
+
)
|
264
|
+
return data
|
242
265
|
|
243
266
|
def set(
|
244
267
|
self,
|
@@ -254,6 +277,7 @@ class ListStorage(Storage):
|
|
254
277
|
self.set(int(cursor), data, set_cursor=set_cursor)
|
255
278
|
return
|
256
279
|
if isinstance(cursor, slice):
|
280
|
+
data = self._to_device(data)
|
257
281
|
self._storage[cursor] = data
|
258
282
|
return
|
259
283
|
if isinstance(
|
@@ -290,6 +314,7 @@ class ListStorage(Storage):
|
|
290
314
|
f"maximum capacity is {self.max_size} "
|
291
315
|
f"and the index of the item to be set is {cursor}."
|
292
316
|
)
|
317
|
+
data = self._to_device(data)
|
293
318
|
if cursor == len(self._storage):
|
294
319
|
self._storage.append(data)
|
295
320
|
else:
|
@@ -387,6 +412,7 @@ class LazyStackStorage(ListStorage):
|
|
387
412
|
compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at
|
388
413
|
the cost of being executable in multiprocessed settings.
|
389
414
|
stack_dim (int, optional): the stack dimension in terms of TensorDict batch sizes. Defaults to `0`.
|
415
|
+
device (str, optional): the device to use for the storage. Defaults to `None` (inputs are not moved to the device).
|
390
416
|
|
391
417
|
Examples:
|
392
418
|
>>> import torch
|
@@ -421,8 +447,9 @@ class LazyStackStorage(ListStorage):
|
|
421
447
|
*,
|
422
448
|
compilable: bool = False,
|
423
449
|
stack_dim: int = 0,
|
450
|
+
device: torch.device | str | int | None = None,
|
424
451
|
):
|
425
|
-
super().__init__(max_size=max_size, compilable=compilable)
|
452
|
+
super().__init__(max_size=max_size, compilable=compilable, device=device)
|
426
453
|
self.stack_dim = stack_dim
|
427
454
|
|
428
455
|
def get(self, index: int | Sequence[int] | slice) -> Any:
|
torchrl/envs/llm/__init__.py
CHANGED
@@ -22,12 +22,14 @@ from .transforms import (
|
|
22
22
|
KLRewardTransform,
|
23
23
|
MCPToolTransform,
|
24
24
|
PythonInterpreter,
|
25
|
+
RetrieveLogProb,
|
25
26
|
TemplateTransform,
|
26
27
|
Tokenizer,
|
27
28
|
)
|
28
29
|
|
29
30
|
__all__ = [
|
30
31
|
"BrowserTransform",
|
32
|
+
"RetrieveLogProb",
|
31
33
|
"ChatEnv",
|
32
34
|
"DataLoadingPrimer",
|
33
35
|
"DatasetChatEnv",
|
torchrl/envs/llm/chat.py
CHANGED
@@ -206,7 +206,10 @@ class ChatEnv(EnvBase):
|
|
206
206
|
if lh.role != self.policy_role:
|
207
207
|
raise ValueError(
|
208
208
|
"The role received in the last block parsed from the policy "
|
209
|
-
f"output does not match the expected policy role: received {lh.role} but expected {self.policy_role}
|
209
|
+
f"output does not match the expected policy role: received {lh.role} but expected {self.policy_role}.\n"
|
210
|
+
f"Parsed input: {text=}\n"
|
211
|
+
f"Parsed history: {parsed_history=}\n"
|
212
|
+
f"Final element: {local_history=}"
|
210
213
|
)
|
211
214
|
# Append history item
|
212
215
|
history = history.append(local_history, inplace=False)
|
torchrl/envs/llm/reward/gsm8k.py
CHANGED
@@ -145,25 +145,32 @@ class GSM8KRewardParser(Transform):
|
|
145
145
|
potential_answer = [potential_answer]
|
146
146
|
if isinstance(cot, str):
|
147
147
|
cot = [cot]
|
148
|
-
reward_answer = 5.0 * (len(potential_answer) == 1)
|
149
148
|
|
149
|
+
# Format quality rewards (always applied)
|
150
|
+
reward_answer = 5.0 * (len(potential_answer) == 1)
|
150
151
|
reward_think = 5.0 * (len(cot) == 1)
|
151
152
|
|
152
|
-
#
|
153
|
+
# Answer correctness rewards
|
153
154
|
reward_right = 20.0 * (
|
154
155
|
any(attempt == true_answer for attempt in potential_answer)
|
155
156
|
)
|
156
|
-
|
157
|
-
# One of the answer tags contains the right answer (might be e.g. $20 instead of 20)
|
158
157
|
reward_contained = 10.0 * (
|
159
158
|
any((true_answer in attempt) for attempt in potential_answer)
|
160
159
|
)
|
161
160
|
|
162
161
|
success = len(potential_answer) > 0 and potential_answer[-1] == true_answer
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
162
|
+
|
163
|
+
# Base success reward (lower than before to make format quality more important)
|
164
|
+
base_success_reward = 60.0 if success else 0.0
|
165
|
+
|
166
|
+
# Compose the rewards - always include format quality, even when successful
|
167
|
+
reward = (
|
168
|
+
base_success_reward
|
169
|
+
+ reward_answer
|
170
|
+
+ reward_think
|
171
|
+
+ reward_contained
|
172
|
+
+ reward_right
|
173
|
+
)
|
167
174
|
|
168
175
|
rewards = TensorDict(
|
169
176
|
reward_answer=reward_answer,
|
@@ -6,7 +6,7 @@
|
|
6
6
|
from .browser import BrowserTransform
|
7
7
|
from .dataloading import as_nested_tensor, as_padded_tensor, DataLoadingPrimer
|
8
8
|
from .format import TemplateTransform
|
9
|
-
from .kl import KLRewardTransform
|
9
|
+
from .kl import KLRewardTransform, RetrieveLogProb
|
10
10
|
from .policy_version import PolicyVersion
|
11
11
|
from .tokenizer import Tokenizer
|
12
12
|
from .tools import MCPToolTransform, PythonInterpreter
|
@@ -15,6 +15,7 @@ __all__ = [
|
|
15
15
|
"BrowserTransform",
|
16
16
|
"DataLoadingPrimer",
|
17
17
|
"KLRewardTransform",
|
18
|
+
"RetrieveLogProb",
|
18
19
|
"MCPToolTransform",
|
19
20
|
"PolicyVersion",
|
20
21
|
"PythonInterpreter",
|
@@ -4,15 +4,25 @@
|
|
4
4
|
# LICENSE file in the root directory of this source tree.
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
|
+
import contextlib
|
8
|
+
import gc
|
9
|
+
|
7
10
|
from copy import copy
|
8
11
|
|
9
12
|
import torch
|
10
|
-
from tensordict import NestedKey, TensorDictBase, unravel_key
|
13
|
+
from tensordict import NestedKey, set_list_to_stack, TensorDictBase, unravel_key
|
11
14
|
from tensordict.nn import ProbabilisticTensorDictModule
|
12
|
-
from tensordict.utils import is_seq_of_nested_key
|
15
|
+
from tensordict.utils import _zip_strict, is_seq_of_nested_key
|
13
16
|
from torchrl.data import Composite, Unbounded
|
17
|
+
from torchrl.data.llm.chat import History
|
14
18
|
from torchrl.envs import EnvBase, Transform
|
15
19
|
from torchrl.envs.transforms.utils import _set_missing_tolerance
|
20
|
+
from torchrl.modules.llm.policies.common import CategoricalSequential
|
21
|
+
|
22
|
+
try:
|
23
|
+
import transformers
|
24
|
+
except ImportError:
|
25
|
+
transformers = None
|
16
26
|
|
17
27
|
|
18
28
|
class KLRewardTransform(Transform):
|
@@ -141,8 +151,8 @@ class KLRewardTransform(Transform):
|
|
141
151
|
f"action_key is required. Please set a parent for the {type(self).__name__} to recover the action keys automatically, "
|
142
152
|
f"or pass the action_key argument directly to {type(self).__name__} constructor."
|
143
153
|
)
|
144
|
-
|
145
|
-
if
|
154
|
+
response_txt = tensordict.get(action_key, None)
|
155
|
+
if response_txt is None:
|
146
156
|
if not self.missing_tolerance:
|
147
157
|
raise RuntimeError(
|
148
158
|
f"Action with key {action_key} not found data {tensordict}"
|
@@ -269,3 +279,229 @@ class KLRewardTransform(Transform):
|
|
269
279
|
observation_spec[self.out_keys[1]] = reward_spec.clone()
|
270
280
|
|
271
281
|
return output_spec
|
282
|
+
|
283
|
+
|
284
|
+
class RetrieveLogProb(Transform):
|
285
|
+
"""A transform to retrieve the log-probs of a text given a reference model.
|
286
|
+
|
287
|
+
Args:
|
288
|
+
actor (CategoricalSequential): the reference model.
|
289
|
+
|
290
|
+
Keyword Args:
|
291
|
+
history_key (NestedKey): the key where the history is stored. Defaults to `"history"`.
|
292
|
+
log_prob_key (NestedKey): the key where the log-probs are stored. Defaults to `"ref_log_prob"`.
|
293
|
+
assistant_only (bool): whether to only retrieve the log-probs of the assistant tokens (i.e., steps of history
|
294
|
+
where the role is `"assistant"`). Defaults to `False`.
|
295
|
+
|
296
|
+
.. note:: The template must accommodate the `return_assistant_tokens_mask` keyword argument.
|
297
|
+
This may not be the case for all templates. In this case, you can pass a custom template to the `apply_chat_template` method
|
298
|
+
via the `tokenizer_kwargs` argument: `tokenizer_kwargs = {"chat_template_name": "qwen"}` or `tokenizer_kwargs = {"chat_template": my_template}.
|
299
|
+
|
300
|
+
tokenizer_kwargs (dict): the keyword arguments to pass to the tokenizer to be used to apply the chat template to the history when `assistant_only` is `True`.
|
301
|
+
To control the tokenization in the actor, pass the tokenizer kwargs to the actor constructor.
|
302
|
+
Defaults to `{"return_assistant_tokens_mask": True, "tokenize": True, "return_tensors": "pt", "padding": True, "add_generation_prompt": False}`.
|
303
|
+
tokenizer (transformers.AutoTokenizer): the tokenizer to be used to tokenize the input and compute the assitant mask. If not provided, the tokenizer will be inferred from the `actor`.
|
304
|
+
detach (bool): whether to exclude the log-probs from the gradient computation. Defaults to `True`.
|
305
|
+
device (torch.device): the device to use for tensor creation. Defaults to `None`.
|
306
|
+
|
307
|
+
Examples:
|
308
|
+
>>> from torchrl.data.llm.chat import History, _CHAT_TEMPLATES
|
309
|
+
>>> from torchrl.modules.llm import TransformersWrapper
|
310
|
+
>>> from torchrl.objectives.llm.sft import SFTLoss
|
311
|
+
>>> from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
|
312
|
+
>>> from tensordict import TensorDict, lazy_stack, set_list_to_stack
|
313
|
+
>>> import torch
|
314
|
+
>>>
|
315
|
+
>>> set_list_to_stack(True).set()
|
316
|
+
>>>
|
317
|
+
>>> # Create chat data
|
318
|
+
>>> chats = [
|
319
|
+
... [
|
320
|
+
... {"role": "system", "content": "You are a helpful assistant."},
|
321
|
+
... {"role": "user", "content": "Hello, how are you?"},
|
322
|
+
... {"role": "assistant", "content": "I'm doing well, thank you!"},
|
323
|
+
... ],
|
324
|
+
... [
|
325
|
+
... {"role": "system", "content": "You are a helpful assistant."},
|
326
|
+
... {"role": "user", "content": "What's the weather like?"},
|
327
|
+
... {"role": "assistant", "content": "I can't check the weather for you."},
|
328
|
+
... ],
|
329
|
+
... ]
|
330
|
+
>>> history = History.from_chats(chats)
|
331
|
+
>>> print(f"Created history with shape: {history.shape}")
|
332
|
+
Created history with shape: torch.Size([2, 3])
|
333
|
+
>>>
|
334
|
+
>>> # Setup tokenizer and model
|
335
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
|
336
|
+
>>> tokenizer.pad_token = tokenizer.eos_token
|
337
|
+
>>> tokenizer.chat_template = _CHAT_TEMPLATES["chatml_format"]
|
338
|
+
>>> model = OPTForCausalLM(OPTConfig()).eval()
|
339
|
+
>>>
|
340
|
+
>>> # Create training and reference policies
|
341
|
+
>>> policy_train = TransformersWrapper(
|
342
|
+
... model,
|
343
|
+
... tokenizer=tokenizer,
|
344
|
+
... generate=False,
|
345
|
+
... from_text=True,
|
346
|
+
... chat_template_name="qwen",
|
347
|
+
... )
|
348
|
+
>>> policy_ref = TransformersWrapper(
|
349
|
+
... model,
|
350
|
+
... tokenizer=tokenizer,
|
351
|
+
... generate=False,
|
352
|
+
... from_text=True,
|
353
|
+
... return_log_probs=True,
|
354
|
+
... chat_template_name="qwen",
|
355
|
+
... )
|
356
|
+
>>>
|
357
|
+
>>> # Create the RetrieveLogProb transform
|
358
|
+
>>> transform = RetrieveLogProb(
|
359
|
+
... policy_ref,
|
360
|
+
... assistant_only=True,
|
361
|
+
... tokenizer_kwargs={"chat_template_name": "qwen"},
|
362
|
+
... tokenizer=tokenizer,
|
363
|
+
... )
|
364
|
+
>>>
|
365
|
+
>>> # Prepare data
|
366
|
+
>>> text = history[:, :-1].apply_chat_template(
|
367
|
+
... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=True
|
368
|
+
... )
|
369
|
+
>>> text_response = history.apply_chat_template(
|
370
|
+
... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=False
|
371
|
+
... )
|
372
|
+
>>> text_response = [
|
373
|
+
... txt[len(txt_start):] for txt, txt_start in zip(text_response, text)
|
374
|
+
... ]
|
375
|
+
>>> td = TensorDict(
|
376
|
+
... text=text,
|
377
|
+
... text_response=text_response,
|
378
|
+
... history=history,
|
379
|
+
... next=TensorDict(
|
380
|
+
... reward=torch.randn(2, 1),
|
381
|
+
... done=torch.zeros(2, dtype=torch.bool),
|
382
|
+
... history=history,
|
383
|
+
... ),
|
384
|
+
... batch_size=(2,),
|
385
|
+
... )
|
386
|
+
>>> data = lazy_stack(list(td.unbind(0)))
|
387
|
+
>>>
|
388
|
+
>>> # Apply the transform to get reference log probabilities
|
389
|
+
>>> data = transform(data)
|
390
|
+
>>> # You can get a padded tensor for batching:
|
391
|
+
>>> ref_log_probs = data.get(("next", "ref_log_prob"), as_padded_tensor=True)
|
392
|
+
>>> print(f"Type: {type(ref_log_probs)}, Length: {len(ref_log_probs)}")
|
393
|
+
Type: <class 'torch.Tensor'>, Length: 2
|
394
|
+
>>> print(f"Example shapes: {[x.shape for x in ref_log_probs]}")
|
395
|
+
Example shapes: [torch.Size([35]), torch.Size([35])]
|
396
|
+
>>> print(ref_log_probs.shape) # (batch, max_seq_len)
|
397
|
+
torch.Size([2, 35])
|
398
|
+
>>>
|
399
|
+
>>> # Use with SFTLoss for KL regularization
|
400
|
+
>>> loss = SFTLoss(
|
401
|
+
... actor_network=policy_train,
|
402
|
+
... tokenizer=tokenizer,
|
403
|
+
... reduction="mean",
|
404
|
+
... normalize_by_seq_length=True,
|
405
|
+
... kl_to_ref_coeff=0.1,
|
406
|
+
... tokenizer_kwargs={"chat_template_name": "qwen"},
|
407
|
+
... )
|
408
|
+
>>> loss_vals = loss(data)
|
409
|
+
>>> print(f"SFT Loss: {loss_vals.loss_sft.item():.4f}")
|
410
|
+
SFT Loss: 10.7856
|
411
|
+
>>> print(f"KL to Reference Loss: {loss_vals.loss_kl_to_ref.item():.4f}")
|
412
|
+
KL to Reference Loss: 0.0000
|
413
|
+
>>> print(f"Total Loss: {loss_vals.sum(reduce=True).item():.4f}")
|
414
|
+
Total Loss: 10.7856
|
415
|
+
|
416
|
+
Note:
|
417
|
+
By default, the log-probabilities are stored as a list of tensors (one per sample, with variable length).
|
418
|
+
Use `as_padded_tensor=True` in `.get()` to obtain a batchable tensor (with padding).
|
419
|
+
The reference log probabilities are computed only for assistant tokens when `assistant_only=True`.
|
420
|
+
|
421
|
+
"""
|
422
|
+
|
423
|
+
def __init__(
|
424
|
+
self,
|
425
|
+
actor: CategoricalSequential,
|
426
|
+
*,
|
427
|
+
history_key: NestedKey | None = None,
|
428
|
+
log_prob_key: NestedKey = "ref_log_prob",
|
429
|
+
assistant_only: bool = False,
|
430
|
+
tokenizer_kwargs: dict | None = None,
|
431
|
+
detach: bool = True,
|
432
|
+
device: torch.device | None = None,
|
433
|
+
tokenizer: transformers.AutoTokenizer | None = None,
|
434
|
+
):
|
435
|
+
if history_key is None:
|
436
|
+
history_key = "history"
|
437
|
+
self.history_key = history_key
|
438
|
+
self.log_prob_key = log_prob_key
|
439
|
+
super().__init__(in_keys=[history_key], out_keys=[log_prob_key])
|
440
|
+
self.actor = actor
|
441
|
+
if not getattr(actor, "return_log_probs", True):
|
442
|
+
raise ValueError(
|
443
|
+
"The actor must have `return_log_probs=True` to use the `AssistantLogProb` transform."
|
444
|
+
)
|
445
|
+
if getattr(actor, "generate", True):
|
446
|
+
raise ValueError(
|
447
|
+
"The actor must have `generate=False` to use the `AssistantLogProb` transform."
|
448
|
+
)
|
449
|
+
if not getattr(actor, "from_text", False):
|
450
|
+
raise ValueError(
|
451
|
+
"The actor must have `from_text=True` to use the `AssistantLogProb` transform. If `from_text=False` is required, please file an issue on GitHub."
|
452
|
+
)
|
453
|
+
# if getattr(self.actor, "tokenizer_kwargs", {}).get("add_generation_prompt", True):
|
454
|
+
# raise ValueError("The actor must have `tokenizer_kwargs['add_generation_prompt']=False` to use the `AssistantLogProb` transform.")
|
455
|
+
self.assistant_only = assistant_only
|
456
|
+
if tokenizer_kwargs is None:
|
457
|
+
tokenizer_kwargs = {}
|
458
|
+
tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True)
|
459
|
+
tokenizer_kwargs.setdefault("tokenize", True)
|
460
|
+
tokenizer_kwargs.setdefault("return_tensors", "pt")
|
461
|
+
tokenizer_kwargs.setdefault("padding", False)
|
462
|
+
tokenizer_kwargs.setdefault("add_generation_prompt", False)
|
463
|
+
self.tokenizer_kwargs = tokenizer_kwargs
|
464
|
+
self.tokenizer = tokenizer
|
465
|
+
self.detach = detach
|
466
|
+
self.device = device
|
467
|
+
|
468
|
+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
469
|
+
next_td = self._step(tensordict, tensordict.get("next"))
|
470
|
+
return tensordict.set("next", next_td)
|
471
|
+
|
472
|
+
@set_list_to_stack(True)
|
473
|
+
def _step(
|
474
|
+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
|
475
|
+
) -> TensorDictBase:
|
476
|
+
td = next_tensordict.select(self.history_key)
|
477
|
+
with torch.device(
|
478
|
+
self.device
|
479
|
+
) if self.device is not None else contextlib.nullcontext(), torch.no_grad() if self.detach else contextlib.nullcontext():
|
480
|
+
result = self.actor(td.select(self.history_key))
|
481
|
+
td.update(result.select(getattr(self.actor, "log_prob_key", "log_probs")))
|
482
|
+
td.rename_key_(
|
483
|
+
getattr(self.actor, "log_prob_key", "log_probs"), self.log_prob_key
|
484
|
+
)
|
485
|
+
if torch.cuda.is_available():
|
486
|
+
gc.collect()
|
487
|
+
torch.cuda.empty_cache()
|
488
|
+
if self.assistant_only:
|
489
|
+
with torch.device(
|
490
|
+
self.device
|
491
|
+
) if self.device is not None else contextlib.nullcontext():
|
492
|
+
# Get assistant mask
|
493
|
+
history: History = td.get(self.history_key)
|
494
|
+
proc = history.apply_chat_template(
|
495
|
+
tokenizer=self.actor.tokenizer
|
496
|
+
if self.tokenizer is None
|
497
|
+
else self.tokenizer,
|
498
|
+
**self.tokenizer_kwargs,
|
499
|
+
)
|
500
|
+
assistant_masks = proc.get("assistant_masks", as_list=True)
|
501
|
+
log_probs = td.get(self.log_prob_key, as_list=True)
|
502
|
+
log_probs = [
|
503
|
+
lp[mask.bool()]
|
504
|
+
for lp, mask in _zip_strict(log_probs, assistant_masks)
|
505
|
+
]
|
506
|
+
td = td.set(self.log_prob_key, log_probs)
|
507
|
+
return next_tensordict.update(td)
|