torchrl-nightly 2025.7.16__cp310-cp310-win_amd64.whl → 2025.7.17__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Binary file
@@ -686,6 +686,10 @@ class SyncDataCollector(DataCollectorBase):
686
686
  policy = RandomPolicy(env.full_action_spec)
687
687
  elif policy_factory is not None:
688
688
  raise TypeError("policy_factory cannot be used with policy argument.")
689
+ # If the underlying policy has a state_dict, we keep a reference to the policy and
690
+ # do all policy weight saving/loading through it
691
+ if hasattr(policy, "state_dict"):
692
+ self._policy_w_state_dict = policy
689
693
 
690
694
  if trust_policy is None:
691
695
  trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule))
@@ -1686,8 +1690,8 @@ class SyncDataCollector(DataCollectorBase):
1686
1690
  else:
1687
1691
  env_state_dict = OrderedDict()
1688
1692
 
1689
- if hasattr(self.policy, "state_dict"):
1690
- policy_state_dict = self.policy.state_dict()
1693
+ if hasattr(self, "_policy_w_state_dict"):
1694
+ policy_state_dict = self._policy_w_state_dict.state_dict()
1691
1695
  state_dict = OrderedDict(
1692
1696
  policy_state_dict=policy_state_dict,
1693
1697
  env_state_dict=env_state_dict,
@@ -1711,7 +1715,13 @@ class SyncDataCollector(DataCollectorBase):
1711
1715
  if strict or "env_state_dict" in state_dict:
1712
1716
  self.env.load_state_dict(state_dict["env_state_dict"], **kwargs)
1713
1717
  if strict or "policy_state_dict" in state_dict:
1714
- self.policy.load_state_dict(state_dict["policy_state_dict"], **kwargs)
1718
+ if not hasattr(self, "_policy_w_state_dict"):
1719
+ raise ValueError(
1720
+ "Underlying policy does not have state_dict to load policy_state_dict into."
1721
+ )
1722
+ self._policy_w_state_dict.load_state_dict(
1723
+ state_dict["policy_state_dict"], **kwargs
1724
+ )
1715
1725
  self._frames = state_dict["frames"]
1716
1726
  self._iter = state_dict["iter"]
1717
1727
 
@@ -352,7 +352,7 @@ class MaskedCategorical(D.Categorical):
352
352
  logits = self.logits
353
353
  if logits.ndim > 2:
354
354
  # Bring channels in 2nd dim
355
- logits = logits.transpose(-1, 1)
355
+ logits = logits.permute(0, -1, *range(1, logits.ndim - 1))
356
356
  original_value_shape = None
357
357
  if logits.ndim == 1 and value.ndim >= 1:
358
358
  if value.ndim >= 2:
@@ -9,8 +9,8 @@ import weakref
9
9
  from typing import Any, Literal, overload
10
10
 
11
11
  import torch
12
- from tensordict import NestedKey, TensorDictBase
13
- from tensordict.nn import TensorDictModuleBase, TensorDictSequential
12
+ from tensordict import lazy_stack, NestedKey, TensorDictBase
13
+ from tensordict.nn import TensorDictModuleBase
14
14
  from tensordict.tensorclass import TensorClass
15
15
  from tensordict.utils import _zip_strict
16
16
  from torch import distributions as D
@@ -175,29 +175,35 @@ class ChatHistory(TensorClass["nocast"]):
175
175
  def __post_init__(self):
176
176
  # Check that all history objects have one more batch dimension than the ChatHistory object
177
177
  if self.prompt is not None:
178
- if self.prompt.batch_dims != self.batch_dims + 1:
178
+ if getattr(self.prompt, "batch_dims", None) == self.batch_dims:
179
179
  warnings.warn(
180
180
  "Prompt history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
181
181
  f"got {self.prompt.batch_dims} and {self.batch_dims}. "
182
182
  "The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
183
183
  )
184
- self.prompt = self.prompt.unsqueeze(-1)
184
+ self.prompt = lazy_stack(
185
+ [self.prompt], -1
186
+ ) # equivalent to unsqueeze(-1) but make sure it's a lazy stack
185
187
  if self.response is not None:
186
- if self.response.batch_dims != self.batch_dims + 1:
188
+ if getattr(self.response, "batch_dims", None) == self.batch_dims:
187
189
  warnings.warn(
188
190
  "Response history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
189
191
  f"got {self.response.batch_dims} and {self.batch_dims}. "
190
192
  "The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
191
193
  )
192
- self.response = self.response.unsqueeze(-1)
194
+ self.response = lazy_stack(
195
+ [self.response], -1
196
+ ) # equivalent to unsqueeze(-1) but make sure it's a lazy stack
193
197
  if self.full is not None:
194
- if self.full.batch_dims != self.batch_dims + 1:
198
+ if getattr(self.full, "batch_dims", None) == self.batch_dims:
195
199
  warnings.warn(
196
200
  "Full history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
197
201
  f"got {self.full.batch_dims} and {self.batch_dims}. "
198
202
  "The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
199
203
  )
200
- self.full = self.full.unsqueeze(-1)
204
+ self.full = lazy_stack(
205
+ [self.full], -1
206
+ ) # equivalent to unsqueeze(-1) but make sure it's a lazy stack
201
207
 
202
208
 
203
209
  class LogProbs(TensorClass["nocast"]):
@@ -482,7 +488,7 @@ class LLMWrapperBase(TensorDictModuleBase):
482
488
  "You can create a new version of this wrapper using the `get_new_version` method."
483
489
  )
