torchrl-nightly 2025.7.15__cp313-cp313-macosx_10_13_universal2.whl → 2025.7.18__cp313-cp313-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Binary file
@@ -686,6 +686,10 @@ class SyncDataCollector(DataCollectorBase):
686
686
  policy = RandomPolicy(env.full_action_spec)
687
687
  elif policy_factory is not None:
688
688
  raise TypeError("policy_factory cannot be used with policy argument.")
689
+ # If the underlying policy has a state_dict, we keep a reference to the policy and
690
+ # do all policy weight saving/loading through it
691
+ if hasattr(policy, "state_dict"):
692
+ self._policy_w_state_dict = policy
689
693
 
690
694
  if trust_policy is None:
691
695
  trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule))
@@ -1686,8 +1690,8 @@ class SyncDataCollector(DataCollectorBase):
1686
1690
  else:
1687
1691
  env_state_dict = OrderedDict()
1688
1692
 
1689
- if hasattr(self.policy, "state_dict"):
1690
- policy_state_dict = self.policy.state_dict()
1693
+ if hasattr(self, "_policy_w_state_dict"):
1694
+ policy_state_dict = self._policy_w_state_dict.state_dict()
1691
1695
  state_dict = OrderedDict(
1692
1696
  policy_state_dict=policy_state_dict,
1693
1697
  env_state_dict=env_state_dict,
@@ -1711,7 +1715,13 @@ class SyncDataCollector(DataCollectorBase):
1711
1715
  if strict or "env_state_dict" in state_dict:
1712
1716
  self.env.load_state_dict(state_dict["env_state_dict"], **kwargs)
1713
1717
  if strict or "policy_state_dict" in state_dict:
1714
- self.policy.load_state_dict(state_dict["policy_state_dict"], **kwargs)
1718
+ if not hasattr(self, "_policy_w_state_dict"):
1719
+ raise ValueError(
1720
+ "Underlying policy does not have state_dict to load policy_state_dict into."
1721
+ )
1722
+ self._policy_w_state_dict.load_state_dict(
1723
+ state_dict["policy_state_dict"], **kwargs
1724
+ )
1715
1725
  self._frames = state_dict["frames"]
1716
1726
  self._iter = state_dict["iter"]
1717
1727
 
@@ -713,6 +713,42 @@ class History(TensorClass["nocast"]):
713
713
  | transformers.AutoProcessor # noqa: F821
714
714
  | None = None,
715
715
  ) -> History:
716
+ r"""Inverts a chat template into a History object.
717
+
718
+ Args:
719
+ text (str | list[str]): The chat template to invert.
720
+ chat_template_name (str, optional): The name of the chat template to use.
721
+ tokenizer (transformers.AutoTokenizer | transformers.AutoProcessor, optional): The tokenizer to use.
722
+
723
+ Returns:
724
+ History: The inverted History object.
725
+
726
+ Examples:
727
+ >>> from torchrl.data.llm.history import History
728
+ >>> from transformers import AutoTokenizer
729
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
730
+ >>> text = "<|im_start|>system\nYou are a helpful assistant.\n<|im_end|>\n<|im_start|>user\nWrite a python script that gives the capital of France or Germany.\n<|im_end|>\n<|im_start|>assistant\n<think>The capital of France is Paris, the capital of Germany is Berlin.</think>\n<answer><python>\n"
731
+ >>> history = History.from_text(text, tokenizer=tokenizer)
732
+ >>> print(history)
733
+ History(
734
+ content=NonTensorStack(
735
+ ['You are a helpful assistant.', 'Write a python s...,
736
+ batch_size=torch.Size([3]),
737
+ device=None),
738
+ is_complete=NonTensorStack(
739
+ [True, True, False],
740
+ batch_size=torch.Size([3]),
741
+ device=None),
742
+ role=NonTensorStack(
743
+ ['system', 'user', 'assistant'],
744
+ batch_size=torch.Size([3]),
745
+ device=None),
746
+ tool_calls=None,
747
+ tool_responses=None,
748
+ batch_size=torch.Size([3]),
749
+ device=None,
750
+ is_shared=False)
751
+ """
716
752
  if chat_template_name is None:
717
753
  if chat_template is not None:
718
754
  # TODO: find best match given template
@@ -4449,12 +4449,18 @@ class Binary(Categorical):
4449
4449
  f"shape of the {self.__class__.__name__} spec in expand()."
4450
4450
  )
4451
4451
  return self.__class__(
4452
- n=self.shape[-1], shape=shape, device=self.device, dtype=self.dtype
4452
+ n=self.shape[-1] if len(self.shape) > 0 else None,
4453
+ shape=shape,
4454
+ device=self.device,
4455
+ dtype=self.dtype,
4453
4456
  )
4454
4457
 
4455
4458
  def _reshape(self, shape):
4456
4459
  return self.__class__(
4457
- n=self.shape[-1], shape=shape, device=self.device, dtype=self.dtype
4460
+ n=self.shape[-1] if len(self.shape) > 0 else None,
4461
+ shape=shape,
4462
+ device=self.device,
4463
+ dtype=self.dtype,
4458
4464
  )
4459
4465
 
4460
4466
  def _unflatten(self, dim, sizes):
@@ -4464,7 +4470,10 @@ class Binary(Categorical):
4464
4470
  .shape
4465
4471
  )
4466
4472
  return self.__class__(
4467
- n=self.shape[-1], shape=shape, device=self.device, dtype=self.dtype
4473
+ n=self.shape[-1] if len(self.shape) > 0 else None,
4474
+ shape=shape,
4475
+ device=self.device,
4476
+ dtype=self.dtype,
4468
4477
  )
4469
4478
 
4470
4479
  def squeeze(self, dim=None):
@@ -4472,13 +4481,19 @@ class Binary(Categorical):
4472
4481
  if shape is None:
4473
4482
  return self
4474
4483
  return self.__class__(
4475
- n=self.shape[-1], shape=shape, device=self.device, dtype=self.dtype
4484
+ n=self.shape[-1] if len(self.shape) > 0 else None,
4485
+ shape=shape,
4486
+ device=self.device,
4487
+ dtype=self.dtype,
4476
4488
  )
4477
4489
 
4478
4490
  def unsqueeze(self, dim: int):
4479
4491
  shape = _unsqueezed_shape(self.shape, dim)
4480
4492
  return self.__class__(
4481
- n=self.shape[-1], shape=shape, device=self.device, dtype=self.dtype
4493
+ n=self.shape[-1] if len(self.shape) > 0 else None,
4494
+ shape=shape,
4495
+ device=self.device,
4496
+ dtype=self.dtype,
4482
4497
  )
4483
4498
 
4484
4499
  def unbind(self, dim: int = 0):
@@ -4495,7 +4510,10 @@ class Binary(Categorical):
4495
4510
  shape = tuple(s for i, s in enumerate(self.shape) if i != dim)
4496
4511
  return tuple(
4497
4512
  self.__class__(
4498
- n=self.shape[-1], shape=shape, device=self.device, dtype=self.dtype
4513
+ n=self.shape[-1] if len(self.shape) > 0 else None,
4514
+ shape=shape,
4515
+ device=self.device,
4516
+ dtype=self.dtype,
4499
4517
  )
4500
4518
  for i in range(self.shape[dim])
4501
4519
  )
@@ -4512,12 +4530,15 @@ class Binary(Categorical):
4512
4530
  if dest_device == self.device and dest_dtype == self.dtype:
4513
4531
  return self
4514
4532
  return self.__class__(
4515
- n=self.shape[-1], shape=self.shape, device=dest_device, dtype=dest_dtype
4533
+ n=self.shape[-1] if len(self.shape) > 0 else None,
4534
+ shape=self.shape,
4535
+ device=dest_device,
4536
+ dtype=dest_dtype,
4516
4537
  )
4517
4538
 
4518
4539
  def clone(self) -> Binary:
