torchrl-nightly 2025.7.20__cp310-cp310-win_amd64.whl → 2025.7.21__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
torchrl/_extension.py CHANGED
@@ -10,9 +10,10 @@ import warnings
10
10
  from packaging.version import parse
11
11
 
12
12
  try:
13
- from .version import __version__
13
+ from .version import __version__, pytorch_version
14
14
  except ImportError:
15
15
  __version__ = None
16
+ pytorch_version = "unknown"
16
17
 
17
18
 
18
19
  def is_module_available(*modules: str) -> bool:
@@ -47,13 +48,16 @@ if _is_nightly(__version__):
47
48
  " - make sure ninja and cmake were installed\n"
48
49
  " - make sure you ran `python setup.py clean && python setup.py develop` and that no error was raised\n"
49
50
  " - make sure the version of PyTorch you are using matches the one that was present in your virtual env during "
50
- "setup."
51
+ f"setup. This package was built with PyTorch {pytorch_version}."
51
52
  )
52
53
 
53
54
  else:
54
55
  EXTENSION_WARNING = (
55
56
  "Failed to import torchrl C++ binaries. Some modules (eg, prioritized replay buffers) may not work with your installation. "
56
- "This is likely due to a discrepancy between your package version and the PyTorch version. Make sure both are compatible. "
57
- "Usually, torchrl majors follow the pytorch majors within a few days around the release. "
58
- "For instance, TorchRL 0.5 requires PyTorch 2.4.0, and TorchRL 0.6 requires PyTorch 2.5.0."
57
+ "This is likely due to a discrepancy between your package version and the PyTorch version. "
58
+ "TorchRL does not tightly pin PyTorch versions to give users freedom, but the trade-off is that C++ extensions like "
59
+ "prioritized replay buffers can only be used with the PyTorch version they were built against. "
60
+ f"This package was built with PyTorch {pytorch_version}. "
61
+ "Workarounds include: (1) upgrading/downgrading PyTorch or TorchRL to compatible versions, "
62
+ "or (2) making a local install using `pip install git+https://github.com/pytorch/rl.git@<version>`."
59
63
  )
Binary file
torchrl/objectives/sac.py CHANGED
@@ -687,6 +687,102 @@ class SACLoss(LossModule):
687
687
  )
688
688
  return self._alpha * log_prob - min_q_logprob, {"log_prob": log_prob.detach()}
689
689
 
