torchrl-nightly 2025.7.19__cp312-cp312-macosx_10_13_universal2.whl → 2025.7.21__cp312-cp312-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrl/_extension.py +9 -5
- torchrl/_torchrl.cpython-312-darwin.so +0 -0
- torchrl/objectives/sac.py +152 -8
- torchrl/version.py +3 -2
- {torchrl_nightly-2025.7.19.dist-info → torchrl_nightly-2025.7.21.dist-info}/METADATA +1 -1
- {torchrl_nightly-2025.7.19.dist-info → torchrl_nightly-2025.7.21.dist-info}/RECORD +9 -9
- {torchrl_nightly-2025.7.19.dist-info → torchrl_nightly-2025.7.21.dist-info}/WHEEL +0 -0
- {torchrl_nightly-2025.7.19.dist-info → torchrl_nightly-2025.7.21.dist-info}/licenses/LICENSE +0 -0
- {torchrl_nightly-2025.7.19.dist-info → torchrl_nightly-2025.7.21.dist-info}/top_level.txt +0 -0
torchrl/_extension.py
CHANGED
@@ -10,9 +10,10 @@ import warnings
|
|
10
10
|
from packaging.version import parse
|
11
11
|
|
12
12
|
try:
|
13
|
-
from .version import __version__
|
13
|
+
from .version import __version__, pytorch_version
|
14
14
|
except ImportError:
|
15
15
|
__version__ = None
|
16
|
+
pytorch_version = "unknown"
|
16
17
|
|
17
18
|
|
18
19
|
def is_module_available(*modules: str) -> bool:
|
@@ -47,13 +48,16 @@ if _is_nightly(__version__):
|
|
47
48
|
" - make sure ninja and cmake were installed\n"
|
48
49
|
" - make sure you ran `python setup.py clean && python setup.py develop` and that no error was raised\n"
|
49
50
|
" - make sure the version of PyTorch you are using matches the one that was present in your virtual env during "
|
50
|
-
"setup."
|
51
|
+
f"setup. This package was built with PyTorch {pytorch_version}."
|
51
52
|
)
|
52
53
|
|
53
54
|
else:
|
54
55
|
EXTENSION_WARNING = (
|
55
56
|
"Failed to import torchrl C++ binaries. Some modules (eg, prioritized replay buffers) may not work with your installation. "
|
56
|
-
"This is likely due to a discrepancy between your package version and the PyTorch version.
|
57
|
-
"
|
58
|
-
"
|
57
|
+
"This is likely due to a discrepancy between your package version and the PyTorch version. "
|
58
|
+
"TorchRL does not tightly pin PyTorch versions to give users freedom, but the trade-off is that C++ extensions like "
|
59
|
+
"prioritized replay buffers can only be used with the PyTorch version they were built against. "
|
60
|
+
f"This package was built with PyTorch {pytorch_version}. "
|
61
|
+
"Workarounds include: (1) upgrading/downgrading PyTorch or TorchRL to compatible versions, "
|
62
|
+
"or (2) making a local install using `pip install git+https://github.com/pytorch/rl.git@<version>`."
|
59
63
|
)
|
Binary file
|
torchrl/objectives/sac.py
CHANGED
@@ -687,6 +687,102 @@ class SACLoss(LossModule):
|
|
687
687
|
)
|
688
688
|
return self._alpha * log_prob - min_q_logprob, {"log_prob": log_prob.detach()}
|
689
689
|
|
690
|
+
@dispatch
|
691
|
+
def actor_loss(
|
692
|
+
self, tensordict: TensorDictBase
|
693
|
+
) -> tuple[Tensor, dict[str, Tensor]]:
|
694
|
+
"""Compute the actor loss for SAC.
|
695
|
+
|
696
|
+
This method computes the actor loss which encourages the policy to maximize
|
697
|
+
the expected Q-value while maintaining high entropy.
|
698
|
+
|
699
|
+
Args:
|
700
|
+
tensordict (TensorDictBase): A tensordict containing the data needed for
|
701
|
+
computing the actor loss. Should contain the observation and other
|
702
|
+
required keys for the actor network.
|
703
|
+
|
704
|
+
Returns:
|
705
|
+
A tuple containing:
|
706
|
+
- The actor loss tensor
|
707
|
+
- A dictionary with metadata including the log probability of actions
|
708
|
+
"""
|
709
|
+
return self._actor_loss(tensordict)
|
710
|
+
|
711
|
+
@dispatch
|
712
|
+
def qvalue_loss(
|
713
|
+
self, tensordict: TensorDictBase
|
714
|
+
) -> tuple[Tensor, dict[str, Tensor]]:
|
715
|
+
"""Compute the Q-value loss for SAC.
|
716
|
+
|
717
|
+
This method computes the Q-value loss which trains the Q-networks to estimate
|
718
|
+
the expected return for state-action pairs.
|
719
|
+
|
720
|
+
Args:
|
721
|
+
tensordict (TensorDictBase): A tensordict containing the data needed for
|
722
|
+
computing the Q-value loss. Should contain the observation, action,
|
723
|
+
reward, done, and terminated keys.
|
724
|
+
|
725
|
+
Returns:
|
726
|
+
A tuple containing:
|
727
|
+
- The Q-value loss tensor
|
728
|
+
- A dictionary with metadata including the TD error
|
729
|
+
"""
|
730
|
+
if self._version == 1:
|
731
|
+
return self._qvalue_v1_loss(tensordict)
|
732
|
+
else:
|
733
|
+
return self._qvalue_v2_loss(tensordict)
|
734
|
+
|
735
|
+
@dispatch
|
736
|
+
def value_loss(
|
737
|
+
self, tensordict: TensorDictBase
|
738
|
+
) -> tuple[Tensor, dict[str, Tensor]]:
|
739
|
+
"""Compute the value loss for SAC (version 1 only).
|
740
|
+
|
741
|
+
This method computes the value loss which trains the value network to estimate
|
742
|
+
the expected return for states. This is only used in SAC version 1.
|
743
|
+
|
744
|
+
Args:
|
745
|
+
tensordict (TensorDictBase): A tensordict containing the data needed for
|
746
|
+
computing the value loss. Should contain the observation and other
|
747
|
+
required keys for the value network.
|
748
|
+
|
749
|
+
Returns:
|
750
|
+
A tuple containing:
|
751
|
+
- The value loss tensor
|
752
|
+
- An empty dictionary (no metadata for value loss)
|
753
|
+
|
754
|
+
Raises:
|
755
|
+
RuntimeError: If called on SAC version 2 (which doesn't use a value network)
|
756
|
+
"""
|
757
|
+
if self._version != 1:
|
758
|
+
raise RuntimeError(
|
759
|
+
"Value loss is only available in SAC version 1. "
|
760
|
+
"SAC version 2 doesn't use a separate value network."
|
761
|
+
)
|
762
|
+
return self._value_loss(tensordict)
|
763
|
+
|
764
|
+
def alpha_loss(self, log_prob: Tensor) -> Tensor:
|
765
|
+
"""Compute the alpha loss for SAC.
|
766
|
+
|
767
|
+
This method computes the alpha loss which adapts the entropy coefficient
|
768
|
+
to maintain the target entropy level.
|
769
|
+
|
770
|
+
Args:
|
771
|
+
log_prob (Tensor): The log probability of actions from the actor network.
|
772
|
+
|
773
|
+
Returns:
|
774
|
+
The alpha loss tensor
|
775
|
+
"""
|
776
|
+
return self._alpha_loss(log_prob)
|
777
|
+
|
778
|
+
@property
|
779
|
+
def _alpha(self):
|
780
|
+
if self.min_log_alpha is not None or self.max_log_alpha is not None:
|
781
|
+
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
|
782
|
+
with torch.no_grad():
|
783
|
+
alpha = self.log_alpha.exp()
|
784
|
+
return alpha
|
785
|
+
|
690
786
|
@property
|
691
787
|
@_cache_values
|
692
788
|
def _cached_target_params_actor_value(self):
|
@@ -882,14 +978,6 @@ class SACLoss(LossModule):
|
|
882
978
|
alpha_loss = torch.zeros_like(log_prob)
|
883
979
|
return alpha_loss
|
884
980
|
|
885
|
-
@property
|
886
|
-
def _alpha(self):
|
887
|
-
if self.min_log_alpha is not None or self.max_log_alpha is not None:
|
888
|
-
self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
|
889
|
-
with torch.no_grad():
|
890
|
-
alpha = self.log_alpha.exp()
|
891
|
-
return alpha
|
892
|
-
|
893
981
|
|
894
982
|
class DiscreteSACLoss(LossModule):
|
895
983
|
"""Discrete SAC Loss module.
|
@@ -1352,6 +1440,48 @@ class DiscreteSACLoss(LossModule):
|
|
1352
1440
|
target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
|
1353
1441
|
return target_value
|
1354
1442
|
|
1443
|
+
@dispatch
|
1444
|
+
def actor_loss(
|
1445
|
+
self, tensordict: TensorDictBase
|
1446
|
+
) -> tuple[Tensor, dict[str, Tensor]]:
|
1447
|
+
"""Compute the actor loss for discrete SAC.
|
1448
|
+
|
1449
|
+
This method computes the actor loss which encourages the policy to maximize
|
1450
|
+
the expected Q-value while maintaining high entropy for discrete actions.
|
1451
|
+
|
1452
|
+
Args:
|
1453
|
+
tensordict (TensorDictBase): A tensordict containing the data needed for
|
1454
|
+
computing the actor loss. Should contain the observation and other
|
1455
|
+
required keys for the actor network.
|
1456
|
+
|
1457
|
+
Returns:
|
1458
|
+
A tuple containing:
|
1459
|
+
- The actor loss tensor
|
1460
|
+
- A dictionary with metadata including the log probability of actions
|
1461
|
+
"""
|
1462
|
+
return self._actor_loss(tensordict)
|
1463
|
+
|
1464
|
+
@dispatch
|
1465
|
+
def qvalue_loss(
|
1466
|
+
self, tensordict: TensorDictBase
|
1467
|
+
) -> tuple[Tensor, dict[str, Tensor]]:
|
1468
|
+
"""Compute the Q-value loss for discrete SAC.
|
1469
|
+
|
1470
|
+
This method computes the Q-value loss which trains the Q-networks to estimate
|
1471
|
+
the expected return for state-action pairs in discrete action spaces.
|
1472
|
+
|
1473
|
+
Args:
|
1474
|
+
tensordict (TensorDictBase): A tensordict containing the data needed for
|
1475
|
+
computing the Q-value loss. Should contain the observation, action,
|
1476
|
+
reward, done, and terminated keys.
|
1477
|
+
|
1478
|
+
Returns:
|
1479
|
+
A tuple containing:
|
1480
|
+
- The Q-value loss tensor
|
1481
|
+
- A dictionary with metadata including the TD error
|
1482
|
+
"""
|
1483
|
+
return self._value_loss(tensordict)
|
1484
|
+
|
1355
1485
|
def _value_loss(
|
1356
1486
|
self, tensordict: TensorDictBase
|
1357
1487
|
) -> tuple[Tensor, dict[str, Tensor]]:
|
@@ -1427,6 +1557,20 @@ class DiscreteSACLoss(LossModule):
|
|
1427
1557
|
alpha_loss = torch.zeros_like(log_prob)
|
1428
1558
|
return alpha_loss
|
1429
1559
|
|
1560
|
+
def alpha_loss(self, log_prob: Tensor) -> Tensor:
|
1561
|
+
"""Compute the alpha loss for discrete SAC.
|
1562
|
+
|
1563
|
+
This method computes the alpha loss which adapts the entropy coefficient
|
1564
|
+
to maintain the target entropy level for discrete actions.
|
1565
|
+
|
1566
|
+
Args:
|
1567
|
+
log_prob (Tensor): The log probability of actions from the actor network.
|
1568
|
+
|
1569
|
+
Returns:
|
1570
|
+
The alpha loss tensor
|
1571
|
+
"""
|
1572
|
+
return self._alpha_loss(log_prob)
|
1573
|
+
|
1430
1574
|
@property
|
1431
1575
|
def _alpha(self):
|
1432
1576
|
if self.min_log_alpha is not None or self.max_log_alpha is not None:
|
torchrl/version.py
CHANGED
@@ -1,2 +1,3 @@
|
|
1
|
-
__version__ = '2025.7.
|
2
|
-
git_version = '
|
1
|
+
__version__ = '2025.7.21'
|
2
|
+
git_version = '851a0410a7c72ad656ab3712395450ae482ce3f6'
|
3
|
+
pytorch_version = '2.9.0.dev20250721'
|
@@ -2,10 +2,10 @@ build_tools/__init__.py,sha256=D9ECsap08rcUtHyaR111nTjArlKkUjQwP8usXwXAAVQ,179
|
|
2
2
|
build_tools/setup_helpers/__init__.py,sha256=7l8TvVqxKezgzKCLuRv20mvGLloprFVZYm8CWS8a6CU,276
|
3
3
|
build_tools/setup_helpers/extension.py,sha256=4-PDLr-pw40bJnd9SfxnTaSjUyuXU_Tg8yOg69Kl0o4,5914
|
4
4
|
torchrl/__init__.py,sha256=mhDBx2UIuBKc0gmi8dVNHokQ6tCbIovruZmyAxcSsy8,2938
|
5
|
-
torchrl/_extension.py,sha256=
|
6
|
-
torchrl/_torchrl.cpython-312-darwin.so,sha256=
|
5
|
+
torchrl/_extension.py,sha256=YmBlmR6Kt0P1rFLSqA9IbkQlhHuJlzySY-ZlUpckxBY,2736
|
6
|
+
torchrl/_torchrl.cpython-312-darwin.so,sha256=D48AeTXfz_agfVAfGuv00nykEj2aTD_tIxcyYUhDbF8,1692224
|
7
7
|
torchrl/_utils.py,sha256=Cw5EG6x5oSZF1iE3YCs1a32VUKp0rTXIs2u67q9zKUI,41078
|
8
|
-
torchrl/version.py,sha256=
|
8
|
+
torchrl/version.py,sha256=r9rIVz69C9oNJra3D2dU2cL9iEKutdEWy1fSsSIEcVI,121
|
9
9
|
torchrl/collectors/__init__.py,sha256=hJ3JD6shRku0BL6SzJQq44FZ5Q1RGR8LealFyU3FRn4,799
|
10
10
|
torchrl/collectors/collectors.py,sha256=HpaW-y0bQOaOql8_7VyEPJ084CulrVwn6iBpGYoHyH4,178287
|
11
11
|
torchrl/collectors/utils.py,sha256=MlXrkYuDmV0Em-tVNQiLL32FWgPNDgceYYG_GgpiviA,11320
|
@@ -191,7 +191,7 @@ torchrl/objectives/iql.py,sha256=1jvlSznWke6NZSwfuYyHVnVBE7Cz3q169GnCRC7iel4,429
|
|
191
191
|
torchrl/objectives/ppo.py,sha256=0soC2aiCOFNM5hCL20-99LX_NZi6XIXDmG2IkGEHSek,76082
|
192
192
|
torchrl/objectives/redq.py,sha256=4usM-nG2UWujeL-VEqzf7-uOwRFx6itkKCeitKuJhtw,28507
|
193
193
|
torchrl/objectives/reinforce.py,sha256=ySXLp5C-OOUYayqjrf4taQmL8LgRvMgPCgHDsle8JDc,22339
|
194
|
-
torchrl/objectives/sac.py,sha256=
|
194
|
+
torchrl/objectives/sac.py,sha256=21AR3GC9tcvXAlWE4m2abp1UZgDKyLegEDiD8-EjJbQ,69232
|
195
195
|
torchrl/objectives/td3.py,sha256=RnlkGzBBTY0KrfRKytsFbNyoVUy2HLfwSL4_9YQRep8,23190
|
196
196
|
torchrl/objectives/td3_bc.py,sha256=jHGwCzPuCbN37zAxsiDQIe92yR1UE7rjcnJoy8b_NjE,25950
|
197
197
|
torchrl/objectives/utils.py,sha256=qDP_ZCjsJyPQwBqGf5Vzrdw6HtNyWqQeAEb2QpvwW1Q,24739
|
@@ -223,8 +223,8 @@ torchrl/trainers/helpers/losses.py,sha256=sHlJqjh02t8cKN73X35Azd_OoWGurohLuviB8Y
|
|
223
223
|
torchrl/trainers/helpers/models.py,sha256=ihTERG2c96E8cS3Tnul6a_ys6iDEEJmHh05p9blQTW8,21807
|
224
224
|
torchrl/trainers/helpers/replay_buffer.py,sha256=ZUZHOa0TILyeWJ3iahzTJ6UvMl_0FdxuZfJEja94Bn8,2001
|
225
225
|
torchrl/trainers/helpers/trainers.py,sha256=j6B5XA7_FFHMQeOIQwjNcO0CGE_4mZKUC9_jH_iqqh4,12071
|
226
|
-
torchrl_nightly-2025.7.
|
227
|
-
torchrl_nightly-2025.7.
|
228
|
-
torchrl_nightly-2025.7.
|
229
|
-
torchrl_nightly-2025.7.
|
230
|
-
torchrl_nightly-2025.7.
|
226
|
+
torchrl_nightly-2025.7.21.dist-info/licenses/LICENSE,sha256=xdjS4_xk-IwnLuIFCvTYTl9Y8aXRejqpmke3dGam_nI,1098
|
227
|
+
torchrl_nightly-2025.7.21.dist-info/METADATA,sha256=fmyatCEgcxBi17W8uVg1hZ-9fslugseRebMHAvwZq8I,42990
|
228
|
+
torchrl_nightly-2025.7.21.dist-info/WHEEL,sha256=9_3tTSxMJq-dgdzMiScNvtT5eTBVd3l6RgHS7HwTzpA,115
|
229
|
+
torchrl_nightly-2025.7.21.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
|
230
|
+
torchrl_nightly-2025.7.21.dist-info/RECORD,,
|
File without changes
|
{torchrl_nightly-2025.7.19.dist-info → torchrl_nightly-2025.7.21.dist-info}/licenses/LICENSE
RENAMED
File without changes
|
File without changes
|