torchrl-nightly 2025.6.20__cp311-cp311-win_amd64.whl → 2025.6.21__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -291,17 +291,38 @@ class SamplerWithoutReplacement(Sampler):
291
291
 
292
292
 
293
293
  class PrioritizedSampler(Sampler):
294
- """Prioritized sampler for replay buffer.
294
+ r"""Prioritized sampler for replay buffer.
295
295
 
296
- Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay." (https://arxiv.org/abs/1511.05952)
296
+ This sampler implements Prioritized Experience Replay (PER) as presented in
297
+ "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay."
298
+ (https://arxiv.org/abs/1511.05952)
299
+
300
+ **Core Idea**: Instead of sampling experiences uniformly from the replay buffer,
301
+ PER samples experiences with probability proportional to their "importance" - typically
302
+ measured by the magnitude of their temporal-difference (TD) error. This prioritization
303
+ can lead to faster learning by focusing on experiences that are most informative.
304
+
305
+ **How it works**:
306
+ 1. Each experience is assigned a priority based on its TD error: :math:`p_i = |\delta_i| + \epsilon`
307
+ 2. Sampling probability is computed as: :math:`P(i) = \frac{p_i^\alpha}{\sum_j p_j^\alpha}`
308
+ 3. Importance sampling weights correct for the bias: :math:`w_i = (N \cdot P(i))^{-\beta}`
297
309
 
298
310
  Args:
299
311
  max_capacity (int): maximum capacity of the buffer.
300
- alpha (:obj:`float`): exponent α determines how much prioritization is used,
301
- with α = 0 corresponding to the uniform case.
302
- beta (:obj:`float`): importance sampling negative exponent.
303
- eps (:obj:`float`, optional): delta added to the priorities to ensure that the buffer
304
- does not contain null priorities. Defaults to 1e-8.
312
+ alpha (:obj:`float`): exponent :math:`\alpha` determines how much prioritization is used.
313
+ - :math:`\alpha = 0`: uniform sampling (no prioritization)
314
+ - :math:`\alpha = 1`: full prioritization based on TD error magnitude
315
+ - Typical values: 0.4-0.7 for balanced prioritization
316
+ - Higher :math:`\alpha` means more aggressive prioritization of high-error experiences
317
+ beta (:obj:`float`): importance sampling negative exponent :math:`\beta`.
318
+ - :math:`\beta` controls the correction for the bias introduced by prioritization
319
+ - :math:`\beta = 0`: no correction (biased towards high-priority samples)
320
+ - :math:`\beta = 1`: full correction (unbiased but potentially unstable)
321
+ - Typical values: start at 0.4-0.6 and anneal to 1.0 during training
322
+ - Lower :math:`\beta` early in training provides stability, higher :math:`\beta` later reduces bias
323
+ eps (:obj:`float`, optional): small constant added to priorities to ensure
324
+ no experience has zero priority. This prevents experiences from never
325
+ being sampled. Defaults to 1e-8.
305
326
  reduction (str, optional): the reduction method for multidimensional
306
327
  tensordicts (ie stored trajectory). Can be one of "max", "min",
307
328
  "median" or "mean".
@@ -309,6 +330,23 @@ class PrioritizedSampler(Sampler):
309
330
  is tracked within the buffer. When ``False``, the max-priority tracks
310
331
  the maximum value since the instantiation of the sampler.
311
332
 
333
+ **Parameter Guidelines**:
334
+ - **:math:`\alpha` (alpha)**: Controls how much to prioritize high-error experiences
335
+ - 0.4-0.7: Good balance between learning speed and stability
336
+ - 1.0: Maximum prioritization (may be unstable)
337
+ - 0.0: Uniform sampling (no prioritization benefit)
338
+
339
+ - **:math:`\beta` (beta)**: Controls importance sampling correction
340
+ - Start at 0.4-0.6 for training stability
341
+ - Anneal to 1.0 over training to reduce bias
342
+ - Lower values = more stable but biased
343
+ - Higher values = less biased but potentially unstable
344
+
345
+ - **:math:`\epsilon`**: Small constant to prevent zero priorities
346
+ - 1e-8: Good default value
347
+ - Too small: may cause numerical issues
348
+ - Too large: reduces prioritization effect
349
+
312
350
  Examples:
313
351
  >>> from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
314
352
  >>> from tensordict import TensorDict
@@ -412,7 +450,7 @@ class PrioritizedSampler(Sampler):
412
450
  )