484
490
 
485
- td_out = self(tensordict.copy())
491
+ td_out = self.forward(tensordict.copy(), logits_only=True)
486
492
 
487
493
  # Get logits/log-probs
488
494
  if as_padded_tensor is None:
@@ -557,7 +563,7 @@ class LLMWrapperBase(TensorDictModuleBase):
557
563
  "get_dist_with_prompt_mask is not implemented for generate=True. "
558
564
  "You can create a new version of this wrapper using the `get_new_version` method."
559
565
  )
560
- td_out = self(tensordict.copy())
566
+ td_out = self.forward(tensordict.copy(), logits_only=True)
561
567
 
562
568
  # Try to get prompt tokens first
563
569
  if self.pad_output:
@@ -668,7 +674,7 @@ class LLMWrapperBase(TensorDictModuleBase):
668
674
  "get_dist_with_assistant_mask is not implemented for generate=True. "
669
675
  "You can create a new version of this wrapper using the `get_new_version` method."
670
676
  )
671
- td_out = self(tensordict.copy())
677
+ td_out = self.forward(tensordict.copy(), logits_only=True)
672
678
  # Update the tokens key to reflect the tokenized history when querying the log-probs
673
679
  tensordict.update(
674
680
  td_out,
@@ -737,7 +743,7 @@ class LLMWrapperBase(TensorDictModuleBase):
737
743
  "get_dist_with_attention_mask is not implemented for generate=True. "
738
744
  "You can create a new version of this wrapper using the `get_new_version` method."
739
745
  )
740
- td_out = self(tensordict.copy())
746
+ td_out = self.forward(tensordict.copy(), logits_only=True)
741
747
  if self.pad_output:
742
748
  logits = td_out.get(logits_key)
743
749
  attention_mask = td_out.get(attention_mask_key)
@@ -794,7 +800,7 @@ class LLMWrapperBase(TensorDictModuleBase):
794
800
  "get_dist_with_custom_mask is not implemented for generate=True. "
795
801
  "You can create a new version of this wrapper using the `get_new_version` method."
796
802
  )
797
- td_out = self(tensordict.copy())
803
+ td_out = self.forward(tensordict.copy(), logits_only=True)
798
804
  if self.pad_output:
799
805
  logits = td_out.get(logits_key)
800
806
  else:
@@ -841,8 +847,24 @@ class LLMWrapperBase(TensorDictModuleBase):
841
847
  """
842
848
  return self._get_dist_with_attention_mask(tensordict, **kwargs)
843
849
 
844
- # Sampling is taken care of by the sub-modules
845
- forward = TensorDictSequential.forward
850
+ def forward(
851
+ self,
852
+ tensordict: TensorDictBase,
853
+ *,
854
+ tensordict_out: TensorDictBase | None = None,
855
+ logits_only: bool = False,
856
+ **kwargs,
857
+ ) -> TensorDictBase: # noqa: D417
858
+ """Forward pass for the LLM policy.
859
+
860
+ Args:
861
+ tensordict (TensorDictBase): The input tensordict.
862
+
863
+ Keyword Args:
864
+ tensordict_out (TensorDictBase | None): The output tensordict.
865
+ logits_only (bool): Whether to return only the logits. Only effective if generate=False. Defaults to `False`.
866
+ """
867
+ raise NotImplementedError
846
868
 
847
869
  def _check_padded(self, val: torch.Tensor) -> torch.Tensor:
848
870
  """Check that a value is a padded tensor."""
@@ -13,6 +13,7 @@ from typing import Literal
13
13
  import torch
14
14
  from tensordict import (
15
15
  lazy_stack,
16
+ LazyStackedTensorDict,
16
17
  MetaData,
17
18
  NonTensorStack,
18
19
  set_list_to_stack,
@@ -468,19 +469,32 @@ class TransformersWrapper(LLMWrapperBase):
468
469
  def forward(
469
470
  self,
470
471
  tensordict: TensorDictBase,
472
+ *,
471
473
  tensordict_out: TensorDictBase | None = None,
474
+ logits_only: bool = False,
472
475
  **kwargs,
473
476
  ) -> TensorDictBase:
477
+ tensordict_orig = tensordict
474
478
  if not tensordict.ndim:
479
+ if tensordict_out is not None:
480
+ raise ValueError(
481
+ "tensordict_out must not be provided when tensordict.ndim == 0. If this is needed, "
482
+ "please submit an issue on github."
483
+ )
475
484
  # unsqueeze - squeeze the input
476
- try:
477
- return self(lazy_stack([tensordict])).squeeze(0)
478
- except Exception as e:
479
- raise RuntimeError(
480
- f"Unsqueeze/squeeze failed. Inputs to {type(self).__name__} should ideally be 1 dimensional."
481
- ) from e
485
+ return self.forward(lazy_stack([tensordict]), logits_only=logits_only)[0]
482
486
  elif tensordict.ndim > 1:
483
- return self(tensordict.reshape(-1)).view(tensordict.shape)
487
+ if tensordict_out is not None:
488
+ raise ValueError(
489
+ "tensordict_out must not be provided when tensordict.ndim > 1. If this is needed, "
490
+ "please submit an issue on github."
491
+ )
492
+ return self.forward(tensordict.reshape(-1), logits_only=logits_only).view(
493
+ tensordict.shape
494
+ )
495
+
496
+ if not isinstance(tensordict, LazyStackedTensorDict):
497
+ tensordict = tensordict.to_lazystack(0)
484
498
 
485
499
  _source_device = None
486
500
  if self._device:
@@ -517,17 +531,23 @@ class TransformersWrapper(LLMWrapperBase):
517
531
  if self.generate:
518
532
  out = self._from_transformers_generate_history(tensordict, cfg, out)
519
533
  else:
520
- out = self._from_transformers_logprobs_history(tensordict, cfg, out)
534
+ out = self._from_transformers_logprobs_history(
535
+ tensordict, cfg, out, logits_only=logits_only
536
+ )
521
537
  elif self.input_mode == "text":
522
538
  if self.generate:
523
539
  out = self._from_transformers_generate_text(tensordict, cfg, out)
524
540
  else:
525
- out = self._from_transformers_logprobs_text(tensordict, cfg, out)
541
+ out = self._from_transformers_logprobs_text(
542
+ tensordict, cfg, out, logits_only=logits_only
543
+ )
526
544
  elif self.input_mode == "tokens":
527
545
  if self.generate:
528
546
  out = self._from_transformers_generate_tokens(tensordict, cfg, out)
529
547
  else:
530
- out = self._from_transformers_logprobs_tokens(tensordict, cfg, out)
548
+ out = self._from_transformers_logprobs_tokens(
549
+ tensordict, cfg, out, logits_only=logits_only
550
+ )
531
551
 
532
552
  if _source_device:
533
553
  out = out.to(_source_device)
@@ -535,7 +555,7 @@ class TransformersWrapper(LLMWrapperBase):
535
555
  if tensordict_out is None:
536
556
  if self.inplace is True:
537
557
  # The output is the input
538
- tensordict_out = tensordict
558
+ tensordict_out = tensordict_orig
539
559
  elif self.inplace is False:
540
560
  # The output is the new structure
541
561
  tensordict_out = out
@@ -690,7 +710,7 @@ class TransformersWrapper(LLMWrapperBase):
690
710
  result.set(self.history_key, history_chat)
691
711
  return result
692
712
 
693
- def _from_transformers_logprobs_history(self, td, cfg, out):
713
+ def _from_transformers_logprobs_history(self, td, cfg, out, logits_only=False):
694
714
  """Compute log-probs from history input."""
695
715
  from torchrl.data.llm import History
696
716
 
@@ -731,7 +751,9 @@ class TransformersWrapper(LLMWrapperBase):
731
751
  raise ValueError(
732
752
  f"Expected TensorDictBase for history input, got {type(response_tokens)}"
733
753
  )
734
- result = self._logprobs_from_history_tokens(response_tokens, cfg, out)
754
+ result = self._logprobs_from_history_tokens(
755
+ response_tokens, cfg, out, logits_only=logits_only
756
+ )
735
757
  text_result = Text._from_tensordict(result.empty())
736
758
  result.set(self.text_key, text_result)
737
759
  result[self.text_key, "full"] = text_full
@@ -952,7 +974,9 @@ class TransformersWrapper(LLMWrapperBase):
952
974
  result = result.to(cast)
953
975
  return result
954
976
 
955
- def _logprobs_from_history_tokens(self, response_tokens, cfg, out):
977
+ def _logprobs_from_history_tokens(
978
+ self, response_tokens, cfg, out, logits_only=False
979
+ ):
956
980
  """Compute log-probs from history tokens."""
957
981
  pad_val = self.tokenizer.pad_token_id
958
982
 
@@ -996,6 +1020,7 @@ class TransformersWrapper(LLMWrapperBase):
996
1020
  tokens_full_padded,
997
1021
  attention_mask_full_padded,
998
1022
  pad_val,
1023
+ logits_only=logits_only,
999
1024
  )
1000
1025
 
1001
1026
  # Build output TensorClass objects
@@ -1051,19 +1076,20 @@ class TransformersWrapper(LLMWrapperBase):
1051
1076
  tokens_obj.padded = MetaData(self.pad_output)
1052
1077
  out.set(self.tokens_key, tokens_obj)
1053
1078
 
1054
- log_probs_obj = LogProbs._from_tensordict(
1055
- TensorDict(batch_size=out.batch_size).to_lazystack(0)
1056
- )
1057
- if self.pad_output:
1058
- log_probs_obj.full = log_probs_full_padded
1059
- else:
1060
- log_probs_full_unpadded = _unpad_tensors(
1061
- log_probs_full_padded, attention_mask_full_padded, as_nested=False
1079
+ if not logits_only:
1080
+ log_probs_obj = LogProbs._from_tensordict(
1081
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1062
1082
  )
1063
- log_probs_obj.full = log_probs_full_unpadded
1064
- log_probs_obj.response = None
1065
- log_probs_obj.padded = MetaData(self.pad_output)
1066
- out.set(self.log_probs_key, log_probs_obj)
1083
+ if self.pad_output:
1084
+ log_probs_obj.full = log_probs_full_padded
1085
+ else:
1086
+ log_probs_full_unpadded = _unpad_tensors(
1087
+ log_probs_full_padded, attention_mask_full_padded, as_nested=False
1088
+ )
1089
+ log_probs_obj.full = log_probs_full_unpadded
1090
+ log_probs_obj.response = None
1091
+ log_probs_obj.padded = MetaData(self.pad_output)
1092
+ out.set(self.log_probs_key, log_probs_obj)
1067
1093
 
1068
1094
  # Add logits to output if we're in a get_dist call
1069
1095
  if self._in_get_dist_call:
@@ -1095,7 +1121,7 @@ class TransformersWrapper(LLMWrapperBase):
1095
1121
  raise ValueError(f"Expected list of text for text input, got {type(text)}")
1096
1122
  return self._generate_from_text(text, cfg, out)
1097
1123
 
1098
- def _from_transformers_logprobs_text(self, td, cfg, out):
1124
+ def _from_transformers_logprobs_text(self, td, cfg, out, logits_only=False):
1099
1125
  """Compute log-probs from text input."""
1100
1126
  # Validate input
1101
1127
  if self.input_key not in td:
@@ -1168,6 +1194,7 @@ class TransformersWrapper(LLMWrapperBase):
1168
1194
  input_ids_full_padded,
1169
1195
  attention_mask_full_padded,
1170
1196
  self.tokenizer.pad_token_id,
1197
+ logits_only=logits_only,
1171
1198
  )
1172
1199
 
1173
1200
  # Build output TensorClass objects
@@ -1212,19 +1239,20 @@ class TransformersWrapper(LLMWrapperBase):
1212
1239
  masks_obj.padded = MetaData(self.pad_output)
1213
1240
  out.set(self.masks_key, masks_obj)
1214
1241
 
1215
- log_probs_obj = LogProbs._from_tensordict(
1216
- TensorDict(batch_size=out.batch_size).to_lazystack(0)
1217
- )
1218
- if self.pad_output:
1219
- log_probs_obj.full = log_probs_full_padded
1220
- else:
1221
- log_probs_full_unpadded = _unpad_tensors(
1222
- log_probs_full_padded, attention_mask_full_padded, as_nested=False
1242
+ if not logits_only:
1243
+ log_probs_obj = LogProbs._from_tensordict(
1244
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1223
1245
  )
1224
- log_probs_obj.full = log_probs_full_unpadded
1225
- log_probs_obj.response = None
1226
- log_probs_obj.padded = MetaData(self.pad_output)
1227
- out.set(self.log_probs_key, log_probs_obj)
1246
+ if self.pad_output:
1247
+ log_probs_obj.full = log_probs_full_padded
1248
+ else:
1249
+ log_probs_full_unpadded = _unpad_tensors(
1250
+ log_probs_full_padded, attention_mask_full_padded, as_nested=False
1251
+ )
1252
+ log_probs_obj.full = log_probs_full_unpadded
1253
+ log_probs_obj.response = None
1254
+ log_probs_obj.padded = MetaData(self.pad_output)
1255
+ out.set(self.log_probs_key, log_probs_obj)
1228
1256
 
1229
1257
  # Add logits to output if we're in a get_dist call
1230
1258
  if self._in_get_dist_call:
@@ -1416,7 +1444,11 @@ class TransformersWrapper(LLMWrapperBase):
1416
1444
  return out
1417
1445
 
1418
1446
  def _from_transformers_logprobs_tokens(
1419
- self, td: TensorDictBase, cfg: dict | None, out: TensorDictBase
1447
+ self,
1448
+ td: TensorDictBase,
1449
+ cfg: dict | None,
1450
+ out: TensorDictBase,
1451
+ logits_only=False,
1420
1452
  ) -> TensorDictBase:
1421
1453
  """Compute log-probs from tokens input."""
1422
1454
  # Validate input
@@ -1470,6 +1502,7 @@ class TransformersWrapper(LLMWrapperBase):
1470
1502
  input_ids_full_padded,
1471
1503
  attention_mask_full_padded,
1472
1504
  self.tokenizer.pad_token_id,
1505
+ logits_only=logits_only,
1473
1506
  )
1474
1507
 
1475
1508
  # Build output TensorClass objects
@@ -1514,19 +1547,20 @@ class TransformersWrapper(LLMWrapperBase):
1514
1547
  masks_obj.padded = MetaData(self.pad_output)
1515
1548
  out.set(self.masks_key, masks_obj)
1516
1549
 
1517
- log_probs_obj = LogProbs._from_tensordict(
1518
- TensorDict(batch_size=out.batch_size).to_lazystack(0)
1519
- )
1520
- if self.pad_output:
1521
- log_probs_obj.full = log_probs_full_padded
1522
- else:
1523
- log_probs_full_unpadded = _unpad_tensors(
1524
- log_probs_full_padded, attention_mask_full_padded, as_nested=False
1550
+ if not logits_only:
1551
+ log_probs_obj = LogProbs._from_tensordict(
1552
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1525
1553
  )
1526
- log_probs_obj.full = log_probs_full_unpadded
1527
- log_probs_obj.response = None
1528
- log_probs_obj.padded = MetaData(self.pad_output)
1529
- out.set(self.log_probs_key, log_probs_obj)
1554
+ if self.pad_output:
1555
+ log_probs_obj.full = log_probs_full_padded
1556
+ else:
1557
+ log_probs_full_unpadded = _unpad_tensors(
1558
+ log_probs_full_padded, attention_mask_full_padded, as_nested=False
1559
+ )
1560
+ log_probs_obj.full = log_probs_full_unpadded
1561
+ log_probs_obj.response = None
1562
+ log_probs_obj.padded = MetaData(self.pad_output)
1563
+ out.set(self.log_probs_key, log_probs_obj)
1530
1564
 
1531
1565
  # Add logits to output if we're in a get_dist call
1532
1566
  if self._in_get_dist_call:
@@ -1567,7 +1601,7 @@ class TransformersWrapper(LLMWrapperBase):
1567
1601
  return log_probs, logits
1568
1602
 
1569
1603
  def _compute_log_probs_from_model_output(
1570
- self, model_output, input_ids, attention_mask, pad_val
1604
+ self, model_output, input_ids, attention_mask, pad_val, logits_only=False
1571
1605
  ):
1572
1606
  """Compute log-probs from model output without modifying original tensors.
