torchrl-nightly 2025.4.21__cp311-cp311-win_amd64.whl → 2025.4.22__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Binary file
@@ -765,7 +765,7 @@ class KLRewardTransform(Transform):
765
765
  kl = curr_log_prob - log_prob
766
766
  if reward is None:
767
767
  reward = 0
768
- next_tensordict.set(self.out_keys[0], reward + self.coef * kl)
768
+ next_tensordict.set(self.out_keys[0], reward - self.coef * kl)
769
769
  return next_tensordict
770
770
 
771
771
  def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
@@ -159,23 +159,29 @@ class EGreedyModule(TensorDictModuleBase):
159
159
  action_tensordict = tensordict
160
160
  action_key = self.action_key
161
161
 
162
- out = action_tensordict.get(action_key)
162
+ action = action_tensordict.get(action_key)
163
163
  eps = self.eps
164
- cond = torch.rand(action_tensordict.shape, device=out.device) < eps
164
+ device = eps.device
165
+ action_device = action.device
166
+ if action_device is not None and action_device != device:
167
+ raise RuntimeError(
168
+ f"Expected action and e-greedy module to be on the same device, but got {action.device=} and e-greedy device={device}."
169
+ )
170
+ cond = torch.rand(action_tensordict.shape, device=device) < eps
165
171
  # cond = torch.zeros(action_tensordict.shape, device=out.device, dtype=torch.bool).bernoulli_(eps)
166
- cond = expand_as_right(cond, out)
172
+ cond = expand_as_right(cond, action)
167
173
  spec = self.spec
168
174
  if spec is not None:
169
175
  if isinstance(spec, Composite):
170
176
  spec = spec[self.action_key]
171
- if spec.shape != out.shape:
177
+ if spec.shape != action.shape:
172
178
  # In batched envs if the spec is passed unbatched, the rand() will not
173
179
  # cover all batched dims
174
180
  if (
175
181
  not len(spec.shape)
176
- or out.shape[-len(spec.shape) :] == spec.shape
182
+ or action.shape[-len(spec.shape) :] == spec.shape
177
183
  ):
178
- spec = spec.expand(out.shape)
184
+ spec = spec.expand(action.shape)
179
185
  else:
180
186
  raise ValueError(
181
187
  "Action spec shape does not match the action shape"
@@ -188,12 +194,12 @@ class EGreedyModule(TensorDictModuleBase):
188
194
  )
189
195
  spec.update_mask(action_mask)
190
196
  r = spec.rand()
191
- if r.device != out.device:
192
- r = r.to(out.device)
193
- out = torch.where(cond, r, out)
197
+ if r.device != device:
198
+ r = r.to(device)
199
+ action = torch.where(cond, r, action)
194
200
  else:
195
201
  raise RuntimeError("spec must be provided to the exploration wrapper.")
196
- action_tensordict.set(action_key, out)
202
+ action_tensordict.set(action_key, action)
197
203
  return tensordict
198
204
 
199
205
 
torchrl/version.py CHANGED
@@ -1,2 +1,2 @@
1
- __version__ = '2025.4.21'
2
- git_version = '3a9f244dea03f117c9e97b49a82c23894ebb3e34'
1
+ __version__ = '2025.4.22'
2
+ git_version = '382430db3c457312366fce4ea42330a656337419'
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: torchrl-nightly
3
- Version: 2025.4.21
3
+ Version: 2025.4.22
4
4
  Home-page: https://github.com/pytorch/rl
5
5
  Author: torchrl contributors
6
6
  Author-email: vmoens@fb.com
@@ -3,9 +3,9 @@ build_tools/setup_helpers/__init__.py,sha256=l9zlK7Nm5bT7P_onQx-hZeIGzKKyCFm1PFk
3
3
  build_tools/setup_helpers/extension.py,sha256=ihV8jz8kqOvpqzuD006XqF1oNX5ukKGlwIOJRb1Vd-o,6075
4
4
  torchrl/__init__.py,sha256=jm6WKJgsKzm6zSTABU4LfHfwSxL47cdIHOpvAsSJLa0,2976
5
5
  torchrl/_extension.py,sha256=x6Nqj2brF3VhlEwxmNA2fYbmpxq1HHGrHMnP0YnQwdc,2412
6
- torchrl/_torchrl.cp311-win_amd64.pyd,sha256=Yfjcsj3OdVYYkHD4PMEsLw7XfcQvF0P2dw3DooI_V2g,454656
6
+ torchrl/_torchrl.cp311-win_amd64.pyd,sha256=lrWSY3skCpa25hPVGF9wDlgA1Uo-w2JlFz2WQu5Cagg,454656
7
7
  torchrl/_utils.py,sha256=pCDGUvEmx-z8Ksnu5vfAHrPmqY0W7PzCqEKU7yLVxs0,38436
8
- torchrl/version.py,sha256=UVQ_PqMmhmd_Jpt1wvGXR1HHZKDQDT6WJOt2LmhPeOc,85
8
+ torchrl/version.py,sha256=GQsx-gRVU7gZZCedJSS2cyF3-S8CFcBKQyk6YlNlRIE,85
9
9
  torchrl/collectors/__init__.py,sha256=A76nzBSeZT4QmLidVT8IrZ41utvdMFvSc0kOPYgiiPU,904
10
10
  torchrl/collectors/collectors.py,sha256=_3a9aHW0tXL5CZVOHaIDx5pFTpZrTQgSZ4BKh4Y6Anw,167253
11
11
  torchrl/collectors/llm.py,sha256=7SY9EYrNb8zhYZKIQtOjAOk8VMen-2zeG5wIN-45gVc,16796