413
451
  return super().__getstate__()
414
452
 
415
- def _init(self):
453
+ def _init(self) -> None:
416
454
  if self.dtype in (torch.float, torch.FloatType, torch.float32):
417
455
  self._sum_tree = SumSegmentTreeFp32(self._max_capacity)
418
456
  self._min_tree = MinSegmentTreeFp32(self._max_capacity)
@@ -425,21 +463,23 @@ class PrioritizedSampler(Sampler):
425
463
  )
426
464
  self._max_priority = None
427
465
 
428
- def _empty(self):
466
+ def _empty(self) -> None:
429
467
  self._init()
430
468
 
431
469
  @property
432
- def _max_priority(self):
470
+ def _max_priority(self) -> tuple[float | None, int | None]:
433
471
  max_priority_index = self.__dict__.get("_max_priority")
434
472
  if max_priority_index is None:
435
473
  return (None, None)
436
474
  return max_priority_index
437
475
 
438
476
  @_max_priority.setter
439
- def _max_priority(self, value):
477
+ def _max_priority(self, value: tuple[float | None, int | None]) -> None:
440
478
  self.__dict__["_max_priority"] = value
441
479
 
442
- def _maybe_erase_max_priority(self, index):
480
+ def _maybe_erase_max_priority(
481
+ self, index: torch.Tensor | int | slice | tuple
482
+ ) -> None:
443
483
  if not self._max_priority_within_buffer:
444
484
  return
445
485
  max_priority_index = self._max_priority[1]
@@ -1839,11 +1879,21 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement):
1839
1879
 
1840
1880
 
1841
1881
  class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
