torchrl-nightly 2025.7.15__cp313-cp313-macosx_10_13_universal2.whl → 2025.7.18__cp313-cp313-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrl/_torchrl.cpython-313-darwin.so +0 -0
- torchrl/collectors/collectors.py +13 -3
- torchrl/data/llm/history.py +36 -0
- torchrl/data/tensor_specs.py +34 -9
- torchrl/envs/transforms/transforms.py +0 -1
- torchrl/modules/distributions/discrete.py +1 -1
- torchrl/modules/llm/policies/common.py +59 -9
- torchrl/modules/llm/policies/transformers_wrapper.py +90 -53
- torchrl/modules/llm/policies/vllm_wrapper.py +50 -23
- torchrl/objectives/a2c.py +32 -13
- torchrl/objectives/ppo.py +50 -32
- torchrl/trainers/helpers/losses.py +2 -2
- torchrl/version.py +2 -2
- {torchrl_nightly-2025.7.15.dist-info → torchrl_nightly-2025.7.18.dist-info}/METADATA +1 -1
- {torchrl_nightly-2025.7.15.dist-info → torchrl_nightly-2025.7.18.dist-info}/RECORD +18 -18
- {torchrl_nightly-2025.7.15.dist-info → torchrl_nightly-2025.7.18.dist-info}/WHEEL +0 -0
- {torchrl_nightly-2025.7.15.dist-info → torchrl_nightly-2025.7.18.dist-info}/licenses/LICENSE +0 -0
- {torchrl_nightly-2025.7.15.dist-info → torchrl_nightly-2025.7.18.dist-info}/top_level.txt +0 -0
Binary file
|
torchrl/collectors/collectors.py
CHANGED
@@ -686,6 +686,10 @@ class SyncDataCollector(DataCollectorBase):
|
|
686
686
|
policy = RandomPolicy(env.full_action_spec)
|
687
687
|
elif policy_factory is not None:
|
688
688
|
raise TypeError("policy_factory cannot be used with policy argument.")
|
689
|
+
# If the underlying policy has a state_dict, we keep a reference to the policy and
|
690
|
+
# do all policy weight saving/loading through it
|
691
|
+
if hasattr(policy, "state_dict"):
|
692
|
+
self._policy_w_state_dict = policy
|
689
693
|
|
690
694
|
if trust_policy is None:
|
691
695
|
trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule))
|
@@ -1686,8 +1690,8 @@ class SyncDataCollector(DataCollectorBase):
|
|
1686
1690
|
else:
|
1687
1691
|
env_state_dict = OrderedDict()
|
1688
1692
|
|
1689
|
-
if hasattr(self
|
1690
|
-
policy_state_dict = self.
|
1693
|
+
if hasattr(self, "_policy_w_state_dict"):
|
1694
|
+
policy_state_dict = self._policy_w_state_dict.state_dict()
|
1691
1695
|
state_dict = OrderedDict(
|
1692
1696
|
policy_state_dict=policy_state_dict,
|
1693
1697
|
env_state_dict=env_state_dict,
|
@@ -1711,7 +1715,13 @@ class SyncDataCollector(DataCollectorBase):
|
|
1711
1715
|
if strict or "env_state_dict" in state_dict:
|
1712
1716
|
self.env.load_state_dict(state_dict["env_state_dict"], **kwargs)
|
1713
1717
|
if strict or "policy_state_dict" in state_dict:
|
1714
|
-
self
|
1718
|
+
if not hasattr(self, "_policy_w_state_dict"):
|
1719
|
+
raise ValueError(
|
1720
|
+
"Underlying policy does not have state_dict to load policy_state_dict into."
|
1721
|
+
)
|
1722
|
+
self._policy_w_state_dict.load_state_dict(
|
1723
|
+
state_dict["policy_state_dict"], **kwargs
|
1724
|
+
)
|
1715
1725
|
self._frames = state_dict["frames"]
|
1716
1726
|
self._iter = state_dict["iter"]
|
1717
1727
|
|
torchrl/data/llm/history.py
CHANGED
@@ -713,6 +713,42 @@ class History(TensorClass["nocast"]):
|
|
713
713
|
| transformers.AutoProcessor # noqa: F821
|
714
714
|
| None = None,
|
715
715
|
) -> History:
|
716
|
+
r"""Inverts a chat template into a History object.
|
717
|
+
|
718
|
+
Args:
|
719
|
+
text (str | list[str]): The chat template to invert.
|
720
|
+
chat_template_name (str, optional): The name of the chat template to use.
|
721
|
+
tokenizer (transformers.AutoTokenizer | transformers.AutoProcessor, optional): The tokenizer to use.
|
722
|
+
|
723
|
+
Returns:
|
724
|
+
History: The inverted History object.
|
725
|
+
|
726
|
+
Examples:
|
727
|
+
>>> from torchrl.data.llm.history import History
|
728
|
+
>>> from transformers import AutoTokenizer
|
729
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
|
730
|
+
>>> text = "<|im_start|>system\nYou are a helpful assistant.\n<|im_end|>\n<|im_start|>user\nWrite a python script that gives the capital of France or Germany.\n<|im_end|>\n<|im_start|>assistant\n<think>The capital of France is Paris, the capital of Germany is Berlin.</think>\n<answer><python>\n"
|
731
|
+
>>> history = History.from_text(text, tokenizer=tokenizer)
|
732
|
+
>>> print(history)
|
733
|
+
History(
|
734
|
+
content=NonTensorStack(
|
735
|
+
['You are a helpful assistant.', 'Write a python s...,
|
736
|
+
batch_size=torch.Size([3]),
|
737
|
+
device=None),
|
738
|
+
is_complete=NonTensorStack(
|
739
|
+
[True, True, False],
|
740
|
+
batch_size=torch.Size([3]),
|
741
|
+
device=None),
|
742
|
+
role=NonTensorStack(
|
743
|
+
['system', 'user', 'assistant'],
|
744
|
+
batch_size=torch.Size([3]),
|
745
|
+
device=None),
|
746
|
+
tool_calls=None,
|
747
|
+
tool_responses=None,
|
748
|
+
batch_size=torch.Size([3]),
|
749
|
+
device=None,
|
750
|
+
is_shared=False)
|
751
|
+
"""
|
716
752
|
if chat_template_name is None:
|
717
753
|
if chat_template is not None:
|
718
754
|
# TODO: find best match given template
|
torchrl/data/tensor_specs.py
CHANGED
@@ -4449,12 +4449,18 @@ class Binary(Categorical):
|
|
4449
4449
|
f"shape of the {self.__class__.__name__} spec in expand()."
|
4450
4450
|
)
|
4451
4451
|
return self.__class__(
|
4452
|
-
n=self.shape[-1]
|
4452
|
+
n=self.shape[-1] if len(self.shape) > 0 else None,
|
4453
|
+
shape=shape,
|
4454
|
+
device=self.device,
|
4455
|
+
dtype=self.dtype,
|
4453
4456
|
)
|
4454
4457
|
|
4455
4458
|
def _reshape(self, shape):
|
4456
4459
|
return self.__class__(
|
4457
|
-
n=self.shape[-1]
|
4460
|
+
n=self.shape[-1] if len(self.shape) > 0 else None,
|
4461
|
+
shape=shape,
|
4462
|
+
device=self.device,
|
4463
|
+
dtype=self.dtype,
|
4458
4464
|
)
|
4459
4465
|
|
4460
4466
|
def _unflatten(self, dim, sizes):
|
@@ -4464,7 +4470,10 @@ class Binary(Categorical):
|
|
4464
4470
|
.shape
|
4465
4471
|
)
|
4466
4472
|
return self.__class__(
|
4467
|
-
n=self.shape[-1]
|
4473
|
+
n=self.shape[-1] if len(self.shape) > 0 else None,
|
4474
|
+
shape=shape,
|
4475
|
+
device=self.device,
|
4476
|
+
dtype=self.dtype,
|
4468
4477
|
)
|
4469
4478
|
|
4470
4479
|
def squeeze(self, dim=None):
|
@@ -4472,13 +4481,19 @@ class Binary(Categorical):
|
|
4472
4481
|
if shape is None:
|
4473
4482
|
return self
|
4474
4483
|
return self.__class__(
|
4475
|
-
n=self.shape[-1]
|
4484
|
+
n=self.shape[-1] if len(self.shape) > 0 else None,
|
4485
|
+
shape=shape,
|
4486
|
+
device=self.device,
|
4487
|
+
dtype=self.dtype,
|
4476
4488
|
)
|
4477
4489
|
|
4478
4490
|
def unsqueeze(self, dim: int):
|
4479
4491
|
shape = _unsqueezed_shape(self.shape, dim)
|
4480
4492
|
return self.__class__(
|
4481
|
-
n=self.shape[-1]
|
4493
|
+
n=self.shape[-1] if len(self.shape) > 0 else None,
|
4494
|
+
shape=shape,
|
4495
|
+
device=self.device,
|
4496
|
+
dtype=self.dtype,
|
4482
4497
|
)
|
4483
4498
|
|
4484
4499
|
def unbind(self, dim: int = 0):
|
@@ -4495,7 +4510,10 @@ class Binary(Categorical):
|
|
4495
4510
|
shape = tuple(s for i, s in enumerate(self.shape) if i != dim)
|
4496
4511
|
return tuple(
|
4497
4512
|
self.__class__(
|
4498
|
-
n=self.shape[-1]
|
4513
|
+
n=self.shape[-1] if len(self.shape) > 0 else None,
|
4514
|
+
shape=shape,
|
4515
|
+
device=self.device,
|
4516
|
+
dtype=self.dtype,
|
4499
4517
|
)
|
4500
4518
|
for i in range(self.shape[dim])
|
4501
4519
|
)
|
@@ -4512,12 +4530,15 @@ class Binary(Categorical):
|
|
4512
4530
|
if dest_device == self.device and dest_dtype == self.dtype:
|
4513
4531
|
return self
|
4514
4532
|
return self.__class__(
|
4515
|
-
n=self.shape[-1]
|
4533
|
+
n=self.shape[-1] if len(self.shape) > 0 else None,
|
4534
|
+
shape=self.shape,
|
4535
|
+
device=dest_device,
|
4536
|
+
dtype=dest_dtype,
|
4516
4537
|
)
|
4517
4538
|
|
4518
4539
|
def clone(self) -> Binary:
|
4519
4540
|
return self.__class__(
|
4520
|
-
n=self.shape[-1],
|
4541
|
+
n=self.shape[-1] if len(self.shape) > 0 else None,
|
4521
4542
|
shape=self.shape,
|
4522
4543
|
device=self.device,
|
4523
4544
|
dtype=self.dtype,
|
@@ -4528,6 +4549,8 @@ class Binary(Categorical):
|
|
4528
4549
|
|
4529
4550
|
The last dimension of the spec (length n of the binary vector) cannot be indexed.
|
4530
4551
|
"""
|
4552
|
+
if not len(self.shape):
|
4553
|
+
raise ValueError("Cannot index a Binary spec with an empty shape")
|
4531
4554
|
indexed_shape = _shape_indexing(self.shape[:-1], idx)
|
4532
4555
|
return self.__class__(
|
4533
4556
|
n=self.shape[-1],
|
@@ -5533,8 +5556,10 @@ class Composite(TensorSpec):
|
|
5533
5556
|
sub_str = [
|
5534
5557
|
indent(f"{k}: {str(item)}", 4 * " ") for k, item in self._specs.items()
|
5535
5558
|
]
|
5559
|
+
if len(sub_str) == 0:
|
5560
|
+
return f"{self.__class__.__name__}(device={self._device}, shape={self.shape}, data_cls={self.data_cls})"
|
5536
5561
|
sub_str = ",\n".join(sub_str)
|
5537
|
-
return f"
|
5562
|
+
return f"{self.__class__.__name__}(\n{sub_str},\n device={self._device},\n shape={self.shape},\n data_cls={self.data_cls})"
|
5538
5563
|
|
5539
5564
|
def type_check(
|
5540
5565
|
self,
|
@@ -1211,7 +1211,6 @@ but got an object of type {type(transform)}."""
|
|
1211
1211
|
if tensordict is not None:
|
1212
1212
|
# We must avoid modifying the original tensordict so a shallow copy is necessary.
|
1213
1213
|
# We just select the input data and reset signal, which is all we need.
|
1214
|
-
self.transform.transform_input_spec(self.base_env.input_spec.unlock_())
|
1215
1214
|
tensordict = tensordict.select(
|
1216
1215
|
*self.reset_keys, *self.state_spec.keys(True, True), strict=False
|
1217
1216
|
)
|
@@ -352,7 +352,7 @@ class MaskedCategorical(D.Categorical):
|
|
352
352
|
logits = self.logits
|
353
353
|
if logits.ndim > 2:
|
354
354
|
# Bring channels in 2nd dim
|
355
|
-
logits = logits.
|
355
|
+
logits = logits.permute(0, -1, *range(1, logits.ndim - 1))
|
356
356
|
original_value_shape = None
|
357
357
|
if logits.ndim == 1 and value.ndim >= 1:
|
358
358
|
if value.ndim >= 2:
|
@@ -4,12 +4,13 @@
|
|
4
4
|
# LICENSE file in the root directory of this source tree.
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
|
+
import warnings
|
7
8
|
import weakref
|
8
9
|
from typing import Any, Literal, overload
|
9
10
|
|
10
11
|
import torch
|
11
|
-
from tensordict import NestedKey, TensorDictBase
|
12
|
-
from tensordict.nn import TensorDictModuleBase
|
12
|
+
from tensordict import lazy_stack, NestedKey, TensorDictBase
|
13
|
+
from tensordict.nn import TensorDictModuleBase
|
13
14
|
from tensordict.tensorclass import TensorClass
|
14
15
|
from tensordict.utils import _zip_strict
|
15
16
|
from torch import distributions as D
|
@@ -171,6 +172,39 @@ class ChatHistory(TensorClass["nocast"]):
|
|
171
172
|
step_mdp_static=True,
|
172
173
|
)
|
173
174
|
|
175
|
+
def __post_init__(self):
|
176
|
+
# Check that all history objects have one more batch dimension than the ChatHistory object
|
177
|
+
if self.prompt is not None:
|
178
|
+
if getattr(self.prompt, "batch_dims", None) == self.batch_dims:
|
179
|
+
warnings.warn(
|
180
|
+
"Prompt history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
|
181
|
+
f"got {self.prompt.batch_dims} and {self.batch_dims}. "
|
182
|
+
"The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
|
183
|
+
)
|
184
|
+
self.prompt = lazy_stack(
|
185
|
+
[self.prompt], -1
|
186
|
+
) # equivalent to unsqueeze(-1) but make sure it's a lazy stack
|
187
|
+
if self.response is not None:
|
188
|
+
if getattr(self.response, "batch_dims", None) == self.batch_dims:
|
189
|
+
warnings.warn(
|
190
|
+
"Response history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
|
191
|
+
f"got {self.response.batch_dims} and {self.batch_dims}. "
|
192
|
+
"The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
|
193
|
+
)
|
194
|
+
self.response = lazy_stack(
|
195
|
+
[self.response], -1
|
196
|
+
) # equivalent to unsqueeze(-1) but make sure it's a lazy stack
|
197
|
+
if self.full is not None:
|
198
|
+
if getattr(self.full, "batch_dims", None) == self.batch_dims:
|
199
|
+
warnings.warn(
|
200
|
+
"Full history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
|
201
|
+
f"got {self.full.batch_dims} and {self.batch_dims}. "
|
202
|
+
"The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
|
203
|
+
)
|
204
|
+
self.full = lazy_stack(
|
205
|
+
[self.full], -1
|
206
|
+
) # equivalent to unsqueeze(-1) but make sure it's a lazy stack
|
207
|
+
|
174
208
|
|
175
209
|
class LogProbs(TensorClass["nocast"]):
|
176
210
|
"""A log-probability container.
|
@@ -454,7 +488,7 @@ class LLMWrapperBase(TensorDictModuleBase):
|
|
454
488
|
"You can create a new version of this wrapper using the `get_new_version` method."
|
455
489
|
)
|
456
490
|
|
457
|
-
td_out = self(tensordict.copy())
|
491
|
+
td_out = self.forward(tensordict.copy(), logits_only=True)
|
458
492
|
|
459
493
|
# Get logits/log-probs
|
460
494
|
if as_padded_tensor is None:
|
@@ -529,7 +563,7 @@ class LLMWrapperBase(TensorDictModuleBase):
|
|
529
563
|
"get_dist_with_prompt_mask is not implemented for generate=True. "
|
530
564
|
"You can create a new version of this wrapper using the `get_new_version` method."
|
531
565
|
)
|
532
|
-
td_out = self(tensordict.copy())
|
566
|
+
td_out = self.forward(tensordict.copy(), logits_only=True)
|
533
567
|
|
534
568
|
# Try to get prompt tokens first
|
535
569
|
if self.pad_output:
|
@@ -640,7 +674,7 @@ class LLMWrapperBase(TensorDictModuleBase):
|
|
640
674
|
"get_dist_with_assistant_mask is not implemented for generate=True. "
|
641
675
|
"You can create a new version of this wrapper using the `get_new_version` method."
|
642
676
|
)
|
643
|
-
td_out = self(tensordict.copy())
|
677
|
+
td_out = self.forward(tensordict.copy(), logits_only=True)
|
644
678
|
# Update the tokens key to reflect the tokenized history when querying the log-probs
|
645
679
|
tensordict.update(
|
646
680
|
td_out,
|
@@ -709,7 +743,7 @@ class LLMWrapperBase(TensorDictModuleBase):
|
|
709
743
|
"get_dist_with_attention_mask is not implemented for generate=True. "
|
710
744
|
"You can create a new version of this wrapper using the `get_new_version` method."
|
711
745
|
)
|
712
|
-
td_out = self(tensordict.copy())
|
746
|
+
td_out = self.forward(tensordict.copy(), logits_only=True)
|
713
747
|
if self.pad_output:
|
714
748
|
logits = td_out.get(logits_key)
|
715
749
|
attention_mask = td_out.get(attention_mask_key)
|
@@ -766,7 +800,7 @@ class LLMWrapperBase(TensorDictModuleBase):
|
|
766
800
|
"get_dist_with_custom_mask is not implemented for generate=True. "
|
767
801
|
"You can create a new version of this wrapper using the `get_new_version` method."
|
768
802
|
)
|
769
|
-
td_out = self(tensordict.copy())
|
803
|
+
td_out = self.forward(tensordict.copy(), logits_only=True)
|
770
804
|
if self.pad_output:
|
771
805
|
logits = td_out.get(logits_key)
|
772
806
|
else:
|
@@ -813,8 +847,24 @@ class LLMWrapperBase(TensorDictModuleBase):
|
|
813
847
|
"""
|
814
848
|
return self._get_dist_with_attention_mask(tensordict, **kwargs)
|
815
849
|
|
816
|
-
|
817
|
-
|
850
|
+
def forward(
|
851
|
+
self,
|
852
|
+
tensordict: TensorDictBase,
|
853
|
+
*,
|
854
|
+
tensordict_out: TensorDictBase | None = None,
|
855
|
+
logits_only: bool = False,
|
856
|
+
**kwargs,
|
857
|
+
) -> TensorDictBase: # noqa: D417
|
858
|
+
"""Forward pass for the LLM policy.
|
859
|
+
|
860
|
+
Args:
|
861
|
+
tensordict (TensorDictBase): The input tensordict.
|
862
|
+
|
863
|
+
Keyword Args:
|
864
|
+
tensordict_out (TensorDictBase | None): The output tensordict.
|
865
|
+
logits_only (bool): Whether to return only the logits. Only effective if generate=False. Defaults to `False`.
|
866
|
+
"""
|
867
|
+
raise NotImplementedError
|
818
868
|
|
819
869
|
def _check_padded(self, val: torch.Tensor) -> torch.Tensor:
|
820
870
|
"""Check that a value is a padded tensor."""
|
@@ -13,6 +13,7 @@ from typing import Literal
|
|
13
13
|
import torch
|
14
14
|
from tensordict import (
|
15
15
|
lazy_stack,
|
16
|
+
LazyStackedTensorDict,
|
16
17
|
MetaData,
|
17
18
|
NonTensorStack,
|
18
19
|
set_list_to_stack,
|
@@ -468,19 +469,32 @@ class TransformersWrapper(LLMWrapperBase):
|
|
468
469
|
def forward(
|
469
470
|
self,
|
470
471
|
tensordict: TensorDictBase,
|
472
|
+
*,
|
471
473
|
tensordict_out: TensorDictBase | None = None,
|
474
|
+
logits_only: bool = False,
|
472
475
|
**kwargs,
|
473
476
|
) -> TensorDictBase:
|
477
|
+
tensordict_orig = tensordict
|
474
478
|
if not tensordict.ndim:
|
479
|
+
if tensordict_out is not None:
|
480
|
+
raise ValueError(
|
481
|
+
"tensordict_out must not be provided when tensordict.ndim == 0. If this is needed, "
|
482
|
+
"please submit an issue on github."
|
483
|
+
)
|
475
484
|
# unsqueeze - squeeze the input
|
476
|
-
|
477
|
-
return self(lazy_stack([tensordict])).squeeze(0)
|
478
|
-
except Exception as e:
|
479
|
-
raise RuntimeError(
|
480
|
-
f"Unsqueeze/squeeze failed. Inputs to {type(self).__name__} should ideally be 1 dimensional."
|
481
|
-
) from e
|
485
|
+
return self.forward(lazy_stack([tensordict]), logits_only=logits_only)[0]
|
482
486
|
elif tensordict.ndim > 1:
|
483
|
-
|
487
|
+
if tensordict_out is not None:
|
488
|
+
raise ValueError(
|
489
|
+
"tensordict_out must not be provided when tensordict.ndim > 1. If this is needed, "
|
490
|
+
"please submit an issue on github."
|
491
|
+
)
|
492
|
+
return self.forward(tensordict.reshape(-1), logits_only=logits_only).view(
|
493
|
+
tensordict.shape
|
494
|
+
)
|
495
|
+
|
496
|
+
if not isinstance(tensordict, LazyStackedTensorDict):
|
497
|
+
tensordict = tensordict.to_lazystack(0)
|
484
498
|
|
485
499
|
_source_device = None
|
486
500
|
if self._device:
|
@@ -517,17 +531,23 @@ class TransformersWrapper(LLMWrapperBase):
|
|
517
531
|
if self.generate:
|
518
532
|
out = self._from_transformers_generate_history(tensordict, cfg, out)
|
519
533
|
else:
|
520
|
-
out = self._from_transformers_logprobs_history(
|
534
|
+
out = self._from_transformers_logprobs_history(
|
535
|
+
tensordict, cfg, out, logits_only=logits_only
|
536
|
+
)
|
521
537
|
elif self.input_mode == "text":
|
522
538
|
if self.generate:
|
523
539
|
out = self._from_transformers_generate_text(tensordict, cfg, out)
|
524
540
|
else:
|
525
|
-
out = self._from_transformers_logprobs_text(
|
541
|
+
out = self._from_transformers_logprobs_text(
|
542
|
+
tensordict, cfg, out, logits_only=logits_only
|
543
|
+
)
|
526
544
|
elif self.input_mode == "tokens":
|
527
545
|
if self.generate:
|
528
546
|
out = self._from_transformers_generate_tokens(tensordict, cfg, out)
|
529
547
|
else:
|
530
|
-
out = self._from_transformers_logprobs_tokens(
|
548
|
+
out = self._from_transformers_logprobs_tokens(
|
549
|
+
tensordict, cfg, out, logits_only=logits_only
|
550
|
+
)
|
531
551
|
|
532
552
|
if _source_device:
|
533
553
|
out = out.to(_source_device)
|
@@ -535,7 +555,7 @@ class TransformersWrapper(LLMWrapperBase):
|
|
535
555
|
if tensordict_out is None:
|
536
556
|
if self.inplace is True:
|
537
557
|
# The output is the input
|
538
|
-
tensordict_out =
|
558
|
+
tensordict_out = tensordict_orig
|
539
559
|
elif self.inplace is False:
|
540
560
|
# The output is the new structure
|
541
561
|
tensordict_out = out
|
@@ -690,7 +710,7 @@ class TransformersWrapper(LLMWrapperBase):
|
|
690
710
|
result.set(self.history_key, history_chat)
|
691
711
|
return result
|
692
712
|
|
693
|
-
def _from_transformers_logprobs_history(self, td, cfg, out):
|
713
|
+
def _from_transformers_logprobs_history(self, td, cfg, out, logits_only=False):
|
694
714
|
"""Compute log-probs from history input."""
|
695
715
|
from torchrl.data.llm import History
|
696
716
|
|
@@ -731,7 +751,9 @@ class TransformersWrapper(LLMWrapperBase):
|
|
731
751
|
raise ValueError(
|
732
752
|
f"Expected TensorDictBase for history input, got {type(response_tokens)}"
|
733
753
|
)
|
734
|
-
result = self._logprobs_from_history_tokens(
|
754
|
+
result = self._logprobs_from_history_tokens(
|
755
|
+
response_tokens, cfg, out, logits_only=logits_only
|
756
|
+
)
|
735
757
|
text_result = Text._from_tensordict(result.empty())
|
736
758
|
result.set(self.text_key, text_result)
|
737
759
|
result[self.text_key, "full"] = text_full
|
@@ -952,7 +974,9 @@ class TransformersWrapper(LLMWrapperBase):
|
|
952
974
|
result = result.to(cast)
|
953
975
|
return result
|
954
976
|
|
955
|
-
def _logprobs_from_history_tokens(
|
977
|
+
def _logprobs_from_history_tokens(
|
978
|
+
self, response_tokens, cfg, out, logits_only=False
|
979
|
+
):
|
956
980
|
"""Compute log-probs from history tokens."""
|
957
981
|
pad_val = self.tokenizer.pad_token_id
|
958
982
|
|
@@ -996,6 +1020,7 @@ class TransformersWrapper(LLMWrapperBase):
|
|
996
1020
|
tokens_full_padded,
|
997
1021
|
attention_mask_full_padded,
|
998
1022
|
pad_val,
|
1023
|
+
logits_only=logits_only,
|
999
1024
|
)
|
1000
1025
|
|
1001
1026
|
# Build output TensorClass objects
|
@@ -1051,19 +1076,20 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1051
1076
|
tokens_obj.padded = MetaData(self.pad_output)
|
1052
1077
|
out.set(self.tokens_key, tokens_obj)
|
1053
1078
|
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
if self.pad_output:
|
1058
|
-
log_probs_obj.full = log_probs_full_padded
|
1059
|
-
else:
|
1060
|
-
log_probs_full_unpadded = _unpad_tensors(
|
1061
|
-
log_probs_full_padded, attention_mask_full_padded, as_nested=False
|
1079
|
+
if not logits_only:
|
1080
|
+
log_probs_obj = LogProbs._from_tensordict(
|
1081
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
1062
1082
|
)
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1083
|
+
if self.pad_output:
|
1084
|
+
log_probs_obj.full = log_probs_full_padded
|
1085
|
+
else:
|
1086
|
+
log_probs_full_unpadded = _unpad_tensors(
|
1087
|
+
log_probs_full_padded, attention_mask_full_padded, as_nested=False
|
1088
|
+
)
|
1089
|
+
log_probs_obj.full = log_probs_full_unpadded
|
1090
|
+
log_probs_obj.response = None
|
1091
|
+
log_probs_obj.padded = MetaData(self.pad_output)
|
1092
|
+
out.set(self.log_probs_key, log_probs_obj)
|
1067
1093
|
|
1068
1094
|
# Add logits to output if we're in a get_dist call
|
1069
1095
|
if self._in_get_dist_call:
|
@@ -1095,7 +1121,7 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1095
1121
|
raise ValueError(f"Expected list of text for text input, got {type(text)}")
|
1096
1122
|
return self._generate_from_text(text, cfg, out)
|
1097
1123
|
|
1098
|
-
def _from_transformers_logprobs_text(self, td, cfg, out):
|
1124
|
+
def _from_transformers_logprobs_text(self, td, cfg, out, logits_only=False):
|
1099
1125
|
"""Compute log-probs from text input."""
|
1100
1126
|
# Validate input
|
1101
1127
|
if self.input_key not in td:
|
@@ -1168,6 +1194,7 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1168
1194
|
input_ids_full_padded,
|
1169
1195
|
attention_mask_full_padded,
|
1170
1196
|
self.tokenizer.pad_token_id,
|
1197
|
+
logits_only=logits_only,
|
1171
1198
|
)
|
1172
1199
|
|
1173
1200
|
# Build output TensorClass objects
|
@@ -1212,19 +1239,20 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1212
1239
|
masks_obj.padded = MetaData(self.pad_output)
|
1213
1240
|
out.set(self.masks_key, masks_obj)
|
1214
1241
|
|
1215
|
-
|
1216
|
-
|
1217
|
-
|
1218
|
-
if self.pad_output:
|
1219
|
-
log_probs_obj.full = log_probs_full_padded
|
1220
|
-
else:
|
1221
|
-
log_probs_full_unpadded = _unpad_tensors(
|
1222
|
-
log_probs_full_padded, attention_mask_full_padded, as_nested=False
|
1242
|
+
if not logits_only:
|
1243
|
+
log_probs_obj = LogProbs._from_tensordict(
|
1244
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
1223
1245
|
)
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1246
|
+
if self.pad_output:
|
1247
|
+
log_probs_obj.full = log_probs_full_padded
|
1248
|
+
else:
|
1249
|
+
log_probs_full_unpadded = _unpad_tensors(
|
1250
|
+
log_probs_full_padded, attention_mask_full_padded, as_nested=False
|
1251
|
+
)
|
1252
|
+
log_probs_obj.full = log_probs_full_unpadded
|
1253
|
+
log_probs_obj.response = None
|
1254
|
+
log_probs_obj.padded = MetaData(self.pad_output)
|
1255
|
+
out.set(self.log_probs_key, log_probs_obj)
|
1228
1256
|
|
1229
1257
|
# Add logits to output if we're in a get_dist call
|
1230
1258
|
if self._in_get_dist_call:
|
@@ -1416,7 +1444,11 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1416
1444
|
return out
|
1417
1445
|
|
1418
1446
|
def _from_transformers_logprobs_tokens(
|
1419
|
-
self,
|
1447
|
+
self,
|
1448
|
+
td: TensorDictBase,
|
1449
|
+
cfg: dict | None,
|
1450
|
+
out: TensorDictBase,
|
1451
|
+
logits_only=False,
|
1420
1452
|
) -> TensorDictBase:
|
1421
1453
|
"""Compute log-probs from tokens input."""
|
1422
1454
|
# Validate input
|
@@ -1470,6 +1502,7 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1470
1502
|
input_ids_full_padded,
|
1471
1503
|
attention_mask_full_padded,
|
1472
1504
|
self.tokenizer.pad_token_id,
|
1505
|
+
logits_only=logits_only,
|
1473
1506
|
)
|
1474
1507
|
|
1475
1508
|
# Build output TensorClass objects
|
@@ -1514,19 +1547,20 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1514
1547
|
masks_obj.padded = MetaData(self.pad_output)
|
1515
1548
|
out.set(self.masks_key, masks_obj)
|
1516
1549
|
|
1517
|
-
|
1518
|
-
|
1519
|
-
|
1520
|
-
if self.pad_output:
|
1521
|
-
log_probs_obj.full = log_probs_full_padded
|
1522
|
-
else:
|
1523
|
-
log_probs_full_unpadded = _unpad_tensors(
|
1524
|
-
log_probs_full_padded, attention_mask_full_padded, as_nested=False
|
1550
|
+
if not logits_only:
|
1551
|
+
log_probs_obj = LogProbs._from_tensordict(
|
1552
|
+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
|
1525
1553
|
)
|
1526
|
-
|
1527
|
-
|
1528
|
-
|
1529
|
-
|
1554
|
+
if self.pad_output:
|
1555
|
+
log_probs_obj.full = log_probs_full_padded
|
1556
|
+
else:
|
1557
|
+
log_probs_full_unpadded = _unpad_tensors(
|
1558
|
+
log_probs_full_padded, attention_mask_full_padded, as_nested=False
|
1559
|
+
)
|
1560
|
+
log_probs_obj.full = log_probs_full_unpadded
|
1561
|
+
log_probs_obj.response = None
|
1562
|
+
log_probs_obj.padded = MetaData(self.pad_output)
|
1563
|
+
out.set(self.log_probs_key, log_probs_obj)
|
1530
1564
|
|
1531
1565
|
# Add logits to output if we're in a get_dist call
|
1532
1566
|
if self._in_get_dist_call:
|
@@ -1567,7 +1601,7 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1567
1601
|
return log_probs, logits
|
1568
1602
|
|
1569
1603
|
def _compute_log_probs_from_model_output(
|
1570
|
-
self, model_output, input_ids, attention_mask, pad_val
|
1604
|
+
self, model_output, input_ids, attention_mask, pad_val, logits_only=False
|
1571
1605
|
):
|
1572
1606
|
"""Compute log-probs from model output without modifying original tensors.
|
1573
1607
|
|
@@ -1576,6 +1610,7 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1576
1610
|
input_ids: Original input token ids
|
1577
1611
|
attention_mask: Original attention mask
|
1578
1612
|
pad_val: Padding token value to ignore in loss computation
|
1613
|
+
logits_only: Whether to return only the logits.
|
1579
1614
|
|
1580
1615
|
Returns:
|
1581
1616
|
tuple: (log_probs, shifted_logits) where log_probs are the computed log probabilities
|
@@ -1600,6 +1635,8 @@ class TransformersWrapper(LLMWrapperBase):
|
|
1600
1635
|
raise ValueError(
|
1601
1636
|
f"The logits shape {shifted_logits.shape} does not match the input ids shape {shifted_input_ids.shape}"
|
1602
1637
|
)
|
1638
|
+
if logits_only:
|
1639
|
+
return None, shifted_logits
|
1603
1640
|
|
1604
1641
|
# Compute log-probs
|
1605
1642
|
td = TensorDict(
|
@@ -11,6 +11,7 @@ from typing import Any, Literal
|
|
11
11
|
import torch
|
12
12
|
from tensordict import (
|
13
13
|
lazy_stack,
|
14
|
+
LazyStackedTensorDict,
|
14
15
|
MetaData,
|
15
16
|
NonTensorStack,
|
16
17
|
set_list_to_stack,
|
@@ -500,19 +501,32 @@ class vLLMWrapper(LLMWrapperBase):
|
|
500
501
|
def forward(
|
501
502
|
self,
|
502
503
|
tensordict: TensorDictBase,
|
504
|
+
*,
|
503
505
|
tensordict_out: TensorDictBase | None = None,
|
506
|
+
logits_only: bool = False,
|
504
507
|
**kwargs,
|
505
508
|
) -> TensorDictBase:
|
509
|
+
tensordict_orig = tensordict
|
506
510
|
if not tensordict.ndim:
|
511
|
+
if tensordict_out is not None:
|
512
|
+
raise ValueError(
|
513
|
+
"tensordict_out must not be provided when tensordict.ndim == 0. If this is needed, "
|
514
|
+
"please submit an issue on github."
|
515
|
+
)
|
507
516
|
# unsqueeze - squeeze the input
|
508
|
-
|
509
|
-
return self(lazy_stack([tensordict])).squeeze(0)
|
510
|
-
except Exception as e:
|
511
|
-
raise RuntimeError(
|
512
|
-
f"Unsqueeze/squeeze failed. Inputs to {type(self).__name__} should ideally be 1 dimensional."
|
513
|
-
) from e
|
517
|
+
return self.forward(lazy_stack([tensordict]), logits_only=logits_only)[0]
|
514
518
|
elif tensordict.ndim > 1:
|
515
|
-
|
519
|
+
if tensordict_out is not None:
|
520
|
+
raise ValueError(
|
521
|
+
"tensordict_out must not be provided when tensordict.ndim > 1. If this is needed, "
|
522
|
+
"please submit an issue on github."
|
523
|
+
)
|
524
|
+
return self.forward(tensordict.reshape(-1), logits_only=logits_only).view(
|
525
|
+
tensordict.shape
|
526
|
+
)
|
527
|
+
|
528
|
+
if not isinstance(tensordict, LazyStackedTensorDict):
|
529
|
+
tensordict = tensordict.to_lazystack(0)
|
516
530
|
|
517
531
|
_source_device = None
|
518
532
|
if self._device:
|
@@ -567,7 +581,7 @@ class vLLMWrapper(LLMWrapperBase):
|
|
567
581
|
if tensordict_out is None:
|
568
582
|
if self.inplace is True:
|
569
583
|
# The output is the input
|
570
|
-
tensordict_out =
|
584
|
+
tensordict_out = tensordict_orig
|
571
585
|
elif self.inplace is False:
|
572
586
|
# The output is the new structure
|
573
587
|
tensordict_out = out
|
@@ -1242,12 +1256,14 @@ class vLLMWrapper(LLMWrapperBase):
|
|
1242
1256
|
|
1243
1257
|
generate_kwargs = {"sampling_params": sampling_params}
|
1244
1258
|
args = ()
|
1259
|
+
empirical_attention_mask = None
|
1245
1260
|
|
1246
1261
|
if tokens_prompt_unpadded is None:
|
1247
1262
|
# TODO: To be on the safe side, we may do this even in the unpadded case since we're not sure
|
1248
1263
|
# the user passed an unpadded tensor in the first place.
|
1264
|
+
empirical_attention_mask = tokens_prompt_padded != self.padding_value
|
1249
1265
|
tokens_prompt_list = self._to_list(
|
1250
|
-
tokens_prompt_padded,
|
1266
|
+
tokens_prompt_padded, empirical_attention_mask
|
1251
1267
|
)
|
1252
1268
|
else:
|
1253
1269
|
tokens_prompt_list = self._to_list(tokens_prompt_unpadded, None)
|
@@ -1365,6 +1381,22 @@ class vLLMWrapper(LLMWrapperBase):
|
|
1365
1381
|
padding_value=self.padding_value,
|
1366
1382
|
padding_side="right",
|
1367
1383
|
)
|
1384
|
+
if (
|
1385
|
+
prompt_logprobs_padded.shape[-1]
|
1386
|
+
!= tokens_prompt_padded.shape[-1]
|
1387
|
+
):
|
1388
|
+
tshape = tokens_prompt_padded.shape
|
1389
|
+
oshape = prompt_logprobs_padded.shape
|
1390
|
+
# it could be that the input was padded already - padding again then
|
1391
|
+
prompt_logprobs_padded = torch.cat(
|
1392
|
+
[
|
1393
|
+
prompt_logprobs_padded.new_zeros(
|
1394
|
+
tshape[:-1] + (tshape[-1] - oshape[-1],)
|
1395
|
+
),
|
1396
|
+
prompt_logprobs_padded,
|
1397
|
+
],
|
1398
|
+
-1,
|
1399
|
+
)
|
1368
1400
|
else:
|
1369
1401
|
prompt_logprobs_list = request_output_tc.get(
|
1370
1402
|
"prompt_logprobs",
|
@@ -1490,26 +1522,21 @@ class vLLMWrapper(LLMWrapperBase):
|
|
1490
1522
|
|
1491
1523
|
request_output_tc = _RequestOutput_tc.from_request_output(tokens_out_stuct)
|
1492
1524
|
|
1525
|
+
# For unpadded case, extract from each sequence
|
1526
|
+
log_probs_full_unpadded = request_output_tc.get("prompt_logprobs", as_list=True)
|
1527
|
+
|
1493
1528
|
# Extract log-probs from prompt_logprobs
|
1494
1529
|
if self.pad_output:
|
1495
1530
|
# For padded case, use all prompt_logprobs
|
1496
|
-
|
1497
|
-
|
1498
|
-
|
1499
|
-
|
1500
|
-
padding_side="left",
|
1531
|
+
if attention_mask_full_padded is not None:
|
1532
|
+
attention_mask_full_padded = tokens_full_padded != self.padding_value
|
1533
|
+
log_probs_full_padded = torch.zeros_like(
|
1534
|
+
tokens_full_padded, dtype=torch.get_default_dtype()
|
1501
1535
|
)
|
1502
|
-
|
1503
|
-
|
1504
|
-
attention_mask_full_padded = tokens_full_padded != self.padding_value
|
1505
|
-
log_probs_full_padded = torch.where(
|
1506
|
-
attention_mask_full_padded, log_probs_full_padded, 0.0
|
1536
|
+
log_probs_full_padded[attention_mask_full_padded] = torch.cat(
|
1537
|
+
log_probs_full_unpadded, -1
|
1507
1538
|
)
|
1508
1539
|
else:
|
1509
|
-
# For unpadded case, extract from each sequence
|
1510
|
-
log_probs_full_unpadded = request_output_tc.get(
|
1511
|
-
"prompt_logprobs", as_list=True
|
1512
|
-
)
|
1513
1540
|
self._check_not_padded(log_probs_full_unpadded)
|
1514
1541
|
|
1515
1542
|
assistant_mask_full_padded = None
|
torchrl/objectives/a2c.py
CHANGED
@@ -70,7 +70,7 @@ class A2CLoss(LossModule):
|
|
70
70
|
samples will be used to compute this estimate.
|
71
71
|
Defaults to ``1``.
|
72
72
|
entropy_coeff (:obj:`float`): the weight of the entropy loss. Defaults to `0.01``.
|
73
|
-
|
73
|
+
critic_coeff (:obj:`float`): the weight of the critic loss. Defaults to ``1.0``. If ``None``, the critic
|
74
74
|
loss won't be included and the in-keys will miss the critic inputs.
|
75
75
|
loss_critic_type (str): loss function for the value discrepancy.
|
76
76
|
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
|
@@ -156,7 +156,7 @@ class A2CLoss(LossModule):
|
|
156
156
|
the expected keyword arguments are:
|
157
157
|
``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and critic.
|
158
158
|
The return value is a tuple of tensors in the following order:
|
159
|
-
``["loss_objective"]`` + ``["loss_critic"]`` if
|
159
|
+
``["loss_objective"]`` + ``["loss_critic"]`` if critic_coeff is not None + ``["entropy", "loss_entropy"]`` if entropy_bonus is True and critic_coeff is not None
|
160
160
|
|
161
161
|
Examples:
|
162
162
|
>>> import torch
|
@@ -277,8 +277,8 @@ class A2CLoss(LossModule):
|
|
277
277
|
*,
|
278
278
|
entropy_bonus: bool = True,
|
279
279
|
samples_mc_entropy: int = 1,
|
280
|
-
entropy_coeff: float =
|
281
|
-
|
280
|
+
entropy_coeff: float | None = None,
|
281
|
+
critic_coeff: float = 1.0,
|
282
282
|
loss_critic_type: str = "smooth_l1",
|
283
283
|
gamma: float | None = None,
|
284
284
|
separate_losses: bool = False,
|
@@ -291,13 +291,32 @@ class A2CLoss(LossModule):
|
|
291
291
|
clip_value: float | None = None,
|
292
292
|
**kwargs,
|
293
293
|
):
|
294
|
+
# Handle deprecated entropy_coef argument
|
294
295
|
if "entropy_coef" in kwargs:
|
296
|
+
if entropy_coeff is not None: # Check if entropy_coeff was explicitly set
|
297
|
+
raise ValueError(
|
298
|
+
"Cannot specify both 'entropy_coef' and 'entropy_coeff'"
|
299
|
+
)
|
295
300
|
warnings.warn(
|
296
301
|
"'entropy_coef' is deprecated and will be removed in torchrl v0.11. Please use 'entropy_coeff' instead.",
|
297
302
|
DeprecationWarning,
|
298
303
|
)
|
299
304
|
entropy_coeff = kwargs.pop("entropy_coef")
|
300
305
|
|
306
|
+
# Set default value if None
|
307
|
+
if entropy_coeff is None:
|
308
|
+
entropy_coeff = 0.01
|
309
|
+
|
310
|
+
# Handle deprecated critic_coef argument
|
311
|
+
if "critic_coef" in kwargs:
|
312
|
+
if critic_coeff != 1.0: # Check if critic_coeff was explicitly set
|
313
|
+
raise ValueError("Cannot specify both 'critic_coef' and 'critic_coeff'")
|
314
|
+
warnings.warn(
|
315
|
+
"'critic_coef' is deprecated and will be removed in torchrl v0.11. Please use 'critic_coeff' instead.",
|
316
|
+
DeprecationWarning,
|
317
|
+
)
|
318
|
+
critic_coeff = kwargs.pop("critic_coef")
|
319
|
+
|
301
320
|
if actor is not None:
|
302
321
|
actor_network = actor
|
303
322
|
del actor
|
@@ -349,12 +368,12 @@ class A2CLoss(LossModule):
|
|
349
368
|
self.register_buffer(
|
350
369
|
"entropy_coeff", torch.as_tensor(entropy_coeff, device=device)
|
351
370
|
)
|
352
|
-
if
|
371
|
+
if critic_coeff is not None:
|
353
372
|
self.register_buffer(
|
354
|
-
"
|
373
|
+
"critic_coeff", torch.as_tensor(critic_coeff, device=device)
|
355
374
|
)
|
356
375
|
else:
|
357
|
-
self.
|
376
|
+
self.critic_coeff = None
|
358
377
|
|
359
378
|
if gamma is not None:
|
360
379
|
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
|
@@ -399,7 +418,7 @@ class A2CLoss(LossModule):
|
|
399
418
|
*self.actor_network.in_keys,
|
400
419
|
*[("next", key) for key in self.actor_network.in_keys],
|
401
420
|
]
|
402
|
-
if self.
|
421
|
+
if self.critic_coeff is not None:
|
403
422
|
keys.extend(self.critic_network.in_keys)
|
404
423
|
return list(set(keys))
|
405
424
|
|
@@ -407,7 +426,7 @@ class A2CLoss(LossModule):
|
|
407
426
|
def out_keys(self):
|
408
427
|
if self._out_keys is None:
|
409
428
|
outs = ["loss_objective"]
|
410
|
-
if self.
|
429
|
+
if self.critic_coeff is not None:
|
411
430
|
outs.append("loss_critic")
|
412
431
|
if self.entropy_bonus:
|
413
432
|
outs.append("entropy")
|
@@ -478,7 +497,7 @@ class A2CLoss(LossModule):
|
|
478
497
|
return log_prob, dist
|
479
498
|
|
480
499
|
def loss_critic(self, tensordict: TensorDictBase) -> tuple[torch.Tensor, float]:
|
481
|
-
"""Returns the loss value of the critic, multiplied by ``
|
500
|
+
"""Returns the loss value of the critic, multiplied by ``critic_coeff`` if it is not ``None``.
|
482
501
|
|
483
502
|
Returns the loss and the clip-fraction.
|
484
503
|
|
@@ -539,8 +558,8 @@ class A2CLoss(LossModule):
|
|
539
558
|
"target_actor_network_params",
|
540
559
|
"target_critic_network_params",
|
541
560
|
)
|
542
|
-
if self.
|
543
|
-
return self.
|
561
|
+
if self.critic_coeff is not None:
|
562
|
+
return self.critic_coeff * loss_value, clip_fraction
|
544
563
|
return loss_value, clip_fraction
|
545
564
|
|
546
565
|
@property
|
@@ -568,7 +587,7 @@ class A2CLoss(LossModule):
|
|
568
587
|
entropy = self.get_entropy_bonus(dist)
|
569
588
|
td_out.set("entropy", entropy.detach().mean()) # for logging
|
570
589
|
td_out.set("loss_entropy", -self.entropy_coeff * entropy)
|
571
|
-
if self.
|
590
|
+
if self.critic_coeff is not None:
|
572
591
|
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
|
573
592
|
td_out.set("loss_critic", loss_critic)
|
574
593
|
if value_clip_fraction is not None:
|
torchrl/objectives/ppo.py
CHANGED
@@ -102,13 +102,13 @@ class PPOLoss(LossModule):
|
|
102
102
|
Defaults to ``1``.
|
103
103
|
entropy_coeff: scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
|
104
104
|
* **Scalar**: one value applied to the summed entropy of every action head.
|
105
|
-
* **Mapping** ``{head_name:
|
105
|
+
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
|
106
106
|
Defaults to ``0.01``.
|
107
107
|
log_explained_variance (bool, optional): if ``True``, the explained variance of the critic
|
108
108
|
predictions w.r.t. value targets will be computed and logged as ``"explained_variance"``.
|
109
109
|
This can help monitor critic quality during training. Best possible score is 1.0, lower values are worse. Defaults to ``True``.
|
110
|
-
|
111
|
-
loss. Defaults to ``1.0``. Set ``
|
110
|
+
critic_coeff (scalar, optional): critic loss multiplier when computing the total
|
111
|
+
loss. Defaults to ``1.0``. Set ``critic_coeff`` to ``None`` to exclude the value
|
112
112
|
loss from the forward outputs.
|
113
113
|
loss_critic_type (str, optional): loss function for the value discrepancy.
|
114
114
|
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
|
@@ -239,7 +239,7 @@ class PPOLoss(LossModule):
|
|
239
239
|
the expected keyword arguments are:
|
240
240
|
``["action", "sample_log_prob", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and value network.
|
241
241
|
The return value is a tuple of tensors in the following order:
|
242
|
-
``["loss_objective"]`` + ``["entropy", "loss_entropy"]`` if entropy_bonus is set + ``"loss_critic"`` if
|
242
|
+
``["loss_objective"]`` + ``["entropy", "loss_entropy"]`` if entropy_bonus is set + ``"loss_critic"`` if critic_coeff is not ``None``.
|
243
243
|
The output keys can also be filtered using :meth:`PPOLoss.select_out_keys` method.
|
244
244
|
|
245
245
|
Examples:
|
@@ -351,9 +351,9 @@ class PPOLoss(LossModule):
|
|
351
351
|
*,
|
352
352
|
entropy_bonus: bool = True,
|
353
353
|
samples_mc_entropy: int = 1,
|
354
|
-
entropy_coeff: float | Mapping[str, float] =
|
354
|
+
entropy_coeff: float | Mapping[str, float] | None = None,
|
355
355
|
log_explained_variance: bool = True,
|
356
|
-
|
356
|
+
critic_coeff: float | None = None,
|
357
357
|
loss_critic_type: str = "smooth_l1",
|
358
358
|
normalize_advantage: bool = False,
|
359
359
|
normalize_advantage_exclude_dims: tuple[int] = (),
|
@@ -377,13 +377,23 @@ class PPOLoss(LossModule):
|
|
377
377
|
critic_network = critic
|
378
378
|
del critic
|
379
379
|
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
380
|
+
# Handle deprecated critic_coef argument
|
381
|
+
if "critic_coef" in kwargs:
|
382
|
+
if critic_coeff is not None:
|
383
|
+
raise ValueError("Cannot specify both 'critic_coef' and 'critic_coeff'")
|
384
|
+
warnings.warn(
|
385
|
+
"'critic_coef' is deprecated and will be removed in torchrl v0.11. Please use 'critic_coeff' instead.",
|
386
|
+
DeprecationWarning,
|
387
|
+
)
|
388
|
+
critic_coeff = kwargs.pop("critic_coef")
|
389
|
+
|
390
|
+
if critic_coeff is None and critic_network is not None:
|
391
|
+
critic_coeff = 1.0
|
392
|
+
elif critic_coeff in (None, 0) and critic_network is not None:
|
393
|
+
critic_coeff = None
|
384
394
|
|
385
395
|
if actor_network is None or (
|
386
|
-
critic_network is None and
|
396
|
+
critic_network is None and critic_coeff not in (None, 0.0)
|
387
397
|
):
|
388
398
|
raise TypeError(
|
389
399
|
"Missing positional arguments actor_network or critic_network."
|
@@ -431,13 +441,21 @@ class PPOLoss(LossModule):
|
|
431
441
|
torch, "get_default_device", lambda: torch.device("cpu")
|
432
442
|
)()
|
433
443
|
|
434
|
-
# Handle deprecated
|
435
|
-
if "
|
444
|
+
# Handle deprecated entropy_coef argument
|
445
|
+
if "entropy_coef" in kwargs:
|
446
|
+
if entropy_coeff is not None: # Check if entropy_coeff was explicitly set
|
447
|
+
raise ValueError(
|
448
|
+
"Cannot specify both 'entropy_coef' and 'entropy_coeff'"
|
449
|
+
)
|
436
450
|
warnings.warn(
|
437
|
-
"'
|
451
|
+
"'entropy_coef' is deprecated and will be removed in torchrl v0.11. Please use 'entropy_coeff' instead.",
|
438
452
|
DeprecationWarning,
|
439
453
|
)
|
440
|
-
entropy_coeff = kwargs.pop("
|
454
|
+
entropy_coeff = kwargs.pop("entropy_coef")
|
455
|
+
|
456
|
+
# Set default value if None
|
457
|
+
if entropy_coeff is None:
|
458
|
+
entropy_coeff = 0.01
|
441
459
|
|
442
460
|
if isinstance(entropy_coeff, Mapping):
|
443
461
|
# Store the mapping for per-head coefficients
|
@@ -457,13 +475,13 @@ class PPOLoss(LossModule):
|
|
457
475
|
self._entropy_coeff_map = None
|
458
476
|
else:
|
459
477
|
raise TypeError("entropy_coeff must be a float or a Mapping[str, float]")
|
460
|
-
if
|
478
|
+
if critic_coeff is not None:
|
461
479
|
self.register_buffer(
|
462
|
-
"
|
480
|
+
"critic_coeff", torch.tensor(critic_coeff, device=device)
|
463
481
|
)
|
464
482
|
else:
|
465
|
-
self.
|
466
|
-
self._has_critic = bool(self.
|
483
|
+
self.critic_coeff = None
|
484
|
+
self._has_critic = bool(self.critic_coeff is not None and self.critic_coeff > 0)
|
467
485
|
self.loss_critic_type = loss_critic_type
|
468
486
|
self.normalize_advantage = normalize_advantage
|
469
487
|
self.normalize_advantage_exclude_dims = normalize_advantage_exclude_dims
|
@@ -692,7 +710,7 @@ class PPOLoss(LossModule):
|
|
692
710
|
def loss_critic(
|
693
711
|
self, tensordict: TensorDictBase
|
694
712
|
) -> tuple[torch.Tensor | TensorDict, ...]:
|
695
|
-
"""Returns the critic loss multiplied by ``
|
713
|
+
"""Returns the critic loss multiplied by ``critic_coeff``, if it is not ``None``."""
|
696
714
|
# TODO: if the advantage is gathered by forward, this introduces an
|
697
715
|
# overhead that we could easily reduce.
|
698
716
|
if self.separate_losses:
|
@@ -766,7 +784,7 @@ class PPOLoss(LossModule):
|
|
766
784
|
"target_critic_network_params",
|
767
785
|
)
|
768
786
|
if self._has_critic:
|
769
|
-
return self.
|
787
|
+
return self.critic_coeff * loss_value, clip_fraction, explained_variance
|
770
788
|
return loss_value, clip_fraction, explained_variance
|
771
789
|
|
772
790
|
@property
|
@@ -954,10 +972,10 @@ class ClipPPOLoss(PPOLoss):
|
|
954
972
|
Defaults to ``1``.
|
955
973
|
entropy_coeff: (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
|
956
974
|
* **Scalar**: one value applied to the summed entropy of every action head.
|
957
|
-
* **Mapping** ``{head_name:
|
975
|
+
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
|
958
976
|
Defaults to ``0.01``.
|
959
|
-
|
960
|
-
loss. Defaults to ``1.0``. Set ``
|
977
|
+
critic_coeff (scalar, optional): critic loss multiplier when computing the total
|
978
|
+
loss. Defaults to ``1.0``. Set ``critic_coeff`` to ``None`` to exclude the value
|
961
979
|
loss from the forward outputs.
|
962
980
|
loss_critic_type (str, optional): loss function for the value discrepancy.
|
963
981
|
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
|
@@ -1057,8 +1075,8 @@ class ClipPPOLoss(PPOLoss):
|
|
1057
1075
|
clip_epsilon: float = 0.2,
|
1058
1076
|
entropy_bonus: bool = True,
|
1059
1077
|
samples_mc_entropy: int = 1,
|
1060
|
-
entropy_coeff: float | Mapping[str, float] =
|
1061
|
-
|
1078
|
+
entropy_coeff: float | Mapping[str, float] | None = None,
|
1079
|
+
critic_coeff: float | None = None,
|
1062
1080
|
loss_critic_type: str = "smooth_l1",
|
1063
1081
|
normalize_advantage: bool = False,
|
1064
1082
|
normalize_advantage_exclude_dims: tuple[int] = (),
|
@@ -1079,7 +1097,7 @@ class ClipPPOLoss(PPOLoss):
|
|
1079
1097
|
entropy_bonus=entropy_bonus,
|
1080
1098
|
samples_mc_entropy=samples_mc_entropy,
|
1081
1099
|
entropy_coeff=entropy_coeff,
|
1082
|
-
|
1100
|
+
critic_coeff=critic_coeff,
|
1083
1101
|
loss_critic_type=loss_critic_type,
|
1084
1102
|
normalize_advantage=normalize_advantage,
|
1085
1103
|
normalize_advantage_exclude_dims=normalize_advantage_exclude_dims,
|
@@ -1247,9 +1265,9 @@ class KLPENPPOLoss(PPOLoss):
|
|
1247
1265
|
Defaults to ``1``.
|
1248
1266
|
entropy_coeff: scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
|
1249
1267
|
* **Scalar**: one value applied to the summed entropy of every action head.
|
1250
|
-
* **Mapping** ``{head_name:
|
1268
|
+
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
|
1251
1269
|
Defaults to ``0.01``.
|
1252
|
-
|
1270
|
+
critic_coeff (scalar, optional): critic loss multiplier when computing the total
|
1253
1271
|
loss. Defaults to ``1.0``.
|
1254
1272
|
loss_critic_type (str, optional): loss function for the value discrepancy.
|
1255
1273
|
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
|
@@ -1351,8 +1369,8 @@ class KLPENPPOLoss(PPOLoss):
|
|
1351
1369
|
samples_mc_kl: int = 1,
|
1352
1370
|
entropy_bonus: bool = True,
|
1353
1371
|
samples_mc_entropy: int = 1,
|
1354
|
-
entropy_coeff: float | Mapping[str, float] =
|
1355
|
-
|
1372
|
+
entropy_coeff: float | Mapping[str, float] | None = None,
|
1373
|
+
critic_coeff: float | None = None,
|
1356
1374
|
loss_critic_type: str = "smooth_l1",
|
1357
1375
|
normalize_advantage: bool = False,
|
1358
1376
|
normalize_advantage_exclude_dims: tuple[int] = (),
|
@@ -1369,7 +1387,7 @@ class KLPENPPOLoss(PPOLoss):
|
|
1369
1387
|
entropy_bonus=entropy_bonus,
|
1370
1388
|
samples_mc_entropy=samples_mc_entropy,
|
1371
1389
|
entropy_coeff=entropy_coeff,
|
1372
|
-
|
1390
|
+
critic_coeff=critic_coeff,
|
1373
1391
|
loss_critic_type=loss_critic_type,
|
1374
1392
|
normalize_advantage=normalize_advantage,
|
1375
1393
|
normalize_advantage_exclude_dims=normalize_advantage_exclude_dims,
|
@@ -86,7 +86,7 @@ class A2CLossConfig:
|
|
86
86
|
# Decay factor for return computation. Default=0.99.
|
87
87
|
entropy_coeff: float = 1e-3
|
88
88
|
# Entropy factor for the A2C loss
|
89
|
-
|
89
|
+
critic_coeff: float = 1.0
|
90
90
|
# Critic factor for the A2C loss
|
91
91
|
critic_loss_function: str = "smooth_l1"
|
92
92
|
# loss function for the value network. Either one of l1, l2 or smooth_l1 (default).
|
@@ -112,7 +112,7 @@ class PPOLossConfig:
|
|
112
112
|
# Number of samples to use for a Monte-Carlo estimate if the policy distribution has not closed formula.
|
113
113
|
loss_function: str = "smooth_l1"
|
114
114
|
# loss function for the value network. Either one of l1, l2 or smooth_l1 (default).
|
115
|
-
|
115
|
+
critic_coeff: float = 1.0
|
116
116
|
# Critic loss multiplier when computing the total loss.
|
117
117
|
|
118
118
|
# ClipPPOLoss parameters:
|
torchrl/version.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
__version__ = '2025.7.
|
2
|
-
git_version = '
|
1
|
+
__version__ = '2025.7.18'
|
2
|
+
git_version = '4001d9cb73cea4498b0fdfe420effc58a5a336be'
|
@@ -3,11 +3,11 @@ build_tools/setup_helpers/__init__.py,sha256=7l8TvVqxKezgzKCLuRv20mvGLloprFVZYm8
|
|
3
3
|
build_tools/setup_helpers/extension.py,sha256=4-PDLr-pw40bJnd9SfxnTaSjUyuXU_Tg8yOg69Kl0o4,5914
|
4
4
|
torchrl/__init__.py,sha256=mhDBx2UIuBKc0gmi8dVNHokQ6tCbIovruZmyAxcSsy8,2938
|
5
5
|
torchrl/_extension.py,sha256=z7wQ8i1iYWYcnygq_j0nq9sT-koY13tfHhTLNbMk17Q,2353
|
6
|
-
torchrl/_torchrl.cpython-313-darwin.so,sha256=
|
6
|
+
torchrl/_torchrl.cpython-313-darwin.so,sha256=xjRDwIjMdFaN3InficNo2b0v9ju7g-ONAgZ0DcGbk38,1692464
|
7
7
|
torchrl/_utils.py,sha256=Cw5EG6x5oSZF1iE3YCs1a32VUKp0rTXIs2u67q9zKUI,41078
|
8
|
-
torchrl/version.py,sha256=
|
8
|
+
torchrl/version.py,sha256=MHs4CxNjQupYI_f84bY7dOAAfPSU9yN6TOyxiS7tS8c,83
|
9
9
|
torchrl/collectors/__init__.py,sha256=hJ3JD6shRku0BL6SzJQq44FZ5Q1RGR8LealFyU3FRn4,799
|
10
|
-
torchrl/collectors/collectors.py,sha256=
|
10
|
+
torchrl/collectors/collectors.py,sha256=HpaW-y0bQOaOql8_7VyEPJ084CulrVwn6iBpGYoHyH4,178287
|
11
11
|
torchrl/collectors/utils.py,sha256=MlXrkYuDmV0Em-tVNQiLL32FWgPNDgceYYG_GgpiviA,11320
|
12
12
|
torchrl/collectors/weight_update.py,sha256=nSIfs8ALsfggLoC2ylg1oOAqdGku1tt4e-50JCZJBww,21073
|
13
13
|
torchrl/collectors/distributed/__init__.py,sha256=_24P0ALFunLhL-ls7EsssGUhJkZ_m3nw7krfMTwPqS0,705
|
@@ -25,7 +25,7 @@ torchrl/collectors/llm/weight_update/__init__.py,sha256=bKjvD7yZG5VnHgvYc4EmKI1s
|
|
25
25
|
torchrl/collectors/llm/weight_update/vllm.py,sha256=slKUmrIo4eL6R4J1oEnmlP6Q7Zer09p92JU8zbIHFUM,11515
|
26
26
|
torchrl/data/__init__.py,sha256=oowsio6ZUOZnJV8JV43xgs17B37XO1yKAYIQPdk8yt0,4819
|
27
27
|
torchrl/data/rlhf.py,sha256=JUmdYBWgkN229DwpXuDrhy9ddjduNvU2kyHzHR6MoA0,963
|
28
|
-
torchrl/data/tensor_specs.py,sha256=
|
28
|
+
torchrl/data/tensor_specs.py,sha256=RlMckj6PJo9MQMzneHzbcVe9xUyMB_n7pnSz0jytB9s,253907
|
29
29
|
torchrl/data/utils.py,sha256=attuNwzfgjszyp0lJSrV06f2peX3r0qTjRZWEwfl6Yg,12108
|
30
30
|
torchrl/data/datasets/__init__.py,sha256=NQpXsHecbZmza8AocX9mkqQQNkdFzeUrMTZoi6hbbU4,733
|
31
31
|
torchrl/data/datasets/atari_dqn.py,sha256=3ij6-UGfKev-QJuUEhZEEmn_3yL210CqKJALaFvlc5M,40739
|
@@ -41,7 +41,7 @@ torchrl/data/datasets/utils.py,sha256=nAFDTlBIPyEoPoJC-Hc_fcOhzE7UZQE4BwKxq15Vhv
|
|
41
41
|
torchrl/data/datasets/vd4rl.py,sha256=z90MqrxKzod8TPGK0uzkC6vw5wQIE4cgrDAC4e72jyk,18262
|
42
42
|
torchrl/data/llm/__init__.py,sha256=B4Ekok-w5PMiWcfmAGXaseaN6hWdNOr4WebeLrHfBVQ,975
|
43
43
|
torchrl/data/llm/dataset.py,sha256=t-41hAzQcjrdoKwpHIMbcrT7pRcQ7DHl2a1-lr6E7W4,20703
|
44
|
-
torchrl/data/llm/history.py,sha256=
|
44
|
+
torchrl/data/llm/history.py,sha256=l9JSxIO5eLUFwHH5IZkANSrByYa8BGmtxMlNXYf2fbs,59640
|
45
45
|
torchrl/data/llm/prompt.py,sha256=bg5LzJfwOq5Ns72KQMciIprMWAmDDinzdopwdopU04c,8380
|
46
46
|
torchrl/data/llm/reward.py,sha256=FbPchNXG3smJV9NCbB5Yk4grsCa2Se4KZ_tojVLKWQM,8404
|
47
47
|
torchrl/data/llm/topk.py,sha256=mYXCgJS4TuEVLZfTNccQd6kmC858AAh2Ygy0q_K1hlY,8365
|
@@ -131,7 +131,7 @@ torchrl/envs/transforms/llm.py,sha256=rQDzuut807wvFpSPCm5tynt8-cMKTgVKVjSVu9D99P
|
|
131
131
|
torchrl/envs/transforms/r3m.py,sha256=sdTVLpnxHfzFVo5rO8WnXf2uUg9cr4LBOLBsWaFgGT8,13478
|
132
132
|
torchrl/envs/transforms/rb_transforms.py,sha256=6ohnKXHHAEh2Hz3Seaw6eDrcFMu-1IVQrT7RVywh3YQ,7447
|
133
133
|
torchrl/envs/transforms/rlhf.py,sha256=lOVXYqQaoDfm4_n77Dxw_wjicBpMtDvavKmBIK2N3lU,628
|
134
|
-
torchrl/envs/transforms/transforms.py,sha256=
|
134
|
+
torchrl/envs/transforms/transforms.py,sha256=cDv_NxElzTOW8qQO-2krvOBmlKVGPOKMfqM6XyuLckU,482882
|
135
135
|
torchrl/envs/transforms/utils.py,sha256=7ToVFnD4-DkOMtML91g4bqXeY0bZ-gmCaSLxC93oaKM,3264
|
136
136
|
torchrl/envs/transforms/vc1.py,sha256=mho5BvdAK-f9hD9t-iah52wT2B06qPmaJO7chrfIOWY,10534
|
137
137
|
torchrl/envs/transforms/vecnorm.py,sha256=XahMcWvK3zjOB6EACSZtJ6UMP3yQ2zD9xf87UEB37Eg,34047
|
@@ -139,7 +139,7 @@ torchrl/envs/transforms/vip.py,sha256=kmygbenw75rEYsKRq4X1hzEH_CRe1406NZZ8Hg2R_V
|
|
139
139
|
torchrl/modules/__init__.py,sha256=XlAO0hulhDQNcKhbu3cFi8KJOHXNiAgmXeTfny0WBqE,4157
|
140
140
|
torchrl/modules/distributions/__init__.py,sha256=RDFoYD9IX1FhwXk5R4M8khq42gdTOcVnUnKHfWCTZBQ,1597
|
141
141
|
torchrl/modules/distributions/continuous.py,sha256=VPBugDuavJmyZ-RzemyLIFA02UCMLsm-rzBQrKcTlIA,25667
|
142
|
-
torchrl/modules/distributions/discrete.py,sha256=
|
142
|
+
torchrl/modules/distributions/discrete.py,sha256=7UE6X8LeTZkaTRFvKNcFSOoug_tOcD_u-FOh-39ZSC4,35581
|
143
143
|
torchrl/modules/distributions/truncated_normal.py,sha256=-qM8vwxTzv3VsWphZwcueDQpHQ67IRnkDFKlTDkQQnY,5937
|
144
144
|
torchrl/modules/distributions/utils.py,sha256=kXRvNHeKUePIgKgn7DnKqbhQ6ImFGgkFVRxITX2dwNU,7567
|
145
145
|
torchrl/modules/llm/__init__.py,sha256=BTkn-8QKp_8sW_NTKP02yoWSJUsX0XL6L9chTJl6epc,737
|
@@ -147,9 +147,9 @@ torchrl/modules/llm/utils.py,sha256=gf_F-4bEMwkcI3jLQM7ifB7nsjRctGebB5E2c-AznO0,
|
|
147
147
|
torchrl/modules/llm/backends/__init__.py,sha256=WdVy9EdiAfk8i5zFa49TEkRvcUd0L4Un4v6wqWBy8l8,438
|
148
148
|
torchrl/modules/llm/backends/vllm.py,sha256=x57Xop1xd5ZShicsh47ZFmz4VpfZ3eCzVx7k0COvpqQ,9387
|
149
149
|
torchrl/modules/llm/policies/__init__.py,sha256=nfZ2mcVuucxnY3WCuzIQrTLIf1yEd36k8-AlvwnSa8Y,545
|
150
|
-
torchrl/modules/llm/policies/common.py,sha256=
|
151
|
-
torchrl/modules/llm/policies/transformers_wrapper.py,sha256=
|
152
|
-
torchrl/modules/llm/policies/vllm_wrapper.py,sha256=
|
150
|
+
torchrl/modules/llm/policies/common.py,sha256=Kvn1cJQbp1EZtxWpAQ50TzZkwVtLAmryqiBHH2nK_wM,39112
|
151
|
+
torchrl/modules/llm/policies/transformers_wrapper.py,sha256=oi-2KALM0pkH-u-Kd6WlnxfH9eGV2GzBqM410ANpPeM,75777
|
152
|
+
torchrl/modules/llm/policies/vllm_wrapper.py,sha256=ReBvi2M9IAiwwBAR7GpDLSQhX0aC-dXPnHYb082Q0To,79632
|
153
153
|
torchrl/modules/models/__init__.py,sha256=DrOG-7hynjjUh_tc2EqysiUiNMRiDR0WLtZql9TPNcI,1743
|
154
154
|
torchrl/modules/models/batchrenorm.py,sha256=TojpTUluIcFdTSemIVRLGtB2O5q54mRHy3vJP6DuI5I,4750
|
155
155
|
torchrl/modules/models/decision_transformer.py,sha256=Lttf_wZMNqXbB_vpxMYgEp18gEzOvm3NvMnxQkHkH4M,6604
|
@@ -176,7 +176,7 @@ torchrl/modules/utils/__init__.py,sha256=KXaF_xEghKSPsNg0JyfxChK6KWHFRy0lwkL2Rip
|
|
176
176
|
torchrl/modules/utils/mappings.py,sha256=VMYrPxDk1ywgl2l_f6HXZaRsVOKcYR7VF5DNkmi3lHk,362
|
177
177
|
torchrl/modules/utils/utils.py,sha256=WPfcE-AoemnrP7Ny4FxJ-_LoQsBnX-y77Zb7MnZjXV0,2916
|
178
178
|
torchrl/objectives/__init__.py,sha256=pnprzIXA6E9Ph7isYgNLh4SFTU0pxIQg4oUNcaQ6doc,2148
|
179
|
-
torchrl/objectives/a2c.py,sha256=
|
179
|
+
torchrl/objectives/a2c.py,sha256=_xdp8D2ErOPyHwpxqPHtUr-EvZw7MqcuhhK9Isnewgo,28791
|
180
180
|
torchrl/objectives/common.py,sha256=40inZ0z3bFdQUkXuup3PWP_KmCx1m13cKTksjOp_b6I,28571
|
181
181
|
torchrl/objectives/cql.py,sha256=8faIZmA9e65NQ39HAi6torMofr98bkngjtBXm0UbnVM,54925
|
182
182
|
torchrl/objectives/crossq.py,sha256=a_vAjET5GG-2U7zZDgMnA0QP1iPCtv2ho6q-XvvLsnc,28858
|
@@ -188,7 +188,7 @@ torchrl/objectives/dreamer.py,sha256=vIJQN91oPXYnPubDFQpaF5d3fR_WwIYuIVYtoCvw0TY
|
|
188
188
|
torchrl/objectives/functional.py,sha256=ZaglBjEGuOTNGeFA-Ox-ugZVcNegQMUj--KWHDRBmaU,2106
|
189
189
|
torchrl/objectives/gail.py,sha256=0m34XmcN-EDk5OfNIo5bKYbKKZfATsYRv4zQe3v2UwA,9576
|
190
190
|
torchrl/objectives/iql.py,sha256=1jvlSznWke6NZSwfuYyHVnVBE7Cz3q169GnCRC7iel4,42991
|
191
|
-
torchrl/objectives/ppo.py,sha256=
|
191
|
+
torchrl/objectives/ppo.py,sha256=0soC2aiCOFNM5hCL20-99LX_NZi6XIXDmG2IkGEHSek,76082
|
192
192
|
torchrl/objectives/redq.py,sha256=4usM-nG2UWujeL-VEqzf7-uOwRFx6itkKCeitKuJhtw,28507
|
193
193
|
torchrl/objectives/reinforce.py,sha256=ySXLp5C-OOUYayqjrf4taQmL8LgRvMgPCgHDsle8JDc,22339
|
194
194
|
torchrl/objectives/sac.py,sha256=Oq9Iq90s9KFbnM4KSRUd2onU1JfW6aW80LWGdtO0CY8,63993
|
@@ -219,12 +219,12 @@ torchrl/trainers/helpers/__init__.py,sha256=HhDB2Ubq2gZodV-hB6xw4ZgCgwaZKUoZgOfV
|
|
219
219
|
torchrl/trainers/helpers/collectors.py,sha256=NjMMvGWEe4TWkVXzx7AlJ_Qa_AxEzMl6EUmEgUzHkoE,18715
|
220
220
|
torchrl/trainers/helpers/envs.py,sha256=1yqJZgz7mc5wa58HmSDGpPQINeDHFZB0_KTgwdKm9QE,22084
|
221
221
|
torchrl/trainers/helpers/logger.py,sha256=FtuEiLnK4NmxVVNyEEWaoCu3nG7WbNpHP3UYGQRJmgo,1278
|
222
|
-
torchrl/trainers/helpers/losses.py,sha256=
|
222
|
+
torchrl/trainers/helpers/losses.py,sha256=sHlJqjh02t8cKN73X35Azd_OoWGurohLuviB8Yeo4JQ,5272
|
223
223
|
torchrl/trainers/helpers/models.py,sha256=ihTERG2c96E8cS3Tnul6a_ys6iDEEJmHh05p9blQTW8,21807
|
224
224
|
torchrl/trainers/helpers/replay_buffer.py,sha256=ZUZHOa0TILyeWJ3iahzTJ6UvMl_0FdxuZfJEja94Bn8,2001
|
225
225
|
torchrl/trainers/helpers/trainers.py,sha256=j6B5XA7_FFHMQeOIQwjNcO0CGE_4mZKUC9_jH_iqqh4,12071
|
226
|
-
torchrl_nightly-2025.7.
|
227
|
-
torchrl_nightly-2025.7.
|
228
|
-
torchrl_nightly-2025.7.
|
229
|
-
torchrl_nightly-2025.7.
|
230
|
-
torchrl_nightly-2025.7.
|
226
|
+
torchrl_nightly-2025.7.18.dist-info/licenses/LICENSE,sha256=xdjS4_xk-IwnLuIFCvTYTl9Y8aXRejqpmke3dGam_nI,1098
|
227
|
+
torchrl_nightly-2025.7.18.dist-info/METADATA,sha256=K_Nmn84sw1xeD28lqIPqdhLjdaFchSMXuG2vjAajTn0,42990
|
228
|
+
torchrl_nightly-2025.7.18.dist-info/WHEEL,sha256=A6iggJuFsuu67bHdjxJADhwSEJmqwgO3xFoNCIwjOxc,115
|
229
|
+
torchrl_nightly-2025.7.18.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
|
230
|
+
torchrl_nightly-2025.7.18.dist-info/RECORD,,
|
File without changes
|
{torchrl_nightly-2025.7.15.dist-info → torchrl_nightly-2025.7.18.dist-info}/licenses/LICENSE
RENAMED
File without changes
|
File without changes
|