1573
1607
 
@@ -1576,6 +1610,7 @@ class TransformersWrapper(LLMWrapperBase):
1576
1610
  input_ids: Original input token ids
1577
1611
  attention_mask: Original attention mask
1578
1612
  pad_val: Padding token value to ignore in loss computation
1613
+ logits_only: Whether to return only the logits.
1579
1614
 
1580
1615
  Returns:
1581
1616
  tuple: (log_probs, shifted_logits) where log_probs are the computed log probabilities
@@ -1600,6 +1635,8 @@ class TransformersWrapper(LLMWrapperBase):
1600
1635
  raise ValueError(
1601
1636
  f"The logits shape {shifted_logits.shape} does not match the input ids shape {shifted_input_ids.shape}"
1602
1637
  )
1638
+ if logits_only:
1639
+ return None, shifted_logits
1603
1640
 
1604
1641
  # Compute log-probs
1605
1642
  td = TensorDict(
@@ -11,6 +11,7 @@ from typing import Any, Literal
11
11
  import torch
12
12
  from tensordict import (
13
13
  lazy_stack,
14
+ LazyStackedTensorDict,
14
15
  MetaData,
15
16
  NonTensorStack,
16
17
  set_list_to_stack,
@@ -500,19 +501,32 @@ class vLLMWrapper(LLMWrapperBase):
500
501
  def forward(
501
502
  self,
502
503
  tensordict: TensorDictBase,
504
+ *,
503
505
  tensordict_out: TensorDictBase | None = None,
506
+ logits_only: bool = False,
504
507
  **kwargs,
505
508
  ) -> TensorDictBase:
509
+ tensordict_orig = tensordict
506
510
  if not tensordict.ndim:
511
+ if tensordict_out is not None:
512
+ raise ValueError(
513
+ "tensordict_out must not be provided when tensordict.ndim == 0. If this is needed, "
514
+ "please submit an issue on github."
515
+ )
507
516
  # unsqueeze - squeeze the input
508
- try:
509
- return self(lazy_stack([tensordict])).squeeze(0)
510
- except Exception as e:
511
- raise RuntimeError(
512
- f"Unsqueeze/squeeze failed. Inputs to {type(self).__name__} should ideally be 1 dimensional."
513
- ) from e
517
+ return self.forward(lazy_stack([tensordict]), logits_only=logits_only)[0]
514
518
  elif tensordict.ndim > 1:
515
- return self(tensordict.reshape(-1)).view(tensordict.shape)
519
+ if tensordict_out is not None:
520
+ raise ValueError(
521
+ "tensordict_out must not be provided when tensordict.ndim > 1. If this is needed, "
522
+ "please submit an issue on github."
523
+ )
524
+ return self.forward(tensordict.reshape(-1), logits_only=logits_only).view(
525
+ tensordict.shape
526
+ )
527
+
528
+ if not isinstance(tensordict, LazyStackedTensorDict):
529
+ tensordict = tensordict.to_lazystack(0)
516
530
 
517
531
  _source_device = None
518
532
  if self._device:
@@ -567,7 +581,7 @@ class vLLMWrapper(LLMWrapperBase):
567
581
  if tensordict_out is None:
568
582
  if self.inplace is True:
569
583
  # The output is the input
570
- tensordict_out = tensordict
584
+ tensordict_out = tensordict_orig
571
585
  elif self.inplace is False:
572
586
  # The output is the new structure
573
587
  tensordict_out = out
@@ -1242,12 +1256,14 @@ class vLLMWrapper(LLMWrapperBase):
1242
1256
 
1243
1257
  generate_kwargs = {"sampling_params": sampling_params}
1244
1258
  args = ()
1259
+ empirical_attention_mask = None
1245
1260
 
1246
1261
  if tokens_prompt_unpadded is None:
1247
1262
  # TODO: To be on the safe side, we may do this even in the unpadded case since we're not sure
1248
1263
  # the user passed an unpadded tensor in the first place.
1264
+ empirical_attention_mask = tokens_prompt_padded != self.padding_value
1249
1265
  tokens_prompt_list = self._to_list(
1250
- tokens_prompt_padded, tokens_prompt_padded != self.padding_value
1266
+ tokens_prompt_padded, empirical_attention_mask
1251
1267
  )
1252
1268
  else:
1253
1269
  tokens_prompt_list = self._to_list(tokens_prompt_unpadded, None)
@@ -1365,6 +1381,22 @@ class vLLMWrapper(LLMWrapperBase):
1365
1381
  padding_value=self.padding_value,
1366
1382
  padding_side="right",
1367
1383
  )
1384
+ if (
1385
+ prompt_logprobs_padded.shape[-1]
1386
+ != tokens_prompt_padded.shape[-1]
1387
+ ):
1388
+ tshape = tokens_prompt_padded.shape
1389
+ oshape = prompt_logprobs_padded.shape
1390
+ # it could be that the input was padded already - padding again then
1391
+ prompt_logprobs_padded = torch.cat(
1392
+ [
1393
+ prompt_logprobs_padded.new_zeros(
1394
+ tshape[:-1] + (tshape[-1] - oshape[-1],)
1395
+ ),
1396
+ prompt_logprobs_padded,
1397
+ ],
1398
+ -1,
1399
+ )
1368
1400
  else:
1369
1401
  prompt_logprobs_list = request_output_tc.get(
1370
1402
  "prompt_logprobs",
@@ -1490,26 +1522,21 @@ class vLLMWrapper(LLMWrapperBase):
1490
1522
 
1491
1523
  request_output_tc = _RequestOutput_tc.from_request_output(tokens_out_stuct)
1492
1524
 
1525
+ # For unpadded case, extract from each sequence
1526
+ log_probs_full_unpadded = request_output_tc.get("prompt_logprobs", as_list=True)
1527
+
1493
1528
  # Extract log-probs from prompt_logprobs
1494
1529
  if self.pad_output:
1495
1530
  # For padded case, use all prompt_logprobs
1496
- log_probs_full_padded = request_output_tc.get(
1497
- "prompt_logprobs",
1498
- as_padded_tensor=True,
1499
- padding_value=0,
1500
- padding_side="left",
1531
+ if attention_mask_full_padded is not None:
1532
+ attention_mask_full_padded = tokens_full_padded != self.padding_value
1533
+ log_probs_full_padded = torch.zeros_like(
1534
+ tokens_full_padded, dtype=torch.get_default_dtype()
1501
1535
  )
1502
-
1503
- # Mask out padding
1504
- attention_mask_full_padded = tokens_full_padded != self.padding_value
1505
- log_probs_full_padded = torch.where(
1506
- attention_mask_full_padded, log_probs_full_padded, 0.0
1536
+ log_probs_full_padded[attention_mask_full_padded] = torch.cat(
1537
+ log_probs_full_unpadded, -1
1507
1538
  )
1508
1539
  else:
1509
- # For unpadded case, extract from each sequence
1510
- log_probs_full_unpadded = request_output_tc.get(
1511
- "prompt_logprobs", as_list=True
1512
- )
1513
1540
  self._check_not_padded(log_probs_full_unpadded)
1514
1541
 
1515
1542
  assistant_mask_full_padded = None
torchrl/version.py CHANGED
@@ -1,2 +1,2 @@
1
- __version__ = '2025.7.16'
2
- git_version = '361a8da6edc77979e17409cf19396230d18c18a9'
1
+ __version__ = '2025.7.17'
2
+ git_version = 'bec8f0382b9694a87c04385f55fc9f8f3ee1724f'
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: torchrl-nightly
3
- Version: 2025.7.16
3
+ Version: 2025.7.17
4
4
  Summary: UNKNOWN
5
5
  Home-page: https://github.com/pytorch/rl
6
6
  Author: torchrl contributors
@@ -3,11 +3,11 @@ build_tools/setup_helpers/__init__.py,sha256=l9zlK7Nm5bT7P_onQx-hZeIGzKKyCFm1PFk
3
3
  build_tools/setup_helpers/extension.py,sha256=ihV8jz8kqOvpqzuD006XqF1oNX5ukKGlwIOJRb1Vd-o,6075
4
4
  torchrl/__init__.py,sha256=76lKYwYKmAKORhyVt2tURmYAIRTifxxO3gWsskrHAXU,3054
5
5
  torchrl/_extension.py,sha256=x6Nqj2brF3VhlEwxmNA2fYbmpxq1HHGrHMnP0YnQwdc,2412
6
- torchrl/_torchrl.cp310-win_amd64.pyd,sha256=MHnLXT4hRARJhDr5PwHXmfL7xW3wUX83LHClN4P4Kmo,440832
6
+ torchrl/_torchrl.cp310-win_amd64.pyd,sha256=EmXEHv8OHR8lomhJKYRoDdRNcjSZrLKyQdUKYT1-BR4,440832
7
7
  torchrl/_utils.py,sha256=2N35rdD65U1khMi5gVIz8-nMjlZsoVq0kCiQftVRSxw,42297
8
- torchrl/version.py,sha256=8tM5vPhf-adaEbP-CudaaDTPQwY79ENoxHk7DIqTeDk,85
8
+ torchrl/version.py,sha256=hd0Oai-1iVD9JdvzREmm-B6pS3vtNIorROtxZuyZf7A,85
9
9
  torchrl/collectors/__init__.py,sha256=LzTyfxmkNGPSa5-3rS5unQK7HfT5ZEdr2NV291rAOlU,832
10
- torchrl/collectors/collectors.py,sha256=i-7ANxLstwaj4ruTkxFvp4YV42oHm_9M95_uPJZZock,181631
10
+ torchrl/collectors/collectors.py,sha256=UbXtDMHrXVQ-cd95TBpss2SIbNFKze2HTPYjFz2cPNQ,182146
11
11
  torchrl/collectors/utils.py,sha256=aBmBLpphhfplqQjRCyn1jtWWJ-Wtc7TWvM0rOBN8SsE,11579
12
12
  torchrl/collectors/weight_update.py,sha256=Ydq5nJSTV3Q1uqLtJ_1Nj1JB5rwHwrG5StaLxymWFV4,21572
13
13
  torchrl/collectors/distributed/__init__.py,sha256=cKDWdNlwx2LoJkTwf-DKUXbq3Y-0Z1DctPYPcdgOSU0,730
@@ -139,7 +139,7 @@ torchrl/envs/transforms/vip.py,sha256=r8Ni0hAYY1gispLj0TXV2VIedrgC4eW3hAhJBv47Q7
139
139
  torchrl/modules/__init__.py,sha256=TuJj3WUlvilYY39nUH-ykXkyprTxjq9NLW0QQXADqJk,4343
140
140
  torchrl/modules/distributions/__init__.py,sha256=Evkiz96ZPs7VUZp2n03h9kd7rmUCEEvMVl2f7RhzMhQ,1670
141
141
  torchrl/modules/distributions/continuous.py,sha256=tahVKeI_uFgnVmskJ-_NsXhWsSR8sr7FJ_281rCq4LE,26434
142
- torchrl/modules/distributions/discrete.py,sha256=HSuaJ0O71eBlToGDh31FRxGDO62k_Gf7iT9qgJjna84,36462
142
+ torchrl/modules/distributions/discrete.py,sha256=QXDv-nllK6i1tXj0KiP6TnHDjrI05YRp05D3mBin6Pc,36488
143
143
  torchrl/modules/distributions/truncated_normal.py,sha256=l5G3TePasl7q12DjwisyQC_E0OfZZo2g_HzBhZREVxc,6122
144
144
  torchrl/modules/distributions/utils.py,sha256=q4AFDKFpacRhrl4rjJ54UhxQzjOcj_SKlz0UIcZlUVc,7796
145
145
  torchrl/modules/llm/__init__.py,sha256=_gH2JzO4sXWYIyDtPaGvrPJCBCGCRA5T0SXZtETeeoQ,775
@@ -147,9 +147,9 @@ torchrl/modules/llm/utils.py,sha256=b2s9ngHwXnNbLggygU3-ScNwk0MWICketq2pZBshGqM,
147
147
  torchrl/modules/llm/backends/__init__.py,sha256=ABKK4mJeRtoLXEqfnMvIuiovs7VJoCxnDDo6QYvPMVk,457
148
148
  torchrl/modules/llm/backends/vllm.py,sha256=5P78jEtAIytgYHzEkOrg-wwqh1ryhiMVy4M_AxNQ9JQ,9649
149
149
  torchrl/modules/llm/policies/__init__.py,sha256=x5gk4ja20-yjsPHY0F_Ymw1G7u4mCDULwxnTAaRTJN8,567
150
- torchrl/modules/llm/policies/common.py,sha256=Aev4EKEogWFr4C7wPqFZ8lmk041ZzrtoFx3QVVverc0,39068
151
- torchrl/modules/llm/policies/transformers_wrapper.py,sha256=EgXlpxue2K4cAUCabrCKgTLhFRosg-OcakuITstL2Zw,76137
152
- torchrl/modules/llm/policies/vllm_wrapper.py,sha256=tNil8XybcGQaVBW5q81MDimXjuLYZTUbBbVG4jNYPuc,80114
150
+ torchrl/modules/llm/policies/common.py,sha256=qFc1Di76qFjTvf38_FfpVKZsz4d4Nva2tOFk9F9vUOM,40085
151
+ torchrl/modules/llm/policies/transformers_wrapper.py,sha256=G4nZbtqcEch1BD3URWfn0pwiZtNF7O1f6P5qpotTJVc,77625
152
+ torchrl/modules/llm/policies/vllm_wrapper.py,sha256=WRB1t-7_CcXn0JmQpRSUikLhgPe6CoasHYfIlgOXx-Q,81542
153
153
  torchrl/modules/models/__init__.py,sha256=Y1XTkBOB5EMj6IaMru6V3CDwFLnkUtxzsHcqzeqq_4Y,1829
154
154
  torchrl/modules/models/batchrenorm.py,sha256=bR4ZhaJ5E1cSK5o8L2dNX5KVLIb-bgrYxcq6yhx0I1A,4869
155
155
  torchrl/modules/models/decision_transformer.py,sha256=ANFTOm3k9_3Uv1vKGdXumRy3meBPnDdT8HqhVvJ2RCo,6783
@@ -223,8 +223,8 @@ torchrl/trainers/helpers/losses.py,sha256=HwrovwbMOhY-5-hlOz-YHclKnoJhMijVjDNuAT
223
223
  torchrl/trainers/helpers/models.py,sha256=VujBq9H92sEzpCtU1iTrJQNlwvyOO-Rho4bzsMonX6s,22465
224
224
  torchrl/trainers/helpers/replay_buffer.py,sha256=RaZqXnHimmadiibvDBcLbtIhpPaVMTPhYMOBvX4v3CA,2060
225
225
  torchrl/trainers/helpers/trainers.py,sha256=hB1FtHtP-S0PBQ4LF6WPy37caaLpacyaLThj1BNl5Ho,12372
226
- torchrl_nightly-2025.7.16.dist-info/LICENSE,sha256=PGO-oZsq4EzhE1-WQS2xGiEF3UCVb9YawfQ09cIMV_8,1119
227
- torchrl_nightly-2025.7.16.dist-info/METADATA,sha256=35ji9dwgAOpYAOrjBU-SAq54qS_DfvkVXzrs0I0xQGQ,44000
228
- torchrl_nightly-2025.7.16.dist-info/WHEEL,sha256=NVXpD7b4Gxps0cd2ds5rr5TG8W4ApEwx_i5J99qMZ5E,102
229
- torchrl_nightly-2025.7.16.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
230
- torchrl_nightly-2025.7.16.dist-info/RECORD,,
226
+ torchrl_nightly-2025.7.17.dist-info/LICENSE,sha256=PGO-oZsq4EzhE1-WQS2xGiEF3UCVb9YawfQ09cIMV_8,1119
227
+ torchrl_nightly-2025.7.17.dist-info/METADATA,sha256=SaETkAw3q6ZUIazIqAwcLeC4WvSoghlUNoSQlhQ4nXQ,44000
228
+ torchrl_nightly-2025.7.17.dist-info/WHEEL,sha256=NVXpD7b4Gxps0cd2ds5rr5TG8W4ApEwx_i5J99qMZ5E,102
229
+ torchrl_nightly-2025.7.17.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
230
+ torchrl_nightly-2025.7.17.dist-info/RECORD,,