4519
4540
  return self.__class__(
4520
- n=self.shape[-1],
4541
+ n=self.shape[-1] if len(self.shape) > 0 else None,
4521
4542
  shape=self.shape,
4522
4543
  device=self.device,
4523
4544
  dtype=self.dtype,
@@ -4528,6 +4549,8 @@ class Binary(Categorical):
4528
4549
 
4529
4550
  The last dimension of the spec (length n of the binary vector) cannot be indexed.
4530
4551
  """
4552
+ if not len(self.shape):
4553
+ raise ValueError("Cannot index a Binary spec with an empty shape")
4531
4554
  indexed_shape = _shape_indexing(self.shape[:-1], idx)
4532
4555
  return self.__class__(
4533
4556
  n=self.shape[-1],
@@ -5533,8 +5556,10 @@ class Composite(TensorSpec):
5533
5556
  sub_str = [
5534
5557
  indent(f"{k}: {str(item)}", 4 * " ") for k, item in self._specs.items()
5535
5558
  ]
5559
+ if len(sub_str) == 0:
5560
+ return f"{self.__class__.__name__}(device={self._device}, shape={self.shape}, data_cls={self.data_cls})"
5536
5561
  sub_str = ",\n".join(sub_str)
5537
- return f"Composite(\n{sub_str},\n device={self._device},\n shape={self.shape})"
5562
+ return f"{self.__class__.__name__}(\n{sub_str},\n device={self._device},\n shape={self.shape},\n data_cls={self.data_cls})"
5538
5563
 
5539
5564
  def type_check(
5540
5565
  self,
@@ -1211,7 +1211,6 @@ but got an object of type {type(transform)}."""
1211
1211
  if tensordict is not None:
1212
1212
  # We must avoid modifying the original tensordict so a shallow copy is necessary.
1213
1213
  # We just select the input data and reset signal, which is all we need.
1214
- self.transform.transform_input_spec(self.base_env.input_spec.unlock_())
1215
1214
  tensordict = tensordict.select(
1216
1215
  *self.reset_keys, *self.state_spec.keys(True, True), strict=False
1217
1216
  )
@@ -352,7 +352,7 @@ class MaskedCategorical(D.Categorical):
352
352
  logits = self.logits
353
353
  if logits.ndim > 2:
354
354
  # Bring channels in 2nd dim
355
- logits = logits.transpose(-1, 1)
355
+ logits = logits.permute(0, -1, *range(1, logits.ndim - 1))
356
356
  original_value_shape = None
357
357
  if logits.ndim == 1 and value.ndim >= 1:
358
358
  if value.ndim >= 2:
@@ -4,12 +4,13 @@
4
4
  # LICENSE file in the root directory of this source tree.
5
5
  from __future__ import annotations
6
6
 
7
+ import warnings
7
8
  import weakref
8
9
  from typing import Any, Literal, overload
9
10
 
10
11
  import torch
11
- from tensordict import NestedKey, TensorDictBase
12
- from tensordict.nn import TensorDictModuleBase, TensorDictSequential
12
+ from tensordict import lazy_stack, NestedKey, TensorDictBase
13
+ from tensordict.nn import TensorDictModuleBase
13
14
  from tensordict.tensorclass import TensorClass
14
15
  from tensordict.utils import _zip_strict
15
16
  from torch import distributions as D
@@ -171,6 +172,39 @@ class ChatHistory(TensorClass["nocast"]):
171
172
  step_mdp_static=True,
172
173
  )
173
174
 
175
+ def __post_init__(self):
176
+ # Check that all history objects have one more batch dimension than the ChatHistory object
177
+ if self.prompt is not None:
178
+ if getattr(self.prompt, "batch_dims", None) == self.batch_dims:
179
+ warnings.warn(
180
+ "Prompt history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
181
+ f"got {self.prompt.batch_dims} and {self.batch_dims}. "
182
+ "The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
183
+ )
184
+ self.prompt = lazy_stack(
185
+ [self.prompt], -1
186
+ ) # equivalent to unsqueeze(-1) but make sure it's a lazy stack
187
+ if self.response is not None:
188
+ if getattr(self.response, "batch_dims", None) == self.batch_dims:
189
+ warnings.warn(
190
+ "Response history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
191
+ f"got {self.response.batch_dims} and {self.batch_dims}. "
192
+ "The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
193
+ )
194
+ self.response = lazy_stack(
195
+ [self.response], -1
196
+ ) # equivalent to unsqueeze(-1) but make sure it's a lazy stack
197
+ if self.full is not None:
198
+ if getattr(self.full, "batch_dims", None) == self.batch_dims:
199
+ warnings.warn(
200
+ "Full history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
201
+ f"got {self.full.batch_dims} and {self.batch_dims}. "
202
+ "The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
203
+ )
204
+ self.full = lazy_stack(
205
+ [self.full], -1
206
+ ) # equivalent to unsqueeze(-1) but make sure it's a lazy stack
207
+
174
208
 
175
209
  class LogProbs(TensorClass["nocast"]):
176
210
  """A log-probability container.
@@ -454,7 +488,7 @@ class LLMWrapperBase(TensorDictModuleBase):
454
488
  "You can create a new version of this wrapper using the `get_new_version` method."
455
489
  )
456
490
 
457
- td_out = self(tensordict.copy())
491
+ td_out = self.forward(tensordict.copy(), logits_only=True)
458
492
 
459
493
  # Get logits/log-probs
460
494
  if as_padded_tensor is None:
@@ -529,7 +563,7 @@ class LLMWrapperBase(TensorDictModuleBase):
529
563
  "get_dist_with_prompt_mask is not implemented for generate=True. "
530
564
  "You can create a new version of this wrapper using the `get_new_version` method."
531
565
  )
532
- td_out = self(tensordict.copy())
566
+ td_out = self.forward(tensordict.copy(), logits_only=True)
533
567
 
534
568
  # Try to get prompt tokens first
535
569
  if self.pad_output:
@@ -640,7 +674,7 @@ class LLMWrapperBase(TensorDictModuleBase):
640
674
  "get_dist_with_assistant_mask is not implemented for generate=True. "
641
675
  "You can create a new version of this wrapper using the `get_new_version` method."
642
676
  )
643
- td_out = self(tensordict.copy())
677
+ td_out = self.forward(tensordict.copy(), logits_only=True)
644
678
  # Update the tokens key to reflect the tokenized history when querying the log-probs
645
679
  tensordict.update(
646
680
  td_out,
@@ -709,7 +743,7 @@ class LLMWrapperBase(TensorDictModuleBase):
709
743
  "get_dist_with_attention_mask is not implemented for generate=True. "
710
744
  "You can create a new version of this wrapper using the `get_new_version` method."
711
745
  )
712
- td_out = self(tensordict.copy())
746
+ td_out = self.forward(tensordict.copy(), logits_only=True)
713
747
  if self.pad_output:
714
748
  logits = td_out.get(logits_key)
715
749
  attention_mask = td_out.get(attention_mask_key)
@@ -766,7 +800,7 @@ class LLMWrapperBase(TensorDictModuleBase):
766
800
  "get_dist_with_custom_mask is not implemented for generate=True. "
767
801
  "You can create a new version of this wrapper using the `get_new_version` method."
768
802
  )
769
- td_out = self(tensordict.copy())
803
+ td_out = self.forward(tensordict.copy(), logits_only=True)
770
804
  if self.pad_output:
771
805
  logits = td_out.get(logits_key)
772
806
  else:
@@ -813,8 +847,24 @@ class LLMWrapperBase(TensorDictModuleBase):
813
847
  """
814
848
  return self._get_dist_with_attention_mask(tensordict, **kwargs)
815
849
 
816
- # Sampling is taken care of by the sub-modules
817
- forward = TensorDictSequential.forward
850
+ def forward(
851
+ self,
852
+ tensordict: TensorDictBase,
853
+ *,
854
+ tensordict_out: TensorDictBase | None = None,
855
+ logits_only: bool = False,
856
+ **kwargs,
857
+ ) -> TensorDictBase: # noqa: D417
858
+ """Forward pass for the LLM policy.
859
+
860
+ Args:
861
+ tensordict (TensorDictBase): The input tensordict.
862
+
863
+ Keyword Args:
864
+ tensordict_out (TensorDictBase | None): The output tensordict.
865
+ logits_only (bool): Whether to return only the logits. Only effective if generate=False. Defaults to `False`.
866
+ """
867
+ raise NotImplementedError
818
868
 
819
869
  def _check_padded(self, val: torch.Tensor) -> torch.Tensor:
820
870
  """Check that a value is a padded tensor."""
@@ -13,6 +13,7 @@ from typing import Literal
13
13
  import torch