1842
- """Samples slices of data along the first dimension, given start and stop signals, using prioritized sampling.
1882
+ r"""Samples slices of data along the first dimension, given start and stop signals, using prioritized sampling.
1883
+
1884
+ This class combines trajectory sampling with Prioritized Experience Replay (PER) as presented in
1885
+ "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. Prioritized experience replay."
1886
+ (https://arxiv.org/abs/1511.05952)
1887
+
1888
+ **Core Idea**: Instead of sampling trajectory slices uniformly, this sampler prioritizes
1889
+ trajectory start points based on the importance of the transitions at those positions.
1890
+ This allows focusing learning on the most informative parts of trajectories.
1843
1891
 
1844
- This class samples sub-trajectories with replacement following a priority weighting presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015.
1845
- Prioritized experience replay."
1846
- (https://arxiv.org/abs/1511.05952)
1892
+ **How it works**:
1893
+ 1. Each transition is assigned a priority based on its TD error: :math:`p_i = |\\delta_i| + \\epsilon`
1894
+ 2. Trajectory start points are sampled with probability: :math:`P(i) = \frac{p_i^\alpha}{\\sum_j p_j^\alpha}`
1895
+ 3. Importance sampling weights correct for bias: :math:`w_i = (N \\cdot P(i))^{-\beta}`
1896
+ 4. Complete trajectory slices are extracted from the sampled start points
1847
1897
 
1848
1898
  For more info see :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` and :class:`~torchrl.data.replay_buffers.samplers.PrioritizedSampler`.
1849
1899
 
@@ -1855,15 +1905,42 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
1855
1905
  :meth:`update_priority`.
1856
1906
 
1857
1907
  Args:
1858
- alpha (:obj:`float`): exponent α determines how much prioritization is used,
1859
- with α = 0 corresponding to the uniform case.
1860
- beta (:obj:`float`): importance sampling negative exponent.
1861
- eps (:obj:`float`, optional): delta added to the priorities to ensure that the buffer
1862
- does not contain null priorities. Defaults to 1e-8.
1908
+ max_capacity (int): maximum capacity of the buffer.
1909
+ alpha (:obj:`float`): exponent :math:`\alpha` determines how much prioritization is used.
1910
+ - :math:`\alpha = 0`: uniform sampling of trajectory start points
1911
+ - :math:`\alpha = 1`: full prioritization based on TD error magnitude at start points
1912
+ - Typical values: 0.4-0.7 for balanced prioritization
1913
+ - Higher :math:`\alpha` means more aggressive prioritization of high-error trajectory regions
1914
+ beta (:obj:`float`): importance sampling negative exponent :math:`\beta`.
1915
+ - :math:`\beta` controls the correction for the bias introduced by prioritization
1916
+ - :math:`\beta = 0`: no correction (biased towards high-priority trajectory regions)
1917
+ - :math:`\beta = 1`: full correction (unbiased but potentially unstable)
1918
+ - Typical values: start at 0.4-0.6 and anneal to 1.0 during training
1919
+ - Lower :math:`\beta` early in training provides stability, higher :math:`\beta` later reduces bias
1920
+ eps (:obj:`float`, optional): small constant added to priorities to ensure
1921
+ no transition has zero priority. This prevents trajectory regions from never
1922
+ being sampled. Defaults to 1e-8.
1863
1923
  reduction (str, optional): the reduction method for multidimensional
1864
1924
  tensordicts (i.e., stored trajectory). Can be one of "max", "min",
1865
1925
  "median" or "mean".
1866
1926
 
1927
+ **Parameter Guidelines**:
1928
+ - **:math:`\alpha` (alpha)**: Controls how much to prioritize high-error trajectory regions
1929
+ - 0.4-0.7: Good balance between learning speed and stability
1930
+ - 1.0: Maximum prioritization (may be unstable)
1931
+ - 0.0: Uniform sampling (no prioritization benefit)
1932
+
1933
+ - **:math:`\beta` (beta)**: Controls importance sampling correction
1934
+ - Start at 0.4-0.6 for training stability
1935
+ - Anneal to 1.0 over training to reduce bias
1936
+ - Lower values = more stable but biased
1937
+ - Higher values = less biased but potentially unstable
1938
+
1939
+ - **:math:`\\epsilon`**: Small constant to prevent zero priorities
1940
+ - 1e-8: Good default value
1941
+ - Too small: may cause numerical issues
1942
+ - Too large: reduces prioritization effect
1943
+
1867
1944
  Keyword Args:
1868
1945
  num_slices (int): the number of slices to be sampled. The batch-size
1869
1946
  must be greater or equal to the ``num_slices`` argument. Exclusive
@@ -230,15 +230,38 @@ class ListStorage(Storage):
230
230
  max_size (int, optional): the maximum number of elements stored in the storage.
231
231
  If not provided, an unlimited storage is created.
232
232
 
233
+ Keyword Args:
234
+ compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at
235
+ the cost of being executable in multiprocessed settings.
236
+ device (str, optional): the device to use for the storage. Defaults to `None` (inputs are not moved to the device).
237
+
233
238
  """
234
239
 
235
240
  _default_checkpointer = ListStorageCheckpointer
236
241
 
237
- def __init__(self, max_size: int | None = None, compilable: bool = False):
242
+ def __init__(
243
+ self,
244
+ max_size: int | None = None,
245
+ *,
246
+ compilable: bool = False,
247
+ device: torch.device | str | int | None = None,
248
+ ):
238
249
  if max_size is None:
239
250
  max_size = torch.iinfo(torch.int64).max
240
251
  super().__init__(max_size, compilable=compilable)
241
252
  self._storage = []
253
+ self.device = device
254
+
255
+ def _to_device(self, data: Any) -> Any:
256
+ """Utility method to move data to the device."""
257
+ if self.device is not None:
258
+ if hasattr(data, "to"):
259
+ data = data.to(self.device)
260
+ else:
261
+ data = tree_map(
262
+ lambda x: x.to(self.device) if hasattr(x, "to") else x, data
263
+ )
264
+ return data
242
265
 
243
266
  def set(
244
267
  self,
@@ -254,6 +277,7 @@ class ListStorage(Storage):
254
277
  self.set(int(cursor), data, set_cursor=set_cursor)
255
278
  return
256
279
  if isinstance(cursor, slice):
280
+ data = self._to_device(data)
257
281
  self._storage[cursor] = data
258
282
  return
259
283
  if isinstance(
@@ -290,6 +314,7 @@ class ListStorage(Storage):
290
314
  f"maximum capacity is {self.max_size} "
291
315
  f"and the index of the item to be set is {cursor}."
292
316
  )
317
+ data = self._to_device(data)
293
318
  if cursor == len(self._storage):
294
319
  self._storage.append(data)
295
320
  else:
@@ -387,6 +412,7 @@ class LazyStackStorage(ListStorage):
387
412
  compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at
388
413
  the cost of being executable in multiprocessed settings.
389
414
  stack_dim (int, optional): the stack dimension in terms of TensorDict batch sizes. Defaults to `0`.
415
+ device (str, optional): the device to use for the storage. Defaults to `None` (inputs are not moved to the device).
390
416
 
391
417
  Examples:
392
418
  >>> import torch
@@ -421,8 +447,9 @@ class LazyStackStorage(ListStorage):
421
447
  *,
422
448
  compilable: bool = False,
423
449
  stack_dim: int = 0,
450
+ device: torch.device | str | int | None = None,
424
451
  ):
425
- super().__init__(max_size=max_size, compilable=compilable)
452
+ super().__init__(max_size=max_size, compilable=compilable, device=device)
426
453
  self.stack_dim = stack_dim
427
454
 
428
455
  def get(self, index: int | Sequence[int] | slice) -> Any:
@@ -22,12 +22,14 @@ from .transforms import (
22
22
  KLRewardTransform,
23
23
  MCPToolTransform,
24
24
  PythonInterpreter,
25
+ RetrieveLogProb,
25
26
  TemplateTransform,
26
27
  Tokenizer,
27
28
  )
28
29
 
29
30
  __all__ = [
30
31
  "BrowserTransform",
32
+ "RetrieveLogProb",
31
33
  "ChatEnv",
32
34
  "DataLoadingPrimer",
33
35
  "DatasetChatEnv",
torchrl/envs/llm/chat.py CHANGED
@@ -206,7 +206,10 @@ class ChatEnv(EnvBase):
206
206
  if lh.role != self.policy_role:
207
207
  raise ValueError(
208
208
  "The role received in the last block parsed from the policy "
209
- f"output does not match the expected policy role: received {lh.role} but expected {self.policy_role}."
209
+ f"output does not match the expected policy role: received {lh.role} but expected {self.policy_role}.\n"
210
+ f"Parsed input: {text=}\n"
211
+ f"Parsed history: {parsed_history=}\n"
212
+ f"Final element: {local_history=}"
210
213
  )
211
214
  # Append history item
212
215
  history = history.append(local_history, inplace=False)
@@ -145,25 +145,32 @@ class GSM8KRewardParser(Transform):
145
145
  potential_answer = [potential_answer]
146
146
  if isinstance(cot, str):
147
147
  cot = [cot]
148
- reward_answer = 5.0 * (len(potential_answer) == 1)
149
148
 
149
+ # Format quality rewards (always applied)
150
+ reward_answer = 5.0 * (len(potential_answer) == 1)
150
151
  reward_think = 5.0 * (len(cot) == 1)
151
152
 
152
- # One of the answer tags has the right answer
153
+ # Answer correctness rewards
153
154
  reward_right = 20.0 * (
154
155
  any(attempt == true_answer for attempt in potential_answer)
155
156
  )
156
-
157
- # One of the answer tags contains the right answer (might be e.g. $20 instead of 20)
158
157
  reward_contained = 10.0 * (
159
158
  any((true_answer in attempt) for attempt in potential_answer)
160
159
  )
161
160
 
162
161
  success = len(potential_answer) > 0 and potential_answer[-1] == true_answer
163
- # Compose the rewards
164
- reward = 100.0 * float(success) + (
165
- reward_answer + reward_think + reward_contained + reward_right
166
- ) * (1 - float(success))
162
+
163
+ # Base success reward (lower than before to make format quality more important)
164
+ base_success_reward = 60.0 if success else 0.0
165
+
166
+ # Compose the rewards - always include format quality, even when successful
167
+ reward = (
168
+ base_success_reward
169
+ + reward_answer
170
+ + reward_think
171
+ + reward_contained
172
+ + reward_right
173
+ )
167
174
 
168
175
  rewards = TensorDict(
169
176
  reward_answer=reward_answer,
@@ -6,7 +6,7 @@
6
6
  from .browser import BrowserTransform
7
7
  from .dataloading import as_nested_tensor, as_padded_tensor, DataLoadingPrimer
8
8
  from .format import TemplateTransform
9
- from .kl import KLRewardTransform
9
+ from .kl import KLRewardTransform, RetrieveLogProb
10
10
  from .policy_version import PolicyVersion
11
11
  from .tokenizer import Tokenizer
12
12
  from .tools import MCPToolTransform, PythonInterpreter
@@ -15,6 +15,7 @@ __all__ = [
15
15
  "BrowserTransform",
16
16
  "DataLoadingPrimer",
17
17
  "KLRewardTransform",
18
+ "RetrieveLogProb",
18
19
  "MCPToolTransform",
19
20
  "PolicyVersion",
20
21
  "PythonInterpreter",
@@ -4,15 +4,25 @@
4
4
  # LICENSE file in the root directory of this source tree.
5
5
  from __future__ import annotations
6
6
 
7
+ import contextlib
8
+ import gc
9
+
7
10
  from copy import copy
8
11
 
9
12
  import torch
10
- from tensordict import NestedKey, TensorDictBase, unravel_key
13
+ from tensordict import NestedKey, set_list_to_stack, TensorDictBase, unravel_key
11
14
  from tensordict.nn import ProbabilisticTensorDictModule
12
- from tensordict.utils import is_seq_of_nested_key
15
+ from tensordict.utils import _zip_strict, is_seq_of_nested_key
13
16
  from torchrl.data import Composite, Unbounded
17
+ from torchrl.data.llm.chat import History
14
18
  from torchrl.envs import EnvBase, Transform
15
19
  from torchrl.envs.transforms.utils import _set_missing_tolerance
20
+ from torchrl.modules.llm.policies.common import CategoricalSequential
21
+
22
+ try:
23
+ import transformers
24
+ except ImportError:
25
+ transformers = None
16
26
 
17
27
 
18
28
  class KLRewardTransform(Transform):
@@ -141,8 +151,8 @@ class KLRewardTransform(Transform):
141
151
  f"action_key is required. Please set a parent for the {type(self).__name__} to recover the action keys automatically, "
142
152
  f"or pass the action_key argument directly to {type(self).__name__} constructor."
143
153
  )
144
- action = tensordict.get(action_key, None)
145
- if action is None:
154
+ response_txt = tensordict.get(action_key, None)
155
+ if response_txt is None:
146
156
  if not self.missing_tolerance:
147
157
  raise RuntimeError(
148
158
  f"Action with key {action_key} not found data {tensordict}"
@@ -269,3 +279,229 @@ class KLRewardTransform(Transform):
269
279
  observation_spec[self.out_keys[1]] = reward_spec.clone()
270
280
 
271
281
  return output_spec
282
+
283
+
284
+ class RetrieveLogProb(Transform):
285
+ """A transform to retrieve the log-probs of a text given a reference model.
286
+
287
+ Args:
288
+ actor (CategoricalSequential): the reference model.
289
+
290
+ Keyword Args:
291
+ history_key (NestedKey): the key where the history is stored. Defaults to `"history"`.
292
+ log_prob_key (NestedKey): the key where the log-probs are stored. Defaults to `"ref_log_prob"`.
293
+ assistant_only (bool): whether to only retrieve the log-probs of the assistant tokens (i.e., steps of history
294
+ where the role is `"assistant"`). Defaults to `False`.
295
+
296
+ .. note:: The template must accommodate the `return_assistant_tokens_mask` keyword argument.
297
+ This may not be the case for all templates. In this case, you can pass a custom template to the `apply_chat_template` method
298
+ via the `tokenizer_kwargs` argument: `tokenizer_kwargs = {"chat_template_name": "qwen"}` or `tokenizer_kwargs = {"chat_template": my_template}.
299
+
300
+ tokenizer_kwargs (dict): the keyword arguments to pass to the tokenizer to be used to apply the chat template to the history when `assistant_only` is `True`.
301
+ To control the tokenization in the actor, pass the tokenizer kwargs to the actor constructor.
302
+ Defaults to `{"return_assistant_tokens_mask": True, "tokenize": True, "return_tensors": "pt", "padding": True, "add_generation_prompt": False}`.
303
+ tokenizer (transformers.AutoTokenizer): the tokenizer to be used to tokenize the input and compute the assitant mask. If not provided, the tokenizer will be inferred from the `actor`.
304
+ detach (bool): whether to exclude the log-probs from the gradient computation. Defaults to `True`.
305
+ device (torch.device): the device to use for tensor creation. Defaults to `None`.
306
+
307
+ Examples:
308
+ >>> from torchrl.data.llm.chat import History, _CHAT_TEMPLATES
309
+ >>> from torchrl.modules.llm import TransformersWrapper
310
+ >>> from torchrl.objectives.llm.sft import SFTLoss
311
+ >>> from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
312
+ >>> from tensordict import TensorDict, lazy_stack, set_list_to_stack
313
+ >>> import torch
314
+ >>>
315
+ >>> set_list_to_stack(True).set()
316
+ >>>
317
+ >>> # Create chat data
318
+ >>> chats = [
319
+ ... [
320
+ ... {"role": "system", "content": "You are a helpful assistant."},
321
+ ... {"role": "user", "content": "Hello, how are you?"},
322
+ ... {"role": "assistant", "content": "I'm doing well, thank you!"},
323
+ ... ],
324
+ ... [
325
+ ... {"role": "system", "content": "You are a helpful assistant."},
326
+ ... {"role": "user", "content": "What's the weather like?"},
327
+ ... {"role": "assistant", "content": "I can't check the weather for you."},
328
+ ... ],
329
+ ... ]
330
+ >>> history = History.from_chats(chats)
331
+ >>> print(f"Created history with shape: {history.shape}")
332
+ Created history with shape: torch.Size([2, 3])
333
+ >>>
334
+ >>> # Setup tokenizer and model
335
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
336
+ >>> tokenizer.pad_token = tokenizer.eos_token
337
+ >>> tokenizer.chat_template = _CHAT_TEMPLATES["chatml_format"]
338
+ >>> model = OPTForCausalLM(OPTConfig()).eval()
339
+ >>>
340
+ >>> # Create training and reference policies
341
+ >>> policy_train = TransformersWrapper(
342
+ ... model,
343
+ ... tokenizer=tokenizer,
344
+ ... generate=False,
345
+ ... from_text=True,
346
+ ... chat_template_name="qwen",
347
+ ... )
348
+ >>> policy_ref = TransformersWrapper(
349
+ ... model,
350
+ ... tokenizer=tokenizer,
351
+ ... generate=False,
352
+ ... from_text=True,
353
+ ... return_log_probs=True,
354
+ ... chat_template_name="qwen",
355
+ ... )
356
+ >>>
357
+ >>> # Create the RetrieveLogProb transform
358
+ >>> transform = RetrieveLogProb(
359
+ ... policy_ref,
360
+ ... assistant_only=True,
361
+ ... tokenizer_kwargs={"chat_template_name": "qwen"},
362
+ ... tokenizer=tokenizer,
363
+ ... )
364
+ >>>
365
+ >>> # Prepare data
366
+ >>> text = history[:, :-1].apply_chat_template(
367
+ ... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=True
368
+ ... )
369
+ >>> text_response = history.apply_chat_template(
370
+ ... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=False
371
+ ... )
372
+ >>> text_response = [
373
+ ... txt[len(txt_start):] for txt, txt_start in zip(text_response, text)
374
+ ... ]
375
+ >>> td = TensorDict(
376
+ ... text=text,
377
+ ... text_response=text_response,
378
+ ... history=history,
379
+ ... next=TensorDict(
380
+ ... reward=torch.randn(2, 1),
381
+ ... done=torch.zeros(2, dtype=torch.bool),
382
+ ... history=history,
383
+ ... ),
384
+ ... batch_size=(2,),
385
+ ... )
386
+ >>> data = lazy_stack(list(td.unbind(0)))
387
+ >>>
388
+ >>> # Apply the transform to get reference log probabilities
389
+ >>> data = transform(data)
390
+ >>> # You can get a padded tensor for batching:
391
+ >>> ref_log_probs = data.get(("next", "ref_log_prob"), as_padded_tensor=True)
392
+ >>> print(f"Type: {type(ref_log_probs)}, Length: {len(ref_log_probs)}")
393
+ Type: <class 'torch.Tensor'>, Length: 2
394
+ >>> print(f"Example shapes: {[x.shape for x in ref_log_probs]}")
395
+ Example shapes: [torch.Size([35]), torch.Size([35])]
396
+ >>> print(ref_log_probs.shape) # (batch, max_seq_len)
397
+ torch.Size([2, 35])
398
+ >>>
399
+ >>> # Use with SFTLoss for KL regularization
400
+ >>> loss = SFTLoss(
401
+ ... actor_network=policy_train,
402
+ ... tokenizer=tokenizer,
403
+ ... reduction="mean",
404
+ ... normalize_by_seq_length=True,
405
+ ... kl_to_ref_coeff=0.1,
406
+ ... tokenizer_kwargs={"chat_template_name": "qwen"},
407
+ ... )
408
+ >>> loss_vals = loss(data)
409
+ >>> print(f"SFT Loss: {loss_vals.loss_sft.item():.4f}")
410
+ SFT Loss: 10.7856
411
+ >>> print(f"KL to Reference Loss: {loss_vals.loss_kl_to_ref.item():.4f}")
412
+ KL to Reference Loss: 0.0000
413
+ >>> print(f"Total Loss: {loss_vals.sum(reduce=True).item():.4f}")
414
+ Total Loss: 10.7856
415
+
416
+ Note:
417
+ By default, the log-probabilities are stored as a list of tensors (one per sample, with variable length).
418
+ Use `as_padded_tensor=True` in `.get()` to obtain a batchable tensor (with padding).
419
+ The reference log probabilities are computed only for assistant tokens when `assistant_only=True`.
420
+
421
+ """
422
+
423
+ def __init__(
424
+ self,
425
+ actor: CategoricalSequential,
426
+ *,
427
+ history_key: NestedKey | None = None,
428
+ log_prob_key: NestedKey = "ref_log_prob",
429
+ assistant_only: bool = False,
430
+ tokenizer_kwargs: dict | None = None,
431
+ detach: bool = True,
432
+ device: torch.device | None = None,
433
+ tokenizer: transformers.AutoTokenizer | None = None,
434
+ ):
435
+ if history_key is None:
436
+ history_key = "history"
437
+ self.history_key = history_key
438
+ self.log_prob_key = log_prob_key
439
+ super().__init__(in_keys=[history_key], out_keys=[log_prob_key])
440
+ self.actor = actor
441
+ if not getattr(actor, "return_log_probs", True):
442
+ raise ValueError(
443
+ "The actor must have `return_log_probs=True` to use the `AssistantLogProb` transform."
444
+ )
445
+ if getattr(actor, "generate", True):
446
+ raise ValueError(
447
+ "The actor must have `generate=False` to use the `AssistantLogProb` transform."
448
+ )
449
+ if not getattr(actor, "from_text", False):
450
+ raise ValueError(
451
+ "The actor must have `from_text=True` to use the `AssistantLogProb` transform. If `from_text=False` is required, please file an issue on GitHub."
452
+ )
453
+ # if getattr(self.actor, "tokenizer_kwargs", {}).get("add_generation_prompt", True):
454
+ # raise ValueError("The actor must have `tokenizer_kwargs['add_generation_prompt']=False` to use the `AssistantLogProb` transform.")
455
+ self.assistant_only = assistant_only
456
+ if tokenizer_kwargs is None:
457
+ tokenizer_kwargs = {}
458
+ tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True)
459
+ tokenizer_kwargs.setdefault("tokenize", True)
460
+ tokenizer_kwargs.setdefault("return_tensors", "pt")
461
+ tokenizer_kwargs.setdefault("padding", False)
462
+ tokenizer_kwargs.setdefault("add_generation_prompt", False)
463
+ self.tokenizer_kwargs = tokenizer_kwargs
464
+ self.tokenizer = tokenizer
465
+ self.detach = detach
466
+ self.device = device
467
+
468
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
469
+ next_td = self._step(tensordict, tensordict.get("next"))
470
+ return tensordict.set("next", next_td)
471
+
472
+ @set_list_to_stack(True)
473
+ def _step(
474
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
475
+ ) -> TensorDictBase:
476
+ td = next_tensordict.select(self.history_key)
477
+ with torch.device(
478
+ self.device
479
+ ) if self.device is not None else contextlib.nullcontext(), torch.no_grad() if self.detach else contextlib.nullcontext():
480
+ result = self.actor(td.select(self.history_key))
481
+ td.update(result.select(getattr(self.actor, "log_prob_key", "log_probs")))
482
+ td.rename_key_(
483
+ getattr(self.actor, "log_prob_key", "log_probs"), self.log_prob_key
484
+ )
485
+ if torch.cuda.is_available():
486
+ gc.collect()
487
+ torch.cuda.empty_cache()
488
+ if self.assistant_only:
489
+ with torch.device(
490
+ self.device
491
+ ) if self.device is not None else contextlib.nullcontext():
492
+ # Get assistant mask
493
+ history: History = td.get(self.history_key)
494
+ proc = history.apply_chat_template(
495
+ tokenizer=self.actor.tokenizer
496
+ if self.tokenizer is None
497
+ else self.tokenizer,
498
+ **self.tokenizer_kwargs,
499
+ )
500
+ assistant_masks = proc.get("assistant_masks", as_list=True)
501
+ log_probs = td.get(self.log_prob_key, as_list=True)
502
+ log_probs = [
503
+ lp[mask.bool()]
504
+ for lp, mask in _zip_strict(log_probs, assistant_masks)
505
+ ]
506
+ td = td.set(self.log_prob_key, log_probs)
507
+ return next_tensordict.update(td)