torchrl-nightly 2025.4.20__cp311-cp311-win_amd64.whl → 2025.4.22__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrl/_torchrl.cp311-win_amd64.pyd +0 -0
- torchrl/envs/transforms/llm.py +1 -1
- torchrl/modules/tensordict_module/exploration.py +16 -10
- torchrl/version.py +2 -2
- {torchrl_nightly-2025.4.20.dist-info → torchrl_nightly-2025.4.22.dist-info}/METADATA +1 -1
- {torchrl_nightly-2025.4.20.dist-info → torchrl_nightly-2025.4.22.dist-info}/RECORD +9 -9
- {torchrl_nightly-2025.4.20.dist-info → torchrl_nightly-2025.4.22.dist-info}/LICENSE +0 -0
- {torchrl_nightly-2025.4.20.dist-info → torchrl_nightly-2025.4.22.dist-info}/WHEEL +0 -0
- {torchrl_nightly-2025.4.20.dist-info → torchrl_nightly-2025.4.22.dist-info}/top_level.txt +0 -0
Binary file
|
torchrl/envs/transforms/llm.py
CHANGED
@@ -765,7 +765,7 @@ class KLRewardTransform(Transform):
|
|
765
765
|
kl = curr_log_prob - log_prob
|
766
766
|
if reward is None:
|
767
767
|
reward = 0
|
768
|
-
next_tensordict.set(self.out_keys[0], reward
|
768
|
+
next_tensordict.set(self.out_keys[0], reward - self.coef * kl)
|
769
769
|
return next_tensordict
|
770
770
|
|
771
771
|
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
@@ -159,23 +159,29 @@ class EGreedyModule(TensorDictModuleBase):
|
|
159
159
|
action_tensordict = tensordict
|
160
160
|
action_key = self.action_key
|
161
161
|
|
162
|
-
|
162
|
+
action = action_tensordict.get(action_key)
|
163
163
|
eps = self.eps
|
164
|
-
|
164
|
+
device = eps.device
|
165
|
+
action_device = action.device
|
166
|
+
if action_device is not None and action_device != device:
|
167
|
+
raise RuntimeError(
|
168
|
+
f"Expected action and e-greedy module to be on the same device, but got {action.device=} and e-greedy device={device}."
|
169
|
+
)
|
170
|
+
cond = torch.rand(action_tensordict.shape, device=device) < eps
|
165
171
|
# cond = torch.zeros(action_tensordict.shape, device=out.device, dtype=torch.bool).bernoulli_(eps)
|
166
|
-
cond = expand_as_right(cond,
|
172
|
+
cond = expand_as_right(cond, action)
|
167
173
|
spec = self.spec
|
168
174
|
if spec is not None:
|
169
175
|
if isinstance(spec, Composite):
|
170
176
|
spec = spec[self.action_key]
|
171
|
-
if spec.shape !=
|
177
|
+
if spec.shape != action.shape:
|
172
178
|
# In batched envs if the spec is passed unbatched, the rand() will not
|
173
179
|
# cover all batched dims
|
174
180
|
if (
|
175
181
|
not len(spec.shape)
|
176
|
-
or
|
182
|
+
or action.shape[-len(spec.shape) :] == spec.shape
|
177
183
|
):
|
178
|
-
spec = spec.expand(
|
184
|
+
spec = spec.expand(action.shape)
|
179
185
|
else:
|
180
186
|
raise ValueError(
|
181
187
|
"Action spec shape does not match the action shape"
|
@@ -188,12 +194,12 @@ class EGreedyModule(TensorDictModuleBase):
|
|
188
194
|
)
|
189
195
|
spec.update_mask(action_mask)
|
190
196
|
r = spec.rand()
|
191
|
-
if r.device !=
|
192
|
-
r = r.to(
|
193
|
-
|
197
|
+
if r.device != device:
|
198
|
+
r = r.to(device)
|
199
|
+
action = torch.where(cond, r, action)
|
194
200
|
else:
|
195
201
|
raise RuntimeError("spec must be provided to the exploration wrapper.")
|
196
|
-
action_tensordict.set(action_key,
|
202
|
+
action_tensordict.set(action_key, action)
|
197
203
|
return tensordict
|
198
204
|
|
199
205
|
|
torchrl/version.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
__version__ = '2025.4.
|
2
|
-
git_version = '
|
1
|
+
__version__ = '2025.4.22'
|
2
|
+
git_version = '382430db3c457312366fce4ea42330a656337419'
|
@@ -3,9 +3,9 @@ build_tools/setup_helpers/__init__.py,sha256=l9zlK7Nm5bT7P_onQx-hZeIGzKKyCFm1PFk
|
|
3
3
|
build_tools/setup_helpers/extension.py,sha256=ihV8jz8kqOvpqzuD006XqF1oNX5ukKGlwIOJRb1Vd-o,6075
|
4
4
|
torchrl/__init__.py,sha256=jm6WKJgsKzm6zSTABU4LfHfwSxL47cdIHOpvAsSJLa0,2976
|
5
5
|
torchrl/_extension.py,sha256=x6Nqj2brF3VhlEwxmNA2fYbmpxq1HHGrHMnP0YnQwdc,2412
|
6
|
-
torchrl/_torchrl.cp311-win_amd64.pyd,sha256=
|
6
|
+
torchrl/_torchrl.cp311-win_amd64.pyd,sha256=lrWSY3skCpa25hPVGF9wDlgA1Uo-w2JlFz2WQu5Cagg,454656
|
7
7
|
torchrl/_utils.py,sha256=pCDGUvEmx-z8Ksnu5vfAHrPmqY0W7PzCqEKU7yLVxs0,38436
|
8
|
-
torchrl/version.py,sha256=
|
8
|
+
torchrl/version.py,sha256=GQsx-gRVU7gZZCedJSS2cyF3-S8CFcBKQyk6YlNlRIE,85
|
9
9
|
torchrl/collectors/__init__.py,sha256=A76nzBSeZT4QmLidVT8IrZ41utvdMFvSc0kOPYgiiPU,904
|
10
10
|
torchrl/collectors/collectors.py,sha256=_3a9aHW0tXL5CZVOHaIDx5pFTpZrTQgSZ4BKh4Y6Anw,167253
|
11
11
|
torchrl/collectors/llm.py,sha256=7SY9EYrNb8zhYZKIQtOjAOk8VMen-2zeG5wIN-45gVc,16796
|
@@ -95,7 +95,7 @@ torchrl/envs/model_based/dreamer.py,sha256=ZTT5wCln-YvN2ZL4dCbIEXXwskJEv3M2Mthmd
|
|
95
95
|
torchrl/envs/transforms/__init__.py,sha256=j1iMhwoTS4-6wTGMeQwjDPLZcmCebLyaFpOSD3oVn8g,3374
|
96
96
|
torchrl/envs/transforms/functional.py,sha256=U8eTymhG_88oqPHdlQhJTsNHyY0fk9ouIMzr1OSgeyM,1486
|
97
97
|
torchrl/envs/transforms/gym_transforms.py,sha256=MS-PpSLLzdUFcvlF-yQkepwyBnmi7m9ivrDb2wSsVjM,10085
|
98
|
-
torchrl/envs/transforms/llm.py,sha256=
|
98
|
+
torchrl/envs/transforms/llm.py,sha256=ovwDOcHekQDNu98yBiLoPFoO2MGlKTuT03SblbFgDNU,38857
|
99
99
|
torchrl/envs/transforms/r3m.py,sha256=uC1jhCytTvYIjfiPExS9-MgSg0C7aw7r6AI0rba0bxo,13848
|
100
100
|
torchrl/envs/transforms/rb_transforms.py,sha256=66zBM9RL_yYzIb4IKbLw9nDGtl_rHdMcOBn8fY7-1KY,7662
|
101
101
|
torchrl/envs/transforms/rlhf.py,sha256=6PUb9SvF2cvyHUD_DmxQ4UM35VJ87qHit8PFHLWwLAk,653
|
@@ -131,7 +131,7 @@ torchrl/modules/planners/mppi.py,sha256=LM7wEsFBipednMcVlBls-vMW_aPHVh1cXDeGHmZm
|
|
131
131
|
torchrl/modules/tensordict_module/__init__.py,sha256=iTz8iCBmxt661GrGJRBfw4tBoTuiSIQLiQxMZCszf24,2383
|
132
132
|
torchrl/modules/tensordict_module/actors.py,sha256=bo-C4zQz2uKMQrCOST3Sgym_QDNMFhI1GHMKZF7obzQ,111160
|
133
133
|
torchrl/modules/tensordict_module/common.py,sha256=k_Ic-HI0KUI1_kEisugJPnY7MgXKyHgmfpUofYiD1Oc,22235
|
134
|
-
torchrl/modules/tensordict_module/exploration.py,sha256=
|
134
|
+
torchrl/modules/tensordict_module/exploration.py,sha256=6OpcmmDaKIcL3ktD1bJsqfKJQQHfLnGAE7VNyuCpvM8,30298
|
135
135
|
torchrl/modules/tensordict_module/probabilistic.py,sha256=ocx_QN7s0u89Gsn2YAzs21dz02TwJ8VwMI7Nn8X5zrM,16703
|
136
136
|
torchrl/modules/tensordict_module/rnn.py,sha256=c2mw6BDnh93bDDx062GX3uesIfctymrJjtLWkBO7Vgs,71635
|
137
137
|
torchrl/modules/tensordict_module/sequence.py,sha256=UMyPgl1XhgbL70orZISjM1spQ1NT1bejNpATAlODcu8,6143
|
@@ -184,8 +184,8 @@ torchrl/trainers/helpers/losses.py,sha256=7ANhLBKMEGhEI5erxggp8UJg8n_IRXHC1V9t6O
|
|
184
184
|
torchrl/trainers/helpers/models.py,sha256=JkIr28JiVAtAago0n1bSBcWl_oxr2XJcFo5WxNugi1E,22463
|
185
185
|
torchrl/trainers/helpers/replay_buffer.py,sha256=RaZqXnHimmadiibvDBcLbtIhpPaVMTPhYMOBvX4v3CA,2060
|
186
186
|
torchrl/trainers/helpers/trainers.py,sha256=hB1FtHtP-S0PBQ4LF6WPy37caaLpacyaLThj1BNl5Ho,12372
|
187
|
-
torchrl_nightly-2025.4.
|
188
|
-
torchrl_nightly-2025.4.
|
189
|
-
torchrl_nightly-2025.4.
|
190
|
-
torchrl_nightly-2025.4.
|
191
|
-
torchrl_nightly-2025.4.
|
187
|
+
torchrl_nightly-2025.4.22.dist-info/LICENSE,sha256=PGO-oZsq4EzhE1-WQS2xGiEF3UCVb9YawfQ09cIMV_8,1119
|
188
|
+
torchrl_nightly-2025.4.22.dist-info/METADATA,sha256=2pxCJVBjjNzIwPjpIgIT1eA96xiS229JZdGA3dYq4DI,40854
|
189
|
+
torchrl_nightly-2025.4.22.dist-info/WHEEL,sha256=y4n9_669c4ZQLyT56MHjc_JUbnwtaZfMVMycweN557o,102
|
190
|
+
torchrl_nightly-2025.4.22.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
|
191
|
+
torchrl_nightly-2025.4.22.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|