690
+ @dispatch
691
+ def actor_loss(
692
+ self, tensordict: TensorDictBase
693
+ ) -> tuple[Tensor, dict[str, Tensor]]:
694
+ """Compute the actor loss for SAC.
695
+
696
+ This method computes the actor loss which encourages the policy to maximize
697
+ the expected Q-value while maintaining high entropy.
698
+
699
+ Args:
700
+ tensordict (TensorDictBase): A tensordict containing the data needed for
701
+ computing the actor loss. Should contain the observation and other
702
+ required keys for the actor network.
703
+
704
+ Returns:
705
+ A tuple containing:
706
+ - The actor loss tensor
707
+ - A dictionary with metadata including the log probability of actions
708
+ """
709
+ return self._actor_loss(tensordict)
710
+
711
+ @dispatch
712
+ def qvalue_loss(
713
+ self, tensordict: TensorDictBase
714
+ ) -> tuple[Tensor, dict[str, Tensor]]:
715
+ """Compute the Q-value loss for SAC.
716
+
717
+ This method computes the Q-value loss which trains the Q-networks to estimate
718
+ the expected return for state-action pairs.
719
+
720
+ Args:
721
+ tensordict (TensorDictBase): A tensordict containing the data needed for
722
+ computing the Q-value loss. Should contain the observation, action,
723
+ reward, done, and terminated keys.
724
+
725
+ Returns:
726
+ A tuple containing:
727
+ - The Q-value loss tensor
728
+ - A dictionary with metadata including the TD error
729
+ """
730
+ if self._version == 1:
731
+ return self._qvalue_v1_loss(tensordict)
732
+ else:
733
+ return self._qvalue_v2_loss(tensordict)
734
+
735
+ @dispatch
736
+ def value_loss(
737
+ self, tensordict: TensorDictBase
738
+ ) -> tuple[Tensor, dict[str, Tensor]]:
739
+ """Compute the value loss for SAC (version 1 only).
740
+
741
+ This method computes the value loss which trains the value network to estimate
742
+ the expected return for states. This is only used in SAC version 1.
743
+
744
+ Args:
745
+ tensordict (TensorDictBase): A tensordict containing the data needed for
746
+ computing the value loss. Should contain the observation and other
747
+ required keys for the value network.
748
+
749
+ Returns:
750
+ A tuple containing:
751
+ - The value loss tensor
752
+ - An empty dictionary (no metadata for value loss)
753
+
754
+ Raises:
755
+ RuntimeError: If called on SAC version 2 (which doesn't use a value network)
756
+ """
757
+ if self._version != 1:
758
+ raise RuntimeError(
759
+ "Value loss is only available in SAC version 1. "
760
+ "SAC version 2 doesn't use a separate value network."
761
+ )
762
+ return self._value_loss(tensordict)
763
+
764
+ def alpha_loss(self, log_prob: Tensor) -> Tensor:
765
+ """Compute the alpha loss for SAC.
766
+
767
+ This method computes the alpha loss which adapts the entropy coefficient
768
+ to maintain the target entropy level.
769
+
770
+ Args:
771
+ log_prob (Tensor): The log probability of actions from the actor network.
772
+
773
+ Returns:
774
+ The alpha loss tensor
775
+ """
776
+ return self._alpha_loss(log_prob)
777
+
778
+ @property
779
+ def _alpha(self):
780
+ if self.min_log_alpha is not None or self.max_log_alpha is not None:
781
+ self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
782
+ with torch.no_grad():
783
+ alpha = self.log_alpha.exp()
784
+ return alpha
785
+
690
786
  @property
691
787
  @_cache_values
692
788
  def _cached_target_params_actor_value(self):
@@ -882,14 +978,6 @@ class SACLoss(LossModule):
882
978
  alpha_loss = torch.zeros_like(log_prob)
883
979
  return alpha_loss
884
980
 
885
- @property
886
- def _alpha(self):
887
- if self.min_log_alpha is not None or self.max_log_alpha is not None:
888
- self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha)
889
- with torch.no_grad():
890
- alpha = self.log_alpha.exp()
891
- return alpha
892
-
893
981
 
894
982
  class DiscreteSACLoss(LossModule):
