torchrl-nightly 2025.8.12__cp311-cp311-manylinux1_x86_64.whl → 2025.8.13__cp311-cp311-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrl/envs/libs/jax_utils.py +5 -3
- {torchrl_nightly-2025.8.12.dist-info → torchrl_nightly-2025.8.13.dist-info}/METADATA +4 -1
- {torchrl_nightly-2025.8.12.dist-info → torchrl_nightly-2025.8.13.dist-info}/RECORD +6 -6
- {torchrl_nightly-2025.8.12.dist-info → torchrl_nightly-2025.8.13.dist-info}/LICENSE +0 -0
- {torchrl_nightly-2025.8.12.dist-info → torchrl_nightly-2025.8.13.dist-info}/WHEEL +0 -0
- {torchrl_nightly-2025.8.12.dist-info → torchrl_nightly-2025.8.13.dist-info}/top_level.txt +0 -0
torchrl/envs/libs/jax_utils.py
CHANGED
@@ -42,13 +42,13 @@ _dtype_conversion = {
|
|
42
42
|
|
43
43
|
|
44
44
|
def _ndarray_to_tensor(value: jnp.ndarray | np.ndarray) -> torch.Tensor: # noqa: F821
|
45
|
-
from jax import
|
45
|
+
from jax import numpy as jnp
|
46
46
|
|
47
47
|
# JAX arrays generated by jax.vmap would have Numpy dtypes.
|
48
48
|
if value.dtype in _dtype_conversion:
|
49
49
|
value = value.view(_dtype_conversion[value.dtype])
|
50
50
|
if isinstance(value, jnp.ndarray):
|
51
|
-
dlpack_tensor =
|
51
|
+
dlpack_tensor = value.__dlpack__()
|
52
52
|
elif isinstance(value, np.ndarray):
|
53
53
|
dlpack_tensor = value.__dlpack__()
|
54
54
|
else:
|
@@ -61,7 +61,9 @@ def _ndarray_to_tensor(value: jnp.ndarray | np.ndarray) -> torch.Tensor: # noqa
|
|
61
61
|
def _tensor_to_ndarray(value: torch.Tensor) -> jnp.ndarray: # noqa: F821
|
62
62
|
from jax import dlpack as jax_dlpack
|
63
63
|
|
64
|
-
|
64
|
+
# Detach the tensor to remove gradients before converting to DLPack
|
65
|
+
value = value.contiguous().detach()
|
66
|
+
return jax_dlpack.from_dlpack(value)
|
65
67
|
|
66
68
|
|
67
69
|
def _get_object_fields(obj) -> dict:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: torchrl-nightly
|
3
|
-
Version: 2025.8.
|
3
|
+
Version: 2025.8.13
|
4
4
|
Summary: A modular, primitive-first, python-first PyTorch library for Reinforcement Learning
|
5
5
|
Author-email: torchrl contributors <vmoens@fb.com>
|
6
6
|
Maintainer-email: torchrl contributors <vmoens@fb.com>
|
@@ -33,6 +33,9 @@ Requires-Dist: cloudpickle
|
|
33
33
|
Requires-Dist: tensordict-nightly
|
34
34
|
Provides-Extra: atari
|
35
35
|
Requires-Dist: gymnasium[atari]; extra == "atari"
|
36
|
+
Provides-Extra: brax
|
37
|
+
Requires-Dist: jax[cuda12]>=0.7.0; extra == "brax"
|
38
|
+
Requires-Dist: brax; extra == "brax"
|
36
39
|
Provides-Extra: checkpointing
|
37
40
|
Requires-Dist: torchsnapshot; extra == "checkpointing"
|
38
41
|
Provides-Extra: dev
|
@@ -180,7 +180,7 @@ torchrl/envs/libs/gym.py,sha256=b01bW-ostmOdcLBIl03Pcv4aWcJ2u4brYdyHezXZKDI,8045
|
|
180
180
|
torchrl/envs/libs/habitat.py,sha256=PYvVqS8AGwplSY9r1x6RSLPejSKvxEyZgcDXFZvCYqA,5387
|
181
181
|
torchrl/envs/libs/isaac_lab.py,sha256=4QiCB3tVGQ_D-gOngHyhtVrivy76B32fKGvpt-0hwxo,3436
|
182
182
|
torchrl/envs/libs/isaacgym.py,sha256=GqXwbs9Iq7A5PHr1Sa-7UgQF42DLW7BFcu8lWYhBVJ8,7133
|
183
|
-
torchrl/envs/libs/jax_utils.py,sha256=
|
183
|
+
torchrl/envs/libs/jax_utils.py,sha256=MAm0dvsG6Uk73Obh5ablkGIgkdUgGZnOqqniST0-lZk,6081
|
184
184
|
torchrl/envs/libs/jumanji.py,sha256=XS2dw3ekt8621VmMbASlUTyqHhLpzVgmuMyLgj07lks,40271
|
185
185
|
torchrl/envs/libs/meltingpot.py,sha256=nd8P3JW_1D6fm_-eZbgmZle20fBoN6M_08SH5ydNS3Y,26096
|
186
186
|
torchrl/envs/libs/openml.py,sha256=cuTWhedmXiDYjz8O-wsAuCKPfxaRGYaOffopGzidUw8,5713
|
@@ -322,8 +322,8 @@ torchrl/trainers/helpers/losses.py,sha256=sHlJqjh02t8cKN73X35Azd_OoWGurohLuviB8Y
|
|
322
322
|
torchrl/trainers/helpers/models.py,sha256=ihTERG2c96E8cS3Tnul6a_ys6iDEEJmHh05p9blQTW8,21807
|
323
323
|
torchrl/trainers/helpers/replay_buffer.py,sha256=ZUZHOa0TILyeWJ3iahzTJ6UvMl_0FdxuZfJEja94Bn8,2001
|
324
324
|
torchrl/trainers/helpers/trainers.py,sha256=j6B5XA7_FFHMQeOIQwjNcO0CGE_4mZKUC9_jH_iqqh4,12071
|
325
|
-
torchrl_nightly-2025.8.
|
326
|
-
torchrl_nightly-2025.8.
|
327
|
-
torchrl_nightly-2025.8.
|
328
|
-
torchrl_nightly-2025.8.
|
329
|
-
torchrl_nightly-2025.8.
|
325
|
+
torchrl_nightly-2025.8.13.dist-info/LICENSE,sha256=xdjS4_xk-IwnLuIFCvTYTl9Y8aXRejqpmke3dGam_nI,1098
|
326
|
+
torchrl_nightly-2025.8.13.dist-info/METADATA,sha256=0zoKgbG9IK7SgQb_1SWxF3PQUs5NV_GuxafkJwjUrF8,41499
|
327
|
+
torchrl_nightly-2025.8.13.dist-info/WHEEL,sha256=e3VbkNOSuK0uEGKey5iz4S8FvWrQAw-zWtlYJiG5LyY,105
|
328
|
+
torchrl_nightly-2025.8.13.dist-info/top_level.txt,sha256=-5FcSdmJ9DwdHF8aOIaofsPbz4Gm8G1eo7r7Sc2CHgE,59
|
329
|
+
torchrl_nightly-2025.8.13.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|