14
14
  from tensordict import (
15
15
  lazy_stack,
16
+ LazyStackedTensorDict,
16
17
  MetaData,
17
18
  NonTensorStack,
18
19
  set_list_to_stack,
@@ -468,19 +469,32 @@ class TransformersWrapper(LLMWrapperBase):
468
469
  def forward(
469
470
  self,
470
471
  tensordict: TensorDictBase,
472
+ *,
471
473
  tensordict_out: TensorDictBase | None = None,
474
+ logits_only: bool = False,
472
475
  **kwargs,
473
476
  ) -> TensorDictBase:
477
+ tensordict_orig = tensordict
474
478
  if not tensordict.ndim:
479
+ if tensordict_out is not None:
480
+ raise ValueError(
481
+ "tensordict_out must not be provided when tensordict.ndim == 0. If this is needed, "
482
+ "please submit an issue on github."
483
+ )
475
484
  # unsqueeze - squeeze the input
476
- try:
477
- return self(lazy_stack([tensordict])).squeeze(0)
478
- except Exception as e:
479
- raise RuntimeError(
480
- f"Unsqueeze/squeeze failed. Inputs to {type(self).__name__} should ideally be 1 dimensional."
481
- ) from e
485
+ return self.forward(lazy_stack([tensordict]), logits_only=logits_only)[0]
482
486
  elif tensordict.ndim > 1:
483
- return self(tensordict.reshape(-1)).view(tensordict.shape)
487
+ if tensordict_out is not None:
488
+ raise ValueError(
489
+ "tensordict_out must not be provided when tensordict.ndim > 1. If this is needed, "
490
+ "please submit an issue on github."
491
+ )
492
+ return self.forward(tensordict.reshape(-1), logits_only=logits_only).view(
493
+ tensordict.shape
494
+ )
495
+
496
+ if not isinstance(tensordict, LazyStackedTensorDict):
497
+ tensordict = tensordict.to_lazystack(0)
484
498
 
485
499
  _source_device = None
486
500
  if self._device:
@@ -517,17 +531,23 @@ class TransformersWrapper(LLMWrapperBase):
517
531
  if self.generate:
518
532
  out = self._from_transformers_generate_history(tensordict, cfg, out)
519
533
  else:
520
- out = self._from_transformers_logprobs_history(tensordict, cfg, out)
534
+ out = self._from_transformers_logprobs_history(
535
+ tensordict, cfg, out, logits_only=logits_only
536
+ )
521
537
  elif self.input_mode == "text":
522
538
  if self.generate:
523
539
  out = self._from_transformers_generate_text(tensordict, cfg, out)
524
540
  else:
525
- out = self._from_transformers_logprobs_text(tensordict, cfg, out)
541
+ out = self._from_transformers_logprobs_text(
542
+ tensordict, cfg, out, logits_only=logits_only
543
+ )
526
544
  elif self.input_mode == "tokens":
527
545
  if self.generate:
528
546
  out = self._from_transformers_generate_tokens(tensordict, cfg, out)
529
547
  else:
530
- out = self._from_transformers_logprobs_tokens(tensordict, cfg, out)
548
+ out = self._from_transformers_logprobs_tokens(
549
+ tensordict, cfg, out, logits_only=logits_only
550
+ )
531
551
 
532
552
  if _source_device:
533
553
  out = out.to(_source_device)
@@ -535,7 +555,7 @@ class TransformersWrapper(LLMWrapperBase):
535
555
  if tensordict_out is None:
536
556
  if self.inplace is True:
537
557
  # The output is the input
538
- tensordict_out = tensordict
558
+ tensordict_out = tensordict_orig
539
559
  elif self.inplace is False:
540
560
  # The output is the new structure
541
561
  tensordict_out = out
@@ -690,7 +710,7 @@ class TransformersWrapper(LLMWrapperBase):
690
710
  result.set(self.history_key, history_chat)
691
711
  return result
692
712
 
693
- def _from_transformers_logprobs_history(self, td, cfg, out):
713
+ def _from_transformers_logprobs_history(self, td, cfg, out, logits_only=False):
694
714
  """Compute log-probs from history input."""
695
715
  from torchrl.data.llm import History
696
716
 
@@ -731,7 +751,9 @@ class TransformersWrapper(LLMWrapperBase):
731
751
  raise ValueError(
732
752
  f"Expected TensorDictBase for history input, got {type(response_tokens)}"
733
753
  )
734
- result = self._logprobs_from_history_tokens(response_tokens, cfg, out)
754
+ result = self._logprobs_from_history_tokens(
755
+ response_tokens, cfg, out, logits_only=logits_only
756
+ )
735
757
  text_result = Text._from_tensordict(result.empty())
736
758
  result.set(self.text_key, text_result)
737
759
  result[self.text_key, "full"] = text_full
@@ -952,7 +974,9 @@ class TransformersWrapper(LLMWrapperBase):
952
974
  result = result.to(cast)
953
975
  return result
954
976
 
955
- def _logprobs_from_history_tokens(self, response_tokens, cfg, out):
977
+ def _logprobs_from_history_tokens(
978
+ self, response_tokens, cfg, out, logits_only=False
979
+ ):
956
980
  """Compute log-probs from history tokens."""
957
981
  pad_val = self.tokenizer.pad_token_id
958
982
 
@@ -996,6 +1020,7 @@ class TransformersWrapper(LLMWrapperBase):
996
1020
  tokens_full_padded,
997
1021
  attention_mask_full_padded,
998
1022
  pad_val,
1023
+ logits_only=logits_only,
999
1024
  )
1000
1025
 
1001
1026
  # Build output TensorClass objects
@@ -1051,19 +1076,20 @@ class TransformersWrapper(LLMWrapperBase):
1051
1076
  tokens_obj.padded = MetaData(self.pad_output)
1052
1077
  out.set(self.tokens_key, tokens_obj)
1053
1078
 
1054
- log_probs_obj = LogProbs._from_tensordict(
1055
- TensorDict(batch_size=out.batch_size).to_lazystack(0)
1056
- )
1057
- if self.pad_output:
1058
- log_probs_obj.full = log_probs_full_padded
1059
- else:
1060
- log_probs_full_unpadded = _unpad_tensors(
1061
- log_probs_full_padded, attention_mask_full_padded, as_nested=False
1079
+ if not logits_only:
1080
+ log_probs_obj = LogProbs._from_tensordict(
1081
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1062
1082
  )
1063
- log_probs_obj.full = log_probs_full_unpadded
1064
- log_probs_obj.response = None
1065
- log_probs_obj.padded = MetaData(self.pad_output)
1066
- out.set(self.log_probs_key, log_probs_obj)
1083
+ if self.pad_output:
1084
+ log_probs_obj.full = log_probs_full_padded
1085
+ else:
1086
+ log_probs_full_unpadded = _unpad_tensors(
1087
+ log_probs_full_padded, attention_mask_full_padded, as_nested=False
1088
+ )
1089
+ log_probs_obj.full = log_probs_full_unpadded
1090
+ log_probs_obj.response = None
1091
+ log_probs_obj.padded = MetaData(self.pad_output)
1092
+ out.set(self.log_probs_key, log_probs_obj)
1067
1093
 
1068
1094
  # Add logits to output if we're in a get_dist call
1069
1095
  if self._in_get_dist_call:
@@ -1095,7 +1121,7 @@ class TransformersWrapper(LLMWrapperBase):
1095
1121
  raise ValueError(f"Expected list of text for text input, got {type(text)}")
1096
1122
  return self._generate_from_text(text, cfg, out)
1097
1123
 
1098
- def _from_transformers_logprobs_text(self, td, cfg, out):
1124
+ def _from_transformers_logprobs_text(self, td, cfg, out, logits_only=False):
1099
1125
  """Compute log-probs from text input."""
1100
1126
  # Validate input
1101
1127
  if self.input_key not in td:
@@ -1168,6 +1194,7 @@ class TransformersWrapper(LLMWrapperBase):
1168
1194
  input_ids_full_padded,
1169
1195
  attention_mask_full_padded,
1170
1196
  self.tokenizer.pad_token_id,
1197
+ logits_only=logits_only,
1171
1198
  )
1172
1199
 
1173
1200
  # Build output TensorClass objects
@@ -1212,19 +1239,20 @@ class TransformersWrapper(LLMWrapperBase):
1212
1239
  masks_obj.padded = MetaData(self.pad_output)
1213
1240
  out.set(self.masks_key, masks_obj)
1214
1241
 
1215
- log_probs_obj = LogProbs._from_tensordict(
1216
- TensorDict(batch_size=out.batch_size).to_lazystack(0)
1217
- )
1218
- if self.pad_output:
1219
- log_probs_obj.full = log_probs_full_padded
1220
- else:
1221
- log_probs_full_unpadded = _unpad_tensors(
1222
- log_probs_full_padded, attention_mask_full_padded, as_nested=False
1242
+ if not logits_only:
1243
+ log_probs_obj = LogProbs._from_tensordict(
1244
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1223
1245
  )
1224
- log_probs_obj.full = log_probs_full_unpadded
1225
- log_probs_obj.response = None
1226
- log_probs_obj.padded = MetaData(self.pad_output)
1227
- out.set(self.log_probs_key, log_probs_obj)
1246
+ if self.pad_output:
1247
+ log_probs_obj.full = log_probs_full_padded
1248
+ else:
1249
+ log_probs_full_unpadded = _unpad_tensors(
1250
+ log_probs_full_padded, attention_mask_full_padded, as_nested=False
1251
+ )
1252
+ log_probs_obj.full = log_probs_full_unpadded
1253
+ log_probs_obj.response = None
1254
+ log_probs_obj.padded = MetaData(self.pad_output)
1255
+ out.set(self.log_probs_key, log_probs_obj)
1228
1256
 
1229
1257
  # Add logits to output if we're in a get_dist call
1230
1258
  if self._in_get_dist_call:
@@ -1416,7 +1444,11 @@ class TransformersWrapper(LLMWrapperBase):
1416
1444
  return out
1417
1445
 
1418
1446
  def _from_transformers_logprobs_tokens(
1419
- self, td: TensorDictBase, cfg: dict | None, out: TensorDictBase
1447
+ self,
1448
+ td: TensorDictBase,
1449
+ cfg: dict | None,
1450
+ out: TensorDictBase,
1451
+ logits_only=False,
1420
1452
  ) -> TensorDictBase:
1421
1453
  """Compute log-probs from tokens input."""
1422
1454
  # Validate input
@@ -1470,6 +1502,7 @@ class TransformersWrapper(LLMWrapperBase):
1470
1502
  input_ids_full_padded,
1471
1503
  attention_mask_full_padded,
1472
1504
  self.tokenizer.pad_token_id,
1505
+ logits_only=logits_only,
1473
1506
  )
1474
1507
 
1475
1508
  # Build output TensorClass objects