895
983
  """Discrete SAC Loss module.
@@ -1352,6 +1440,48 @@ class DiscreteSACLoss(LossModule):
1352
1440
  target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
1353
1441
  return target_value
1354
1442
 
1443
+ @dispatch
1444
+ def actor_loss(
1445
+ self, tensordict: TensorDictBase
1446
+ ) -> tuple[Tensor, dict[str, Tensor]]:
1447
+ """Compute the actor loss for discrete SAC.
1448
+
1449
+ This method computes the actor loss which encourages the policy to maximize
1450
+ the expected Q-value while maintaining high entropy for discrete actions.
1451
+
1452
+ Args:
1453
+ tensordict (TensorDictBase): A tensordict containing the data needed for
1454
+ computing the actor loss. Should contain the observation and other
1455
+ required keys for the actor network.
1456
+
1457
+ Returns:
1458
+ A tuple containing:
1459
+ - The actor loss tensor
1460
+ - A dictionary with metadata including the log probability of actions
1461
+ """
1462
+ return self._actor_loss(tensordict)
1463
+
1464
+ @dispatch
1465
+ def qvalue_loss(
1466
+ self, tensordict: TensorDictBase
1467
+ ) -> tuple[Tensor, dict[str, Tensor]]:
1468
+ """Compute the Q-value loss for discrete SAC.
1469
+
1470
+ This method computes the Q-value loss which trains the Q-networks to estimate
1471
+ the expected return for state-action pairs in discrete action spaces.
1472
+
1473
+ Args:
1474
+ tensordict (TensorDictBase): A tensordict containing the data needed for
1475
+ computing the Q-value loss. Should contain the observation, action,
1476
+ reward, done, and terminated keys.
1477
+
1478
+ Returns:
1479
+ A tuple containing:
1480
+ - The Q-value loss tensor
1481
+ - A dictionary with metadata including the TD error
1482
+ """
1483
+ return self._value_loss(tensordict)
1484
+
1355
1485
  def _value_loss(
1356
1486
  self, tensordict: TensorDictBase
1357
1487
  ) -> tuple[Tensor, dict[str, Tensor]]:
@@ -1427,6 +1557,20 @@ class DiscreteSACLoss(LossModule):
1427
1557
  alpha_loss = torch.zeros_like(log_prob)
1428
1558
  return alpha_loss
1429
1559
 
1560
+ def alpha_loss(self, log_prob: Tensor) -> Tensor:
1561
+ """Compute the alpha loss for discrete SAC.
1562
+
1563
+ This method computes the alpha loss which adapts the entropy coefficient
1564
+ to maintain the target entropy level for discrete actions.
1565
+
1566
+ Args:
1567
+ log_prob (Tensor): The log probability of actions from the actor network.
1568
+
1569
+ Returns:
1570
+ The alpha loss tensor
1571
+ """
1572
+ return self._alpha_loss(log_prob)
1573
+
1430
1574
  @property
1431
1575
  def _alpha(self):
1432
1576
  if self.min_log_alpha is not None or self.max_log_alpha is not None:
torchrl/version.py CHANGED
@@ -1,2 +1,3 @@
1
- __version__ = '2025.7.20'
2
- git_version = '163e23f9c930b58990287206958e2e036af5b902'
1
+ __version__ = '2025.7.21'
2
+ git_version = '851a0410a7c72ad656ab3712395450ae482ce3f6'
3
+ pytorch_version = '2.9.0.dev20250721+cpu'
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: torchrl-nightly
3
- Version: 2025.7.20
3
+ Version: 2025.7.21
4
4
  Summary: UNKNOWN
5
5
  Home-page: https://github.com/pytorch/rl
6
6
  Author: torchrl contributors
@@ -2,10 +2,10 @@ build_tools/__init__.py,sha256=5xRGSM4YMr794wxfVj_SQly2cyHNyhXTdXdWUZJoM2M,183
2
2
  build_tools/setup_helpers/__init__.py,sha256=l9zlK7Nm5bT7P_onQx-hZeIGzKKyCFm1PFkFApuRxnM,284
3
3
  build_tools/setup_helpers/extension.py,sha256=ihV8jz8kqOvpqzuD006XqF1oNX5ukKGlwIOJRb1Vd-o,6075
4
4
  torchrl/__init__.py,sha256=76lKYwYKmAKORhyVt2tURmYAIRTifxxO3gWsskrHAXU,3054
5
- torchrl/_extension.py,sha256=x6Nqj2brF3VhlEwxmNA2fYbmpxq1HHGrHMnP0YnQwdc,2412
6
- torchrl/_torchrl.cp310-win_amd64.pyd,sha256=1woR6IR2RqSkShhHE1uh-ar6bWpJb6HwCX8aTXZ8gmc,440832
5
+ torchrl/_extension.py,sha256=K3QwxgCWZIG35Z72VvqshHua1rLQ639wVDREzOhptxk,2799
6
+ torchrl/_torchrl.cp310-win_amd64.pyd,sha256=4oj3psQjrTQDlIVI4dov_VGS71Ep9QEa2H0gphXzSEw,440832
7
7
  torchrl/_utils.py,sha256=2N35rdD65U1khMi5gVIz8-nMjlZsoVq0kCiQftVRSxw,42297
8
- torchrl/version.py,sha256=OZbhjLHH625CpOnLlZQqBfCIDgGNww73f_9eKa8-6x0,85
8
+ torchrl/version.py,sha256=N40bks6QsRHErJ5Ngg0vjMoDLvCjZsSEpyORZwjkSJ4,128
9
9
  torchrl/collectors/__init__.py,sha256=LzTyfxmkNGPSa5-3rS5unQK7HfT5ZEdr2NV291rAOlU,832
10
10
  torchrl/collectors/collectors.py,sha256=UbXtDMHrXVQ-cd95TBpss2SIbNFKze2HTPYjFz2cPNQ,182146
11
11
  torchrl/collectors/utils.py,sha256=aBmBLpphhfplqQjRCyn1jtWWJ-Wtc7TWvM0rOBN8SsE,11579
@@ -191,7 +191,7 @@ torchrl/objectives/iql.py,sha256=VgCjfjKu91WCMCOK78vuc4k4kg1G3hPgOmjiKpozRM0,439
191
191
  torchrl/objectives/ppo.py,sha256=n0MXf9_e9u_u7amt41bLApr7G5gX1372HmjGfPa7EyE,77666
192
192
  torchrl/objectives/redq.py,sha256=qRN5WyA6YHh7GcKX9n5GinXyETssAXJkiH0HuOx6Uss,29177
193
193
  torchrl/objectives/reinforce.py,sha256=EnUjqDSiTla3CuHg9rspQlvecd-VXZrPZxg4rGECZ8w,22861
194
- torchrl/objectives/sac.py,sha256=gKOgCU399miKgpgu7Bmzs1bkIF8JTm_lybHn8V4wDuk,65478
194
+ torchrl/objectives/sac.py,sha256=wKpfdm2y8Udp100PVp4bC0ljkdPwPQQNlYg5ZqvVO1M,70861
195
195
  torchrl/objectives/td3.py,sha256=Rq2q5gXo3AMuHm2OjRZvpfvKsAl1lIK5ALh2_sZM1ZE,23743
196
196
  torchrl/objectives/td3_bc.py,sha256=1pjB8mjCT2CLvQzjnqwAfZoc7yhjMB9UQjuJ5wZfTUY,26558
197
197
  torchrl/objectives/utils.py,sha256=M397jw5eA7IkXfw0_gOBXbxPTQNnQdC5N7-dj91scsY,25461
@@ -223,8 +223,8 @@ torchrl/trainers/helpers/losses.py,sha256=HwrovwbMOhY-5-hlOz-YHclKnoJhMijVjDNuAT
223
223
  torchrl/trainers/helpers/models.py,sha256=VujBq9H92sEzpCtU1iTrJQNlwvyOO-Rho4bzsMonX6s,22465
224
224
  torchrl/trainers/helpers/replay_buffer.py,sha256=RaZqXnHimmadiibvDBcLbtIhpPaVMTPhYMOBvX4v3CA,2060
225
225
  torchrl/trainers/helpers/trainers.py,sha256=hB1FtHtP-S0PBQ4LF6WPy37caaLpacyaLThj1BNl5Ho,12372
226
- torchrl_nightly-2025.7.20.dist-info/LICENSE,sha256=PGO-oZsq4EzhE1-WQS2xGiEF3UCVb9YawfQ09cIMV_8,1119
227
- torchrl_nightly-2025.7.20.dist-info/METADATA,sha256=Y6Bqw_PYQFrHGyyJWSheGXeD7benjQDUlvVqXWRL7ZU,44000
228
- torchrl_nightly-2025.7.20.dist-info/WHEEL,sha256=NVXpD7b4Gxps0cd2ds5rr5TG8W4ApEwx_i5J99qMZ5E,102
229
- torchrl_nightly-2025.7.20.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
230
- torchrl_nightly-2025.7.20.dist-info/RECORD,,
226
+ torchrl_nightly-2025.7.21.dist-info/LICENSE,sha256=PGO-oZsq4EzhE1-WQS2xGiEF3UCVb9YawfQ09cIMV_8,1119
227
+ torchrl_nightly-2025.7.21.dist-info/METADATA,sha256=QAXDqpa4pe97bCPhDHquyjTx7K8y6bXLF2fnt2sAr3o,44000
228
+ torchrl_nightly-2025.7.21.dist-info/WHEEL,sha256=NVXpD7b4Gxps0cd2ds5rr5TG8W4ApEwx_i5J99qMZ5E,102
229
+ torchrl_nightly-2025.7.21.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
230
+ torchrl_nightly-2025.7.21.dist-info/RECORD,,