@@ -95,7 +95,7 @@ torchrl/envs/model_based/dreamer.py,sha256=ZTT5wCln-YvN2ZL4dCbIEXXwskJEv3M2Mthmd
95
95
  torchrl/envs/transforms/__init__.py,sha256=j1iMhwoTS4-6wTGMeQwjDPLZcmCebLyaFpOSD3oVn8g,3374
96
96
  torchrl/envs/transforms/functional.py,sha256=U8eTymhG_88oqPHdlQhJTsNHyY0fk9ouIMzr1OSgeyM,1486
97
97
  torchrl/envs/transforms/gym_transforms.py,sha256=MS-PpSLLzdUFcvlF-yQkepwyBnmi7m9ivrDb2wSsVjM,10085
98
- torchrl/envs/transforms/llm.py,sha256=6hEotgHsLwELzjJ-jovFGiJQmi49IgWo5fUO8GOBbzA,38857
98
+ torchrl/envs/transforms/llm.py,sha256=ovwDOcHekQDNu98yBiLoPFoO2MGlKTuT03SblbFgDNU,38857
99
99
  torchrl/envs/transforms/r3m.py,sha256=uC1jhCytTvYIjfiPExS9-MgSg0C7aw7r6AI0rba0bxo,13848
100
100
  torchrl/envs/transforms/rb_transforms.py,sha256=66zBM9RL_yYzIb4IKbLw9nDGtl_rHdMcOBn8fY7-1KY,7662
101
101
  torchrl/envs/transforms/rlhf.py,sha256=6PUb9SvF2cvyHUD_DmxQ4UM35VJ87qHit8PFHLWwLAk,653
@@ -131,7 +131,7 @@ torchrl/modules/planners/mppi.py,sha256=LM7wEsFBipednMcVlBls-vMW_aPHVh1cXDeGHmZm
131
131
  torchrl/modules/tensordict_module/__init__.py,sha256=iTz8iCBmxt661GrGJRBfw4tBoTuiSIQLiQxMZCszf24,2383
132
132
  torchrl/modules/tensordict_module/actors.py,sha256=bo-C4zQz2uKMQrCOST3Sgym_QDNMFhI1GHMKZF7obzQ,111160
133
133
  torchrl/modules/tensordict_module/common.py,sha256=k_Ic-HI0KUI1_kEisugJPnY7MgXKyHgmfpUofYiD1Oc,22235
134
- torchrl/modules/tensordict_module/exploration.py,sha256=UAf6uQaBla-ryzBDKIKdDy063sXy7vKhny1d7_FuA8E,29942
134
+ torchrl/modules/tensordict_module/exploration.py,sha256=6OpcmmDaKIcL3ktD1bJsqfKJQQHfLnGAE7VNyuCpvM8,30298
135
135
  torchrl/modules/tensordict_module/probabilistic.py,sha256=ocx_QN7s0u89Gsn2YAzs21dz02TwJ8VwMI7Nn8X5zrM,16703
136
136
  torchrl/modules/tensordict_module/rnn.py,sha256=c2mw6BDnh93bDDx062GX3uesIfctymrJjtLWkBO7Vgs,71635
137
137
  torchrl/modules/tensordict_module/sequence.py,sha256=UMyPgl1XhgbL70orZISjM1spQ1NT1bejNpATAlODcu8,6143
@@ -184,8 +184,8 @@ torchrl/trainers/helpers/losses.py,sha256=7ANhLBKMEGhEI5erxggp8UJg8n_IRXHC1V9t6O
184
184
  torchrl/trainers/helpers/models.py,sha256=JkIr28JiVAtAago0n1bSBcWl_oxr2XJcFo5WxNugi1E,22463
185
185
  torchrl/trainers/helpers/replay_buffer.py,sha256=RaZqXnHimmadiibvDBcLbtIhpPaVMTPhYMOBvX4v3CA,2060
186
186
  torchrl/trainers/helpers/trainers.py,sha256=hB1FtHtP-S0PBQ4LF6WPy37caaLpacyaLThj1BNl5Ho,12372
187
- torchrl_nightly-2025.4.21.dist-info/LICENSE,sha256=PGO-oZsq4EzhE1-WQS2xGiEF3UCVb9YawfQ09cIMV_8,1119
188
- torchrl_nightly-2025.4.21.dist-info/METADATA,sha256=pxcg8JOHegIAMb9_l_rWH-Ypjdq29u6YM7OJ0rAFDtM,40854
189
- torchrl_nightly-2025.4.21.dist-info/WHEEL,sha256=y4n9_669c4ZQLyT56MHjc_JUbnwtaZfMVMycweN557o,102
190
- torchrl_nightly-2025.4.21.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
191
- torchrl_nightly-2025.4.21.dist-info/RECORD,,
187
+ torchrl_nightly-2025.4.22.dist-info/LICENSE,sha256=PGO-oZsq4EzhE1-WQS2xGiEF3UCVb9YawfQ09cIMV_8,1119
188
+ torchrl_nightly-2025.4.22.dist-info/METADATA,sha256=2pxCJVBjjNzIwPjpIgIT1eA96xiS229JZdGA3dYq4DI,40854
189
+ torchrl_nightly-2025.4.22.dist-info/WHEEL,sha256=y4n9_669c4ZQLyT56MHjc_JUbnwtaZfMVMycweN557o,102
190
+ torchrl_nightly-2025.4.22.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
191
+ torchrl_nightly-2025.4.22.dist-info/RECORD,,