@@ -1514,19 +1547,20 @@ class TransformersWrapper(LLMWrapperBase):
1514
1547
  masks_obj.padded = MetaData(self.pad_output)
1515
1548
  out.set(self.masks_key, masks_obj)
1516
1549
 
1517
- log_probs_obj = LogProbs._from_tensordict(
1518
- TensorDict(batch_size=out.batch_size).to_lazystack(0)
1519
- )
1520
- if self.pad_output:
1521
- log_probs_obj.full = log_probs_full_padded
1522
- else:
1523
- log_probs_full_unpadded = _unpad_tensors(
1524
- log_probs_full_padded, attention_mask_full_padded, as_nested=False
1550
+ if not logits_only:
1551
+ log_probs_obj = LogProbs._from_tensordict(
1552
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
1525
1553
  )
1526
- log_probs_obj.full = log_probs_full_unpadded
1527
- log_probs_obj.response = None
1528
- log_probs_obj.padded = MetaData(self.pad_output)
1529
- out.set(self.log_probs_key, log_probs_obj)
1554
+ if self.pad_output:
1555
+ log_probs_obj.full = log_probs_full_padded
1556
+ else:
1557
+ log_probs_full_unpadded = _unpad_tensors(
1558
+ log_probs_full_padded, attention_mask_full_padded, as_nested=False
1559
+ )
1560
+ log_probs_obj.full = log_probs_full_unpadded
1561
+ log_probs_obj.response = None
1562
+ log_probs_obj.padded = MetaData(self.pad_output)
1563
+ out.set(self.log_probs_key, log_probs_obj)
1530
1564
 
1531
1565
  # Add logits to output if we're in a get_dist call
1532
1566
  if self._in_get_dist_call:
@@ -1567,7 +1601,7 @@ class TransformersWrapper(LLMWrapperBase):
1567
1601
  return log_probs, logits
1568
1602
 
1569
1603
  def _compute_log_probs_from_model_output(
1570
- self, model_output, input_ids, attention_mask, pad_val
1604
+ self, model_output, input_ids, attention_mask, pad_val, logits_only=False
1571
1605
  ):
1572
1606
  """Compute log-probs from model output without modifying original tensors.
1573
1607
 
@@ -1576,6 +1610,7 @@ class TransformersWrapper(LLMWrapperBase):
1576
1610
  input_ids: Original input token ids
1577
1611
  attention_mask: Original attention mask
1578
1612
  pad_val: Padding token value to ignore in loss computation
1613
+ logits_only: Whether to return only the logits.
1579
1614
 
1580
1615
  Returns:
1581
1616
  tuple: (log_probs, shifted_logits) where log_probs are the computed log probabilities
@@ -1600,6 +1635,8 @@ class TransformersWrapper(LLMWrapperBase):
1600
1635
  raise ValueError(
1601
1636
  f"The logits shape {shifted_logits.shape} does not match the input ids shape {shifted_input_ids.shape}"
1602
1637
  )
1638
+ if logits_only:
1639
+ return None, shifted_logits
1603
1640
 
1604
1641
  # Compute log-probs
1605
1642
  td = TensorDict(
@@ -11,6 +11,7 @@ from typing import Any, Literal
11
11
  import torch
12
12
  from tensordict import (
13
13
  lazy_stack,
14
+ LazyStackedTensorDict,
14
15
  MetaData,
15
16
  NonTensorStack,
16
17
  set_list_to_stack,
@@ -500,19 +501,32 @@ class vLLMWrapper(LLMWrapperBase):
500
501
  def forward(
501
502
  self,
502
503
  tensordict: TensorDictBase,
504
+ *,
503
505
  tensordict_out: TensorDictBase | None = None,
506
+ logits_only: bool = False,
504
507
  **kwargs,
505
508
  ) -> TensorDictBase:
509
+ tensordict_orig = tensordict
506
510
  if not tensordict.ndim:
511
+ if tensordict_out is not None:
512
+ raise ValueError(
513
+ "tensordict_out must not be provided when tensordict.ndim == 0. If this is needed, "
514
+ "please submit an issue on github."
515
+ )
507
516
  # unsqueeze - squeeze the input
508
- try:
509
- return self(lazy_stack([tensordict])).squeeze(0)
510
- except Exception as e:
511
- raise RuntimeError(
512
- f"Unsqueeze/squeeze failed. Inputs to {type(self).__name__} should ideally be 1 dimensional."
513
- ) from e
517
+ return self.forward(lazy_stack([tensordict]), logits_only=logits_only)[0]
514
518
  elif tensordict.ndim > 1:
515
- return self(tensordict.reshape(-1)).view(tensordict.shape)
519
+ if tensordict_out is not None:
520
+ raise ValueError(
521
+ "tensordict_out must not be provided when tensordict.ndim > 1. If this is needed, "
522
+ "please submit an issue on github."
523
+ )
524
+ return self.forward(tensordict.reshape(-1), logits_only=logits_only).view(
525
+ tensordict.shape
526
+ )
527
+
528
+ if not isinstance(tensordict, LazyStackedTensorDict):
529
+ tensordict = tensordict.to_lazystack(0)
516
530
 
517
531
  _source_device = None
518
532
  if self._device:
@@ -567,7 +581,7 @@ class vLLMWrapper(LLMWrapperBase):
567
581
  if tensordict_out is None:
568
582
  if self.inplace is True:
569
583
  # The output is the input
570
- tensordict_out = tensordict
584
+ tensordict_out = tensordict_orig
571
585
  elif self.inplace is False:
572
586
  # The output is the new structure
573
587
  tensordict_out = out
@@ -1242,12 +1256,14 @@ class vLLMWrapper(LLMWrapperBase):
1242
1256
 
1243
1257
  generate_kwargs = {"sampling_params": sampling_params}
1244
1258
  args = ()
1259
+ empirical_attention_mask = None
1245
1260
 
1246
1261
  if tokens_prompt_unpadded is None:
1247
1262
  # TODO: To be on the safe side, we may do this even in the unpadded case since we're not sure
1248
1263
  # the user passed an unpadded tensor in the first place.
1264
+ empirical_attention_mask = tokens_prompt_padded != self.padding_value
1249
1265
  tokens_prompt_list = self._to_list(
1250
- tokens_prompt_padded, tokens_prompt_padded != self.padding_value
1266
+ tokens_prompt_padded, empirical_attention_mask
1251
1267
  )
1252
1268
  else:
1253
1269
  tokens_prompt_list = self._to_list(tokens_prompt_unpadded, None)
@@ -1365,6 +1381,22 @@ class vLLMWrapper(LLMWrapperBase):
1365
1381
  padding_value=self.padding_value,
1366
1382
  padding_side="right",
1367
1383
  )
1384
+ if (
1385
+ prompt_logprobs_padded.shape[-1]
1386
+ != tokens_prompt_padded.shape[-1]
1387
+ ):
1388
+ tshape = tokens_prompt_padded.shape
1389
+ oshape = prompt_logprobs_padded.shape
1390
+ # it could be that the input was padded already - padding again then
1391
+ prompt_logprobs_padded = torch.cat(
1392
+ [
1393
+ prompt_logprobs_padded.new_zeros(
1394
+ tshape[:-1] + (tshape[-1] - oshape[-1],)
1395
+ ),
1396
+ prompt_logprobs_padded,
1397
+ ],
1398
+ -1,
1399
+ )
1368
1400
  else:
1369
1401
  prompt_logprobs_list = request_output_tc.get(
1370
1402
  "prompt_logprobs",
@@ -1490,26 +1522,21 @@ class vLLMWrapper(LLMWrapperBase):
1490
1522
 
1491
1523
  request_output_tc = _RequestOutput_tc.from_request_output(tokens_out_stuct)
1492
1524
 
1525
+ # For unpadded case, extract from each sequence
1526
+ log_probs_full_unpadded = request_output_tc.get("prompt_logprobs", as_list=True)
1527
+
1493
1528
  # Extract log-probs from prompt_logprobs
1494
1529
  if self.pad_output:
1495
1530
  # For padded case, use all prompt_logprobs
1496
- log_probs_full_padded = request_output_tc.get(
1497
- "prompt_logprobs",
1498
- as_padded_tensor=True,
1499
- padding_value=0,
1500
- padding_side="left",
1531
+ if attention_mask_full_padded is not None:
1532
+ attention_mask_full_padded = tokens_full_padded != self.padding_value
1533
+ log_probs_full_padded = torch.zeros_like(
1534
+ tokens_full_padded, dtype=torch.get_default_dtype()
1501
1535
  )
1502
-
1503
- # Mask out padding
1504
- attention_mask_full_padded = tokens_full_padded != self.padding_value
1505
- log_probs_full_padded = torch.where(
1506
- attention_mask_full_padded, log_probs_full_padded, 0.0
1536
+ log_probs_full_padded[attention_mask_full_padded] = torch.cat(
1537
+ log_probs_full_unpadded, -1
1507
1538
  )
1508
1539
  else:
1509
- # For unpadded case, extract from each sequence
1510
- log_probs_full_unpadded = request_output_tc.get(
1511
- "prompt_logprobs", as_list=True
1512
- )
1513
1540
  self._check_not_padded(log_probs_full_unpadded)
1514
1541
 
1515
1542
  assistant_mask_full_padded = None
torchrl/objectives/a2c.py CHANGED
@@ -70,7 +70,7 @@ class A2CLoss(LossModule):
70
70
  samples will be used to compute this estimate.
71
71
  Defaults to ``1``.
72
72
  entropy_coeff (:obj:`float`): the weight of the entropy loss. Defaults to `0.01``.
73
- critic_coef (:obj:`float`): the weight of the critic loss. Defaults to ``1.0``. If ``None``, the critic
73
+ critic_coeff (:obj:`float`): the weight of the critic loss. Defaults to ``1.0``. If ``None``, the critic
74
74
  loss won't be included and the in-keys will miss the critic inputs.
75
75
  loss_critic_type (str): loss function for the value discrepancy.
76
76
  Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
@@ -156,7 +156,7 @@ class A2CLoss(LossModule):
156
156
  the expected keyword arguments are:
157
157
  ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and critic.
158
158
  The return value is a tuple of tensors in the following order:
159
- ``["loss_objective"]`` + ``["loss_critic"]`` if critic_coef is not None + ``["entropy", "loss_entropy"]`` if entropy_bonus is True and critic_coef is not None
159
+ ``["loss_objective"]`` + ``["loss_critic"]`` if critic_coeff is not None + ``["entropy", "loss_entropy"]`` if entropy_bonus is True and critic_coeff is not None
160
160
 
161
161
  Examples:
162
162
  >>> import torch
@@ -277,8 +277,8 @@ class A2CLoss(LossModule):
277
277
  *,
278
278
  entropy_bonus: bool = True,
279
279
  samples_mc_entropy: int = 1,
280
- entropy_coeff: float = 0.01,
281
- critic_coef: float = 1.0,
280
+ entropy_coeff: float | None = None,
281
+ critic_coeff: float = 1.0,
282
282
  loss_critic_type: str = "smooth_l1",
283
283
  gamma: float | None = None,
284
284
  separate_losses: bool = False,
@@ -291,13 +291,32 @@ class A2CLoss(LossModule):
291
291
  clip_value: float | None = None,
292
292
  **kwargs,
293
293
  ):
294
+ # Handle deprecated entropy_coef argument
294
295
  if "entropy_coef" in kwargs:
296
+ if entropy_coeff is not None: # Check if entropy_coeff was explicitly set
297
+ raise ValueError(
298
+ "Cannot specify both 'entropy_coef' and 'entropy_coeff'"
299
+ )
295
300
  warnings.warn(
296
301
  "'entropy_coef' is deprecated and will be removed in torchrl v0.11. Please use 'entropy_coeff' instead.",
297
302
  DeprecationWarning,
298
303
  )
299
304
  entropy_coeff = kwargs.pop("entropy_coef")
300
305
 
306
+ # Set default value if None
307
+ if entropy_coeff is None:
308
+ entropy_coeff = 0.01
309
+
310
+ # Handle deprecated critic_coef argument
311
+ if "critic_coef" in kwargs:
312
+ if critic_coeff != 1.0: # Check if critic_coeff was explicitly set
313
+ raise ValueError("Cannot specify both 'critic_coef' and 'critic_coeff'")
314
+ warnings.warn(
315
+ "'critic_coef' is deprecated and will be removed in torchrl v0.11. Please use 'critic_coeff' instead.",
316
+ DeprecationWarning,
317
+ )
318
+ critic_coeff = kwargs.pop("critic_coef")
319
+
301
320
  if actor is not None:
302
321
  actor_network = actor
303
322
  del actor
@@ -349,12 +368,12 @@ class A2CLoss(LossModule):
349
368
  self.register_buffer(
350
369
  "entropy_coeff", torch.as_tensor(entropy_coeff, device=device)
351
370
  )
352
- if critic_coef is not None:
371
+ if critic_coeff is not None:
353
372
  self.register_buffer(
354
- "critic_coef", torch.as_tensor(critic_coef, device=device)
373
+ "critic_coeff", torch.as_tensor(critic_coeff, device=device)
355
374
  )
356
375
  else:
357
- self.critic_coef = None
376
+ self.critic_coeff = None
358
377
 
359
378
  if gamma is not None:
360
379
  raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
@@ -399,7 +418,7 @@ class A2CLoss(LossModule):
399
418
  *self.actor_network.in_keys,
400
419
  *[("next", key) for key in self.actor_network.in_keys],
401
420
  ]
402
- if self.critic_coef is not None:
421
+ if self.critic_coeff is not None:
403
422
  keys.extend(self.critic_network.in_keys)
404
423
  return list(set(keys))
405
424
 
@@ -407,7 +426,7 @@ class A2CLoss(LossModule):
407
426
  def out_keys(self):
408
427
  if self._out_keys is None:
409
428
  outs = ["loss_objective"]
410
- if self.critic_coef is not None:
429
+ if self.critic_coeff is not None:
411
430
  outs.append("loss_critic")
412
431
  if self.entropy_bonus:
413
432
  outs.append("entropy")
@@ -478,7 +497,7 @@ class A2CLoss(LossModule):
478
497
  return log_prob, dist
479
498
 
480
499
  def loss_critic(self, tensordict: TensorDictBase) -> tuple[torch.Tensor, float]:
481
- """Returns the loss value of the critic, multiplied by ``critic_coef`` if it is not ``None``.
500
+ """Returns the loss value of the critic, multiplied by ``critic_coeff`` if it is not ``None``.
482
501
 
483
502
  Returns the loss and the clip-fraction.
484
503
 
@@ -539,8 +558,8 @@ class A2CLoss(LossModule):
539
558
  "target_actor_network_params",
540
559
  "target_critic_network_params",
541
560
  )
542
- if self.critic_coef is not None:
543
- return self.critic_coef * loss_value, clip_fraction
561
+ if self.critic_coeff is not None:
562
+ return self.critic_coeff * loss_value, clip_fraction
544
563
  return loss_value, clip_fraction
545
564
 
546
565
  @property
@@ -568,7 +587,7 @@ class A2CLoss(LossModule):
568
587
  entropy = self.get_entropy_bonus(dist)
569
588
  td_out.set("entropy", entropy.detach().mean()) # for logging
570
589
  td_out.set("loss_entropy", -self.entropy_coeff * entropy)
571
- if self.critic_coef is not None:
590
+ if self.critic_coeff is not None:
572
591
  loss_critic, value_clip_fraction = self.loss_critic(tensordict)
573
592
  td_out.set("loss_critic", loss_critic)
574
593
  if value_clip_fraction is not None:
torchrl/objectives/ppo.py CHANGED
@@ -102,13 +102,13 @@ class PPOLoss(LossModule):
102
102
  Defaults to ``1``.
103
103
  entropy_coeff: scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
104
104
  * **Scalar**: one value applied to the summed entropy of every action head.
105
- * **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
105
+ * **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
106
106
  Defaults to ``0.01``.
107
107
  log_explained_variance (bool, optional): if ``True``, the explained variance of the critic
108
108
  predictions w.r.t. value targets will be computed and logged as ``"explained_variance"``.
109
109
  This can help monitor critic quality during training. Best possible score is 1.0, lower values are worse. Defaults to ``True``.
110
- critic_coef (scalar, optional): critic loss multiplier when computing the total
111
- loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value
110
+ critic_coeff (scalar, optional): critic loss multiplier when computing the total
111
+ loss. Defaults to ``1.0``. Set ``critic_coeff`` to ``None`` to exclude the value
112
112
  loss from the forward outputs.
113
113
  loss_critic_type (str, optional): loss function for the value discrepancy.
114
114
  Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
@@ -239,7 +239,7 @@ class PPOLoss(LossModule):
239
239
  the expected keyword arguments are:
240
240
  ``["action", "sample_log_prob", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and value network.
241
241
  The return value is a tuple of tensors in the following order:
242
- ``["loss_objective"]`` + ``["entropy", "loss_entropy"]`` if entropy_bonus is set + ``"loss_critic"`` if critic_coef is not ``None``.
242
+ ``["loss_objective"]`` + ``["entropy", "loss_entropy"]`` if entropy_bonus is set + ``"loss_critic"`` if critic_coeff is not ``None``.
243
243
  The output keys can also be filtered using :meth:`PPOLoss.select_out_keys` method.
244
244
 
245
245
  Examples:
@@ -351,9 +351,9 @@ class PPOLoss(LossModule):
351
351
  *,
352
352
  entropy_bonus: bool = True,
353
353
  samples_mc_entropy: int = 1,
354
- entropy_coeff: float | Mapping[str, float] = 0.01,
354
+ entropy_coeff: float | Mapping[str, float] | None = None,
355
355
  log_explained_variance: bool = True,
356
- critic_coef: float | None = None,
356
+ critic_coeff: float | None = None,
357
357
  loss_critic_type: str = "smooth_l1",
358
358
  normalize_advantage: bool = False,
359
359
  normalize_advantage_exclude_dims: tuple[int] = (),
@@ -377,13 +377,23 @@ class PPOLoss(LossModule):
377
377
  critic_network = critic
378
378
  del critic
379
379
 
380
- if critic_coef is None and critic_network is not None:
381
- critic_coef = 1.0
382
- elif critic_coef in (None, 0) and critic_network is not None:
383
- critic_coef = None
380
+ # Handle deprecated critic_coef argument
381
+ if "critic_coef" in kwargs:
382
+ if critic_coeff is not None:
383
+ raise ValueError("Cannot specify both 'critic_coef' and 'critic_coeff'")
384
+ warnings.warn(
385
+ "'critic_coef' is deprecated and will be removed in torchrl v0.11. Please use 'critic_coeff' instead.",
386
+ DeprecationWarning,
387
+ )
388
+ critic_coeff = kwargs.pop("critic_coef")
389
+
390
+ if critic_coeff is None and critic_network is not None:
391
+ critic_coeff = 1.0
392
+ elif critic_coeff in (None, 0) and critic_network is not None:
393
+ critic_coeff = None
384
394
 
385
395
  if actor_network is None or (
386
- critic_network is None and critic_coef not in (None, 0.0)
396
+ critic_network is None and critic_coeff not in (None, 0.0)
387
397
  ):
388
398
  raise TypeError(
389
399
  "Missing positional arguments actor_network or critic_network."
@@ -431,13 +441,21 @@ class PPOLoss(LossModule):
431
441
  torch, "get_default_device", lambda: torch.device("cpu")
432
442
  )()
433
443
 
434
- # Handle deprecated entropy_coeff argument
435
- if "entropy_coeff" in kwargs:
444
+ # Handle deprecated entropy_coef argument
445
+ if "entropy_coef" in kwargs:
446
+ if entropy_coeff is not None: # Check if entropy_coeff was explicitly set
447
+ raise ValueError(
448
+ "Cannot specify both 'entropy_coef' and 'entropy_coeff'"
449
+ )
436
450
  warnings.warn(
437
- "'entropy_coeff' is deprecated and will be removed in torchrl v0.11. Please use 'entropy_coeff' instead.",
451
+ "'entropy_coef' is deprecated and will be removed in torchrl v0.11. Please use 'entropy_coeff' instead.",
438
452
  DeprecationWarning,
439
453
  )
440
- entropy_coeff = kwargs.pop("entropy_coeff")
454
+ entropy_coeff = kwargs.pop("entropy_coef")
455
+
456
+ # Set default value if None
457
+ if entropy_coeff is None:
458
+ entropy_coeff = 0.01
441
459
 
442
460
  if isinstance(entropy_coeff, Mapping):
443
461
  # Store the mapping for per-head coefficients
@@ -457,13 +475,13 @@ class PPOLoss(LossModule):
457
475
  self._entropy_coeff_map = None
458
476
  else:
459
477
  raise TypeError("entropy_coeff must be a float or a Mapping[str, float]")
460
- if critic_coef is not None:
478
+ if critic_coeff is not None:
461
479
  self.register_buffer(
462
- "critic_coef", torch.tensor(critic_coef, device=device)
480
+ "critic_coeff", torch.tensor(critic_coeff, device=device)
463
481
  )
464
482
  else:
465
- self.critic_coef = None
466
- self._has_critic = bool(self.critic_coef is not None and self.critic_coef > 0)
483
+ self.critic_coeff = None
484
+ self._has_critic = bool(self.critic_coeff is not None and self.critic_coeff > 0)
467
485
  self.loss_critic_type = loss_critic_type
468
486
  self.normalize_advantage = normalize_advantage
469
487
  self.normalize_advantage_exclude_dims = normalize_advantage_exclude_dims
@@ -692,7 +710,7 @@ class PPOLoss(LossModule):
692
710
  def loss_critic(
693
711
  self, tensordict: TensorDictBase
694
712
  ) -> tuple[torch.Tensor | TensorDict, ...]:
695
- """Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``."""
713
+ """Returns the critic loss multiplied by ``critic_coeff``, if it is not ``None``."""
696
714
  # TODO: if the advantage is gathered by forward, this introduces an
697
715
  # overhead that we could easily reduce.
698
716
  if self.separate_losses:
@@ -766,7 +784,7 @@ class PPOLoss(LossModule):
766
784
  "target_critic_network_params",
767
785
  )
768
786
  if self._has_critic:
769
- return self.critic_coef * loss_value, clip_fraction, explained_variance
787
+ return self.critic_coeff * loss_value, clip_fraction, explained_variance
770
788
  return loss_value, clip_fraction, explained_variance
771
789
 
772
790
  @property
@@ -954,10 +972,10 @@ class ClipPPOLoss(PPOLoss):
954
972
  Defaults to ``1``.
955
973
  entropy_coeff: (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
956
974
  * **Scalar**: one value applied to the summed entropy of every action head.
957
- * **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
975
+ * **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
958
976
  Defaults to ``0.01``.
959
- critic_coef (scalar, optional): critic loss multiplier when computing the total
960
- loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value
977
+ critic_coeff (scalar, optional): critic loss multiplier when computing the total
978
+ loss. Defaults to ``1.0``. Set ``critic_coeff`` to ``None`` to exclude the value
961
979
  loss from the forward outputs.
962
980
  loss_critic_type (str, optional): loss function for the value discrepancy.
963
981
  Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
@@ -1057,8 +1075,8 @@ class ClipPPOLoss(PPOLoss):
1057
1075
  clip_epsilon: float = 0.2,
1058
1076
  entropy_bonus: bool = True,
1059
1077
  samples_mc_entropy: int = 1,
1060
- entropy_coeff: float | Mapping[str, float] = 0.01,
1061
- critic_coef: float | None = None,
1078
+ entropy_coeff: float | Mapping[str, float] | None = None,
1079
+ critic_coeff: float | None = None,
1062
1080
  loss_critic_type: str = "smooth_l1",
1063
1081
  normalize_advantage: bool = False,
1064
1082
  normalize_advantage_exclude_dims: tuple[int] = (),
@@ -1079,7 +1097,7 @@ class ClipPPOLoss(PPOLoss):
1079
1097
  entropy_bonus=entropy_bonus,
1080
1098
  samples_mc_entropy=samples_mc_entropy,
1081
1099
  entropy_coeff=entropy_coeff,
1082
- critic_coef=critic_coef,
1100
+ critic_coeff=critic_coeff,
1083
1101
  loss_critic_type=loss_critic_type,
1084
1102
  normalize_advantage=normalize_advantage,
1085
1103
  normalize_advantage_exclude_dims=normalize_advantage_exclude_dims,
@@ -1247,9 +1265,9 @@ class KLPENPPOLoss(PPOLoss):
1247
1265
  Defaults to ``1``.
1248
1266
  entropy_coeff: scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
1249
1267
  * **Scalar**: one value applied to the summed entropy of every action head.
1250
- * **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
1268
+ * **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
1251
1269
  Defaults to ``0.01``.
1252
- critic_coef (scalar, optional): critic loss multiplier when computing the total
1270
+ critic_coeff (scalar, optional): critic loss multiplier when computing the total
1253
1271
  loss. Defaults to ``1.0``.
1254
1272
  loss_critic_type (str, optional): loss function for the value discrepancy.
1255
1273
  Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
@@ -1351,8 +1369,8 @@ class KLPENPPOLoss(PPOLoss):
1351
1369
  samples_mc_kl: int = 1,
1352
1370
  entropy_bonus: bool = True,
1353
1371
  samples_mc_entropy: int = 1,
1354
- entropy_coeff: float | Mapping[str, float] = 0.01,
1355
- critic_coef: float | None = None,
1372
+ entropy_coeff: float | Mapping[str, float] | None = None,
1373
+ critic_coeff: float | None = None,
1356
1374
  loss_critic_type: str = "smooth_l1",
1357
1375
  normalize_advantage: bool = False,
1358
1376
  normalize_advantage_exclude_dims: tuple[int] = (),
@@ -1369,7 +1387,7 @@ class KLPENPPOLoss(PPOLoss):
1369
1387
  entropy_bonus=entropy_bonus,
1370
1388
  samples_mc_entropy=samples_mc_entropy,
1371
1389
  entropy_coeff=entropy_coeff,
1372
- critic_coef=critic_coef,
1390
+ critic_coeff=critic_coeff,
1373
1391
  loss_critic_type=loss_critic_type,
1374
1392
  normalize_advantage=normalize_advantage,
1375
1393
  normalize_advantage_exclude_dims=normalize_advantage_exclude_dims,
@@ -86,7 +86,7 @@ class A2CLossConfig:
86
86
  # Decay factor for return computation. Default=0.99.
87
87
  entropy_coeff: float = 1e-3
88
88
  # Entropy factor for the A2C loss
89
- critic_coef: float = 1.0
89
+ critic_coeff: float = 1.0
90
90
  # Critic factor for the A2C loss
91
91
  critic_loss_function: str = "smooth_l1"
92
92
  # loss function for the value network. Either one of l1, l2 or smooth_l1 (default).
@@ -112,7 +112,7 @@ class PPOLossConfig:
112
112
  # Number of samples to use for a Monte-Carlo estimate if the policy distribution has not closed formula.
113
113
  loss_function: str = "smooth_l1"
114
114
  # loss function for the value network. Either one of l1, l2 or smooth_l1 (default).
115
- critic_coef: float = 1.0
115
+ critic_coeff: float = 1.0
116
116
  # Critic loss multiplier when computing the total loss.
117
117
 
118
118
  # ClipPPOLoss parameters:
torchrl/version.py CHANGED
@@ -1,2 +1,2 @@
1
- __version__ = '2025.7.15'
2
- git_version = '77c00b910e6fdd85aa94b4d354390b724af4ec94'
1
+ __version__ = '2025.7.18'
2
+ git_version = '4001d9cb73cea4498b0fdfe420effc58a5a336be'
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchrl-nightly
3
- Version: 2025.7.15
3
+ Version: 2025.7.18
4
4
  Home-page: https://github.com/pytorch/rl
5
5
  Author: torchrl contributors
6
6
  Author-email: vmoens@fb.com
@@ -3,11 +3,11 @@ build_tools/setup_helpers/__init__.py,sha256=7l8TvVqxKezgzKCLuRv20mvGLloprFVZYm8
3
3
  build_tools/setup_helpers/extension.py,sha256=4-PDLr-pw40bJnd9SfxnTaSjUyuXU_Tg8yOg69Kl0o4,5914
4
4
  torchrl/__init__.py,sha256=mhDBx2UIuBKc0gmi8dVNHokQ6tCbIovruZmyAxcSsy8,2938
5
5
  torchrl/_extension.py,sha256=z7wQ8i1iYWYcnygq_j0nq9sT-koY13tfHhTLNbMk17Q,2353
6
- torchrl/_torchrl.cpython-313-darwin.so,sha256=Oc5ssTpuyTv6h4iaS_pHmxT44bNN96SH0V9beynYRSc,1692464
6
+ torchrl/_torchrl.cpython-313-darwin.so,sha256=xjRDwIjMdFaN3InficNo2b0v9ju7g-ONAgZ0DcGbk38,1692464
7
7
  torchrl/_utils.py,sha256=Cw5EG6x5oSZF1iE3YCs1a32VUKp0rTXIs2u67q9zKUI,41078
8
- torchrl/version.py,sha256=Fow5OPjVvk1yM4tQyBX-t6Un4hGKcYsr4kuvYN_gGPs,83
8
+ torchrl/version.py,sha256=MHs4CxNjQupYI_f84bY7dOAAfPSU9yN6TOyxiS7tS8c,83
9
9
  torchrl/collectors/__init__.py,sha256=hJ3JD6shRku0BL6SzJQq44FZ5Q1RGR8LealFyU3FRn4,799
10
- torchrl/collectors/collectors.py,sha256=WoeR-MAfzcLiy8EHPWQ3uknm_jTWjA9Wi45CODG8NZI,177782
10
+ torchrl/collectors/collectors.py,sha256=HpaW-y0bQOaOql8_7VyEPJ084CulrVwn6iBpGYoHyH4,178287
11
11
  torchrl/collectors/utils.py,sha256=MlXrkYuDmV0Em-tVNQiLL32FWgPNDgceYYG_GgpiviA,11320
12
12
  torchrl/collectors/weight_update.py,sha256=nSIfs8ALsfggLoC2ylg1oOAqdGku1tt4e-50JCZJBww,21073
13
13
  torchrl/collectors/distributed/__init__.py,sha256=_24P0ALFunLhL-ls7EsssGUhJkZ_m3nw7krfMTwPqS0,705
@@ -25,7 +25,7 @@ torchrl/collectors/llm/weight_update/__init__.py,sha256=bKjvD7yZG5VnHgvYc4EmKI1s
25
25
  torchrl/collectors/llm/weight_update/vllm.py,sha256=slKUmrIo4eL6R4J1oEnmlP6Q7Zer09p92JU8zbIHFUM,11515
26
26
  torchrl/data/__init__.py,sha256=oowsio6ZUOZnJV8JV43xgs17B37XO1yKAYIQPdk8yt0,4819
27
27
  torchrl/data/rlhf.py,sha256=JUmdYBWgkN229DwpXuDrhy9ddjduNvU2kyHzHR6MoA0,963
28
- torchrl/data/tensor_specs.py,sha256=_t6-iobtJClJ50zvo1KzHSaYS5CvL2Ca6x8btlAc3rs,253067
28
+ torchrl/data/tensor_specs.py,sha256=RlMckj6PJo9MQMzneHzbcVe9xUyMB_n7pnSz0jytB9s,253907
29
29
  torchrl/data/utils.py,sha256=attuNwzfgjszyp0lJSrV06f2peX3r0qTjRZWEwfl6Yg,12108
30
30
  torchrl/data/datasets/__init__.py,sha256=NQpXsHecbZmza8AocX9mkqQQNkdFzeUrMTZoi6hbbU4,733
31
31
  torchrl/data/datasets/atari_dqn.py,sha256=3ij6-UGfKev-QJuUEhZEEmn_3yL210CqKJALaFvlc5M,40739
@@ -41,7 +41,7 @@ torchrl/data/datasets/utils.py,sha256=nAFDTlBIPyEoPoJC-Hc_fcOhzE7UZQE4BwKxq15Vhv
41
41
  torchrl/data/datasets/vd4rl.py,sha256=z90MqrxKzod8TPGK0uzkC6vw5wQIE4cgrDAC4e72jyk,18262
42
42
  torchrl/data/llm/__init__.py,sha256=B4Ekok-w5PMiWcfmAGXaseaN6hWdNOr4WebeLrHfBVQ,975
43
43
  torchrl/data/llm/dataset.py,sha256=t-41hAzQcjrdoKwpHIMbcrT7pRcQ7DHl2a1-lr6E7W4,20703
44
- torchrl/data/llm/history.py,sha256=Tzkwmc37C9vYjVw_x1wblyENNZSV67srBEioO2j4v2c,57857
44
+ torchrl/data/llm/history.py,sha256=l9JSxIO5eLUFwHH5IZkANSrByYa8BGmtxMlNXYf2fbs,59640
45
45
  torchrl/data/llm/prompt.py,sha256=bg5LzJfwOq5Ns72KQMciIprMWAmDDinzdopwdopU04c,8380
46
46
  torchrl/data/llm/reward.py,sha256=FbPchNXG3smJV9NCbB5Yk4grsCa2Se4KZ_tojVLKWQM,8404
47
47
  torchrl/data/llm/topk.py,sha256=mYXCgJS4TuEVLZfTNccQd6kmC858AAh2Ygy0q_K1hlY,8365
@@ -131,7 +131,7 @@ torchrl/envs/transforms/llm.py,sha256=rQDzuut807wvFpSPCm5tynt8-cMKTgVKVjSVu9D99P
131
131
  torchrl/envs/transforms/r3m.py,sha256=sdTVLpnxHfzFVo5rO8WnXf2uUg9cr4LBOLBsWaFgGT8,13478
132
132
  torchrl/envs/transforms/rb_transforms.py,sha256=6ohnKXHHAEh2Hz3Seaw6eDrcFMu-1IVQrT7RVywh3YQ,7447
133
133
  torchrl/envs/transforms/rlhf.py,sha256=lOVXYqQaoDfm4_n77Dxw_wjicBpMtDvavKmBIK2N3lU,628
134
- torchrl/envs/transforms/transforms.py,sha256=QnPV5R0sDbR9bHJnRSG8JBy6cnMIeKG7vYUQjRVw5a8,482966
134
+ torchrl/envs/transforms/transforms.py,sha256=cDv_NxElzTOW8qQO-2krvOBmlKVGPOKMfqM6XyuLckU,482882
135
135
  torchrl/envs/transforms/utils.py,sha256=7ToVFnD4-DkOMtML91g4bqXeY0bZ-gmCaSLxC93oaKM,3264
136
136
  torchrl/envs/transforms/vc1.py,sha256=mho5BvdAK-f9hD9t-iah52wT2B06qPmaJO7chrfIOWY,10534
137
137
  torchrl/envs/transforms/vecnorm.py,sha256=XahMcWvK3zjOB6EACSZtJ6UMP3yQ2zD9xf87UEB37Eg,34047
@@ -139,7 +139,7 @@ torchrl/envs/transforms/vip.py,sha256=kmygbenw75rEYsKRq4X1hzEH_CRe1406NZZ8Hg2R_V
139
139
  torchrl/modules/__init__.py,sha256=XlAO0hulhDQNcKhbu3cFi8KJOHXNiAgmXeTfny0WBqE,4157
140
140
  torchrl/modules/distributions/__init__.py,sha256=RDFoYD9IX1FhwXk5R4M8khq42gdTOcVnUnKHfWCTZBQ,1597
141
141
  torchrl/modules/distributions/continuous.py,sha256=VPBugDuavJmyZ-RzemyLIFA02UCMLsm-rzBQrKcTlIA,25667
142
- torchrl/modules/distributions/discrete.py,sha256=czQSNkacxgZcExKONzDRPZjCJPbfAVaS7fC7Igdp708,35555
142
+ torchrl/modules/distributions/discrete.py,sha256=7UE6X8LeTZkaTRFvKNcFSOoug_tOcD_u-FOh-39ZSC4,35581
143
143
  torchrl/modules/distributions/truncated_normal.py,sha256=-qM8vwxTzv3VsWphZwcueDQpHQ67IRnkDFKlTDkQQnY,5937
144
144
  torchrl/modules/distributions/utils.py,sha256=kXRvNHeKUePIgKgn7DnKqbhQ6ImFGgkFVRxITX2dwNU,7567
145
145
  torchrl/modules/llm/__init__.py,sha256=BTkn-8QKp_8sW_NTKP02yoWSJUsX0XL6L9chTJl6epc,737
@@ -147,9 +147,9 @@ torchrl/modules/llm/utils.py,sha256=gf_F-4bEMwkcI3jLQM7ifB7nsjRctGebB5E2c-AznO0,
147
147
  torchrl/modules/llm/backends/__init__.py,sha256=WdVy9EdiAfk8i5zFa49TEkRvcUd0L4Un4v6wqWBy8l8,438
148
148
  torchrl/modules/llm/backends/vllm.py,sha256=x57Xop1xd5ZShicsh47ZFmz4VpfZ3eCzVx7k0COvpqQ,9387
149
149
  torchrl/modules/llm/policies/__init__.py,sha256=nfZ2mcVuucxnY3WCuzIQrTLIf1yEd36k8-AlvwnSa8Y,545
150
- torchrl/modules/llm/policies/common.py,sha256=zuaw0CVBAuMcd857JkdVWfSaxGFgwDXWOPF8GflqIkw,36379
151
- torchrl/modules/llm/policies/transformers_wrapper.py,sha256=HTkubIsbEui2hWqAZ3GwsATI2NGmA0kry1nW5RjnEJ0,74326
152
- torchrl/modules/llm/policies/vllm_wrapper.py,sha256=u0ITRdVI8pNhpRRMy2yXEh9bK_TkYRUOUEzix2m2aR0,78231
150
+ torchrl/modules/llm/policies/common.py,sha256=Kvn1cJQbp1EZtxWpAQ50TzZkwVtLAmryqiBHH2nK_wM,39112
151
+ torchrl/modules/llm/policies/transformers_wrapper.py,sha256=oi-2KALM0pkH-u-Kd6WlnxfH9eGV2GzBqM410ANpPeM,75777
152
+ torchrl/modules/llm/policies/vllm_wrapper.py,sha256=ReBvi2M9IAiwwBAR7GpDLSQhX0aC-dXPnHYb082Q0To,79632
153
153
  torchrl/modules/models/__init__.py,sha256=DrOG-7hynjjUh_tc2EqysiUiNMRiDR0WLtZql9TPNcI,1743
154
154
  torchrl/modules/models/batchrenorm.py,sha256=TojpTUluIcFdTSemIVRLGtB2O5q54mRHy3vJP6DuI5I,4750
155
155
  torchrl/modules/models/decision_transformer.py,sha256=Lttf_wZMNqXbB_vpxMYgEp18gEzOvm3NvMnxQkHkH4M,6604
@@ -176,7 +176,7 @@ torchrl/modules/utils/__init__.py,sha256=KXaF_xEghKSPsNg0JyfxChK6KWHFRy0lwkL2Rip
176
176
  torchrl/modules/utils/mappings.py,sha256=VMYrPxDk1ywgl2l_f6HXZaRsVOKcYR7VF5DNkmi3lHk,362
177
177
  torchrl/modules/utils/utils.py,sha256=WPfcE-AoemnrP7Ny4FxJ-_LoQsBnX-y77Zb7MnZjXV0,2916
178
178
  torchrl/objectives/__init__.py,sha256=pnprzIXA6E9Ph7isYgNLh4SFTU0pxIQg4oUNcaQ6doc,2148
179
- torchrl/objectives/a2c.py,sha256=K8mWcLVLUnuW5DgPZCS8P9nN1t30Gvw0j-EgcnO-QGE,27895
179
+ torchrl/objectives/a2c.py,sha256=_xdp8D2ErOPyHwpxqPHtUr-EvZw7MqcuhhK9Isnewgo,28791
180
180
  torchrl/objectives/common.py,sha256=40inZ0z3bFdQUkXuup3PWP_KmCx1m13cKTksjOp_b6I,28571
181
181
  torchrl/objectives/cql.py,sha256=8faIZmA9e65NQ39HAi6torMofr98bkngjtBXm0UbnVM,54925
182
182
  torchrl/objectives/crossq.py,sha256=a_vAjET5GG-2U7zZDgMnA0QP1iPCtv2ho6q-XvvLsnc,28858
@@ -188,7 +188,7 @@ torchrl/objectives/dreamer.py,sha256=vIJQN91oPXYnPubDFQpaF5d3fR_WwIYuIVYtoCvw0TY
188
188
  torchrl/objectives/functional.py,sha256=ZaglBjEGuOTNGeFA-Ox-ugZVcNegQMUj--KWHDRBmaU,2106
189
189
  torchrl/objectives/gail.py,sha256=0m34XmcN-EDk5OfNIo5bKYbKKZfATsYRv4zQe3v2UwA,9576
190
190
  torchrl/objectives/iql.py,sha256=1jvlSznWke6NZSwfuYyHVnVBE7Cz3q169GnCRC7iel4,42991
191
- torchrl/objectives/ppo.py,sha256=x3wJ3k7jVZWPAZCxdk4bgzhoTYukPwTj39Yo6ZgBbCM,75250
191
+ torchrl/objectives/ppo.py,sha256=0soC2aiCOFNM5hCL20-99LX_NZi6XIXDmG2IkGEHSek,76082
192
192
  torchrl/objectives/redq.py,sha256=4usM-nG2UWujeL-VEqzf7-uOwRFx6itkKCeitKuJhtw,28507
193
193
  torchrl/objectives/reinforce.py,sha256=ySXLp5C-OOUYayqjrf4taQmL8LgRvMgPCgHDsle8JDc,22339
194
194
  torchrl/objectives/sac.py,sha256=Oq9Iq90s9KFbnM4KSRUd2onU1JfW6aW80LWGdtO0CY8,63993
@@ -219,12 +219,12 @@ torchrl/trainers/helpers/__init__.py,sha256=HhDB2Ubq2gZodV-hB6xw4ZgCgwaZKUoZgOfV
219
219
  torchrl/trainers/helpers/collectors.py,sha256=NjMMvGWEe4TWkVXzx7AlJ_Qa_AxEzMl6EUmEgUzHkoE,18715
220
220
  torchrl/trainers/helpers/envs.py,sha256=1yqJZgz7mc5wa58HmSDGpPQINeDHFZB0_KTgwdKm9QE,22084
221
221
  torchrl/trainers/helpers/logger.py,sha256=FtuEiLnK4NmxVVNyEEWaoCu3nG7WbNpHP3UYGQRJmgo,1278
222
- torchrl/trainers/helpers/losses.py,sha256=qH-2YJwMtDAYAPXTTYy3cOPiq4ILC6xTjfnGUU__6vo,5270
222
+ torchrl/trainers/helpers/losses.py,sha256=sHlJqjh02t8cKN73X35Azd_OoWGurohLuviB8Yeo4JQ,5272
223
223
  torchrl/trainers/helpers/models.py,sha256=ihTERG2c96E8cS3Tnul6a_ys6iDEEJmHh05p9blQTW8,21807
224
224
  torchrl/trainers/helpers/replay_buffer.py,sha256=ZUZHOa0TILyeWJ3iahzTJ6UvMl_0FdxuZfJEja94Bn8,2001
225
225
  torchrl/trainers/helpers/trainers.py,sha256=j6B5XA7_FFHMQeOIQwjNcO0CGE_4mZKUC9_jH_iqqh4,12071
226
- torchrl_nightly-2025.7.15.dist-info/licenses/LICENSE,sha256=xdjS4_xk-IwnLuIFCvTYTl9Y8aXRejqpmke3dGam_nI,1098
227
- torchrl_nightly-2025.7.15.dist-info/METADATA,sha256=j4RRTr55v80t_WJvysde-14_KWj9VMI3H7eXvuAmbeQ,42990
228
- torchrl_nightly-2025.7.15.dist-info/WHEEL,sha256=A6iggJuFsuu67bHdjxJADhwSEJmqwgO3xFoNCIwjOxc,115
229
- torchrl_nightly-2025.7.15.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
230
- torchrl_nightly-2025.7.15.dist-info/RECORD,,
226
+ torchrl_nightly-2025.7.18.dist-info/licenses/LICENSE,sha256=xdjS4_xk-IwnLuIFCvTYTl9Y8aXRejqpmke3dGam_nI,1098
227
+ torchrl_nightly-2025.7.18.dist-info/METADATA,sha256=K_Nmn84sw1xeD28lqIPqdhLjdaFchSMXuG2vjAajTn0,42990
228
+ torchrl_nightly-2025.7.18.dist-info/WHEEL,sha256=A6iggJuFsuu67bHdjxJADhwSEJmqwgO3xFoNCIwjOxc,115
229
+ torchrl_nightly-2025.7.18.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
230
+ torchrl_nightly-2025.7.18.dist-info/RECORD,,