torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
import importlib.util
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from torchrl._utils import _make_ordinal_device
|
|
12
|
+
from torchrl.data.utils import DEVICE_TYPING
|
|
13
|
+
from torchrl.envs.common import EnvBase
|
|
14
|
+
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
|
|
15
|
+
from torchrl.envs.utils import _classproperty
|
|
16
|
+
|
|
17
|
+
_has_habitat = importlib.util.find_spec("habitat") is not None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _wrap_import_error(fun):
|
|
21
|
+
@functools.wraps(fun)
|
|
22
|
+
def new_fun(*args, **kwargs):
|
|
23
|
+
if not _has_habitat:
|
|
24
|
+
raise ImportError(
|
|
25
|
+
"Habitat could not be loaded. Consider installing "
|
|
26
|
+
"it or solving the import bugs (see attached error message). "
|
|
27
|
+
"Refer to TorchRL's knowledge base in the documentation to "
|
|
28
|
+
"debug habitat installation."
|
|
29
|
+
)
|
|
30
|
+
return fun(*args, **kwargs)
|
|
31
|
+
|
|
32
|
+
return new_fun
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@_wrap_import_error
|
|
36
|
+
def _get_available_envs():
|
|
37
|
+
for env in GymEnv.available_envs:
|
|
38
|
+
if env.startswith("Habitat"):
|
|
39
|
+
yield env
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class HabitatEnv(GymEnv):
|
|
43
|
+
"""A wrapper for habitat envs.
|
|
44
|
+
|
|
45
|
+
This class currently serves as placeholder and compatibility security.
|
|
46
|
+
It behaves exactly like the GymEnv wrapper.
|
|
47
|
+
|
|
48
|
+
Doc: https://aihabitat.org/docs/
|
|
49
|
+
|
|
50
|
+
GitHub: https://github.com/facebookresearch/habitat-lab
|
|
51
|
+
|
|
52
|
+
URL: https://aihabitat.org/habitat3/
|
|
53
|
+
|
|
54
|
+
Paper: https://ai.meta.com/static-resource/habitat3
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
env_name (str): The environment to execute.
|
|
58
|
+
categorical_action_encoding (bool, optional): if ``True``, categorical
|
|
59
|
+
specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
|
|
60
|
+
otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
|
|
61
|
+
Defaults to ``False``.
|
|
62
|
+
|
|
63
|
+
Keyword Args:
|
|
64
|
+
from_pixels (bool, optional): if ``True``, an attempt to return the pixel
|
|
65
|
+
observations from the env will be performed. By default, these observations
|
|
66
|
+
will be written under the ``"pixels"`` entry.
|
|
67
|
+
The method being used varies
|
|
68
|
+
depending on the gym version and may involve a ``wrappers.pixel_observation.PixelObservationWrapper``.
|
|
69
|
+
Defaults to ``False``.
|
|
70
|
+
pixels_only (bool, optional): if ``True``, only the pixel observations will
|
|
71
|
+
be returned (by default under the ``"pixels"`` entry in the output tensordict).
|
|
72
|
+
If ``False``, observations (eg, states) and pixels will be returned
|
|
73
|
+
whenever ``from_pixels=True``. Defaults to ``True``.
|
|
74
|
+
frame_skip (int, optional): if provided, indicates for how many steps the
|
|
75
|
+
same action is to be repeated. The observation returned will be the
|
|
76
|
+
last observation of the sequence, whereas the reward will be the sum
|
|
77
|
+
of rewards across steps.
|
|
78
|
+
device (torch.device, optional): if provided, the device on which the simulation
|
|
79
|
+
will occur. Defaults to ``torch.device("cuda:0")``.
|
|
80
|
+
batch_size (torch.Size, optional): the batch size of the environment.
|
|
81
|
+
Should match the leading dimensions of all observations, done states,
|
|
82
|
+
rewards, actions and infos.
|
|
83
|
+
Defaults to ``torch.Size([])``.
|
|
84
|
+
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
|
|
85
|
+
for envs to be ``done`` just after :meth:`reset` is called.
|
|
86
|
+
Defaults to ``False``.
|
|
87
|
+
|
|
88
|
+
Attributes:
|
|
89
|
+
available_envs (List[str]): a list of environments to build.
|
|
90
|
+
|
|
91
|
+
Examples:
|
|
92
|
+
>>> from torchrl.envs import HabitatEnv
|
|
93
|
+
>>> env = HabitatEnv("HabitatRenderPick-v0", from_pixels=True)
|
|
94
|
+
>>> env.rollout(3)
|
|
95
|
+
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
@_wrap_import_error
|
|
99
|
+
@set_gym_backend("gym")
|
|
100
|
+
def __init__(self, env_name, **kwargs):
|
|
101
|
+
import habitat # noqa
|
|
102
|
+
import habitat.gym # noqa
|
|
103
|
+
|
|
104
|
+
device_num = torch.device(kwargs.pop("device", 0)).index
|
|
105
|
+
kwargs["override_options"] = [
|
|
106
|
+
f"habitat.simulator.habitat_sim_v0.gpu_device_id={device_num}",
|
|
107
|
+
"habitat.simulator.concur_render=False",
|
|
108
|
+
]
|
|
109
|
+
super().__init__(env_name=env_name, **kwargs)
|
|
110
|
+
|
|
111
|
+
@_classproperty
|
|
112
|
+
def available_envs(cls):
|
|
113
|
+
if not _has_habitat:
|
|
114
|
+
return []
|
|
115
|
+
return list(_get_available_envs())
|
|
116
|
+
|
|
117
|
+
def _build_gym_env(self, env, pixels_only):
|
|
118
|
+
if self.from_pixels:
|
|
119
|
+
env.reset()
|
|
120
|
+
return super()._build_gym_env(env, pixels_only)
|
|
121
|
+
|
|
122
|
+
def to(self, device: DEVICE_TYPING) -> EnvBase:
|
|
123
|
+
device = _make_ordinal_device(torch.device(device))
|
|
124
|
+
if device.type != "cuda":
|
|
125
|
+
raise ValueError("The device must be of type cuda for Habitat.")
|
|
126
|
+
device_num = device.index
|
|
127
|
+
kwargs = {"override_options": []}
|
|
128
|
+
for arg in self._constructor_kwargs.get("override_options", []):
|
|
129
|
+
if arg.startswith("habitat.simulator.habitat_sim_v0.gpu_device_id"):
|
|
130
|
+
arg = f"habitat.simulator.habitat_sim_v0.gpu_device_id={device_num}"
|
|
131
|
+
kwargs["override_options"].append(arg)
|
|
132
|
+
else:
|
|
133
|
+
kwargs["override_options"].append(arg)
|
|
134
|
+
|
|
135
|
+
self._env.close()
|
|
136
|
+
del self._env
|
|
137
|
+
self.rebuild_with_kwargs(**kwargs)
|
|
138
|
+
return super().to(device)
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torchrl.envs.libs.gym import GymWrapper
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class IsaacLabWrapper(GymWrapper):
|
|
12
|
+
"""A wrapper for IsaacLab environments.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
env (scripts_isaaclab.envs.ManagerBasedRLEnv or equivalent): the environment instance to wrap.
|
|
16
|
+
categorical_action_encoding (bool, optional): if ``True``, categorical
|
|
17
|
+
specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
|
|
18
|
+
otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
|
|
19
|
+
Defaults to ``False``.
|
|
20
|
+
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
|
|
21
|
+
for envs to be ``done`` just after :meth:`reset` is called.
|
|
22
|
+
Defaults to ``False``.
|
|
23
|
+
|
|
24
|
+
For other arguments, see the :class:`torchrl.envs.GymWrapper` documentation.
|
|
25
|
+
|
|
26
|
+
Refer to `the Isaac Lab doc for installation instructions <https://isaac-sim.github.io/IsaacLab/main/source/setup/installation/pip_installation.html>`_.
|
|
27
|
+
|
|
28
|
+
Example:
|
|
29
|
+
>>> # This code block ensures that the Isaac app is started in headless mode
|
|
30
|
+
>>> from scripts_isaaclab.app import AppLauncher
|
|
31
|
+
>>> import argparse
|
|
32
|
+
|
|
33
|
+
>>> parser = argparse.ArgumentParser(description="Train an RL agent with TorchRL.")
|
|
34
|
+
>>> AppLauncher.add_app_launcher_args(parser)
|
|
35
|
+
>>> args_cli, hydra_args = parser.parse_known_args(["--headless"])
|
|
36
|
+
>>> app_launcher = AppLauncher(args_cli)
|
|
37
|
+
|
|
38
|
+
>>> # Imports and env
|
|
39
|
+
>>> import gymnasium as gym
|
|
40
|
+
>>> import isaaclab_tasks # noqa: F401
|
|
41
|
+
>>> from isaaclab_tasks.manager_based.classic.ant.ant_env_cfg import AntEnvCfg
|
|
42
|
+
>>> from torchrl.envs.libs.isaac_lab import IsaacLabWrapper
|
|
43
|
+
|
|
44
|
+
>>> env = gym.make("Isaac-Ant-v0", cfg=AntEnvCfg())
|
|
45
|
+
>>> env = IsaacLabWrapper(env)
|
|
46
|
+
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
env: isaaclab.envs.ManagerBasedRLEnv, # noqa: F821
|
|
52
|
+
*,
|
|
53
|
+
categorical_action_encoding: bool = False,
|
|
54
|
+
allow_done_after_reset: bool = True,
|
|
55
|
+
convert_actions_to_numpy: bool = False,
|
|
56
|
+
device: torch.device | None = None,
|
|
57
|
+
**kwargs,
|
|
58
|
+
):
|
|
59
|
+
if device is None:
|
|
60
|
+
device = torch.device("cuda:0")
|
|
61
|
+
super().__init__(
|
|
62
|
+
env,
|
|
63
|
+
device=device,
|
|
64
|
+
categorical_action_encoding=categorical_action_encoding,
|
|
65
|
+
allow_done_after_reset=allow_done_after_reset,
|
|
66
|
+
convert_actions_to_numpy=convert_actions_to_numpy,
|
|
67
|
+
**kwargs,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def seed(self, seed: int | None):
|
|
71
|
+
self._set_seed(seed)
|
|
72
|
+
|
|
73
|
+
def _output_transform(self, step_outputs_tuple): # noqa: F811
|
|
74
|
+
# IsaacLab will modify the `terminated` and `truncated` tensors
|
|
75
|
+
# in-place. We clone them here to make sure data doesn't inadvertently get modified.
|
|
76
|
+
# The variable naming follows torchrl's convention here.
|
|
77
|
+
observations, reward, terminated, truncated, info = step_outputs_tuple
|
|
78
|
+
done = terminated | truncated
|
|
79
|
+
reward = reward.unsqueeze(-1) # to get to (num_envs, 1)
|
|
80
|
+
return (
|
|
81
|
+
observations,
|
|
82
|
+
reward,
|
|
83
|
+
terminated.clone(),
|
|
84
|
+
truncated.clone(),
|
|
85
|
+
done.clone(),
|
|
86
|
+
info,
|
|
87
|
+
)
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib.util
|
|
8
|
+
import itertools
|
|
9
|
+
import warnings
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch
|
|
14
|
+
from tensordict import TensorDictBase
|
|
15
|
+
from torchrl.data.tensor_specs import Composite
|
|
16
|
+
from torchrl.envs.libs.gym import GymWrapper
|
|
17
|
+
from torchrl.envs.utils import _classproperty, make_composite_from_td
|
|
18
|
+
|
|
19
|
+
_has_isaac = importlib.util.find_spec("isaacgym") is not None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class IsaacGymWrapper(GymWrapper):
|
|
23
|
+
"""Wrapper for IsaacGymEnvs environments.
|
|
24
|
+
|
|
25
|
+
The original library can be found `here <https://github.com/NVIDIA-Omniverse/IsaacGymEnvs>`_
|
|
26
|
+
and is based on IsaacGym which can be downloaded `through NVIDIA's webpage <https://developer.nvidia.com/isaac-gym>_`.
|
|
27
|
+
|
|
28
|
+
.. note:: IsaacGym environments cannot be executed consecutively, ie. instantiating one
|
|
29
|
+
environment after another (even if it has been cleared) will cause
|
|
30
|
+
CUDA memory issues. We recommend creating one environment per process only.
|
|
31
|
+
If you need more than one environment, the best way to achieve that is
|
|
32
|
+
to spawn them across processes.
|
|
33
|
+
|
|
34
|
+
.. note:: IsaacGym works on CUDA devices by essence. Make sure your machine
|
|
35
|
+
has GPUs available and the required setup for IsaacGym (eg, Ubuntu 20.04).
|
|
36
|
+
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def lib(self):
|
|
41
|
+
import isaacgym
|
|
42
|
+
|
|
43
|
+
return isaacgym
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self, env: isaacgymenvs.tasks.base.vec_task.Env, **kwargs # noqa: F821
|
|
47
|
+
):
|
|
48
|
+
warnings.warn(
|
|
49
|
+
"IsaacGym environment support is an experimental feature that may change in the future."
|
|
50
|
+
)
|
|
51
|
+
super().__init__(
|
|
52
|
+
env, torch.device(env.device), batch_size=torch.Size([]), **kwargs
|
|
53
|
+
)
|
|
54
|
+
if not hasattr(self, "task"):
|
|
55
|
+
# by convention in IsaacGymEnvs
|
|
56
|
+
self.task = env.__name__
|
|
57
|
+
|
|
58
|
+
def _make_specs(self, env: gym.Env) -> None: # noqa: F821
|
|
59
|
+
super()._make_specs(env, batch_size=self.batch_size)
|
|
60
|
+
self.full_done_spec = Composite(
|
|
61
|
+
{
|
|
62
|
+
key: spec.squeeze(-1)
|
|
63
|
+
for key, spec in self.full_done_spec.items(True, True)
|
|
64
|
+
},
|
|
65
|
+
shape=self.batch_size,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
self.observation_spec["obs"] = self.observation_spec["observation"]
|
|
69
|
+
del self.observation_spec["observation"]
|
|
70
|
+
|
|
71
|
+
data = self.rollout(3).get("next")[..., 0]
|
|
72
|
+
del data[self.reward_key]
|
|
73
|
+
for done_key in self.done_keys:
|
|
74
|
+
try:
|
|
75
|
+
del data[done_key]
|
|
76
|
+
except KeyError:
|
|
77
|
+
continue
|
|
78
|
+
specs = make_composite_from_td(data)
|
|
79
|
+
|
|
80
|
+
obs_spec = self.observation_spec
|
|
81
|
+
obs_spec.unlock_(recurse=True)
|
|
82
|
+
obs_spec.update(specs)
|
|
83
|
+
obs_spec.lock_(recurse=True)
|
|
84
|
+
|
|
85
|
+
def _output_transform(self, output):
|
|
86
|
+
obs, reward, done, info = output
|
|
87
|
+
if self.from_pixels:
|
|
88
|
+
obs["pixels"] = self._env.render(mode="rgb_array")
|
|
89
|
+
return obs, reward, done ^ done, done, done, info
|
|
90
|
+
|
|
91
|
+
def _reset_output_transform(self, reset_data):
|
|
92
|
+
reset_data.pop("reward", None)
|
|
93
|
+
if self.from_pixels:
|
|
94
|
+
reset_data["pixels"] = self._env.render(mode="rgb_array")
|
|
95
|
+
return reset_data, {}
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def _make_envs(cls, *, task, num_envs, device, seed=None, headless=False, **kwargs):
|
|
99
|
+
import isaacgym # noqa
|
|
100
|
+
import isaacgymenvs # noqa
|
|
101
|
+
|
|
102
|
+
_ = kwargs.pop("from_pixels", None)
|
|
103
|
+
envs = isaacgymenvs.make(
|
|
104
|
+
seed=seed,
|
|
105
|
+
task=task,
|
|
106
|
+
num_envs=num_envs,
|
|
107
|
+
sim_device=str(device),
|
|
108
|
+
rl_device=str(device),
|
|
109
|
+
headless=headless,
|
|
110
|
+
**kwargs,
|
|
111
|
+
)
|
|
112
|
+
return envs
|
|
113
|
+
|
|
114
|
+
def _set_seed(self, seed: int | None) -> None:
|
|
115
|
+
# as of #665c32170d84b4be66722eea405a1e08b6e7f761 the seed points nowhere in gym.make for IsaacGymEnvs
|
|
116
|
+
...
|
|
117
|
+
|
|
118
|
+
def read_action(self, action):
|
|
119
|
+
"""Reads the action obtained from the input TensorDict and transforms it in the format expected by the contained environment.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
action (Tensor or TensorDict): an action to be taken in the environment
|
|
123
|
+
|
|
124
|
+
Returns: an action in a format compatible with the contained environment.
|
|
125
|
+
|
|
126
|
+
"""
|
|
127
|
+
return action
|
|
128
|
+
|
|
129
|
+
def read_done(
|
|
130
|
+
self,
|
|
131
|
+
terminated: bool | None = None,
|
|
132
|
+
truncated: bool | None = None,
|
|
133
|
+
done: bool | None = None,
|
|
134
|
+
) -> tuple[bool, bool, bool]:
|
|
135
|
+
if terminated is not None:
|
|
136
|
+
terminated = terminated.bool()
|
|
137
|
+
if truncated is not None:
|
|
138
|
+
truncated = truncated.bool()
|
|
139
|
+
if done is not None:
|
|
140
|
+
done = done.bool()
|
|
141
|
+
return terminated, truncated, done, done.any()
|
|
142
|
+
|
|
143
|
+
def read_reward(self, total_reward):
|
|
144
|
+
return total_reward
|
|
145
|
+
|
|
146
|
+
def read_obs(
|
|
147
|
+
self, observations: dict[str, Any] | torch.Tensor | np.ndarray
|
|
148
|
+
) -> dict[str, Any]:
|
|
149
|
+
"""Reads an observation from the environment and returns an observation compatible with the output TensorDict.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
observations (observation under a format dictated by the inner env): observation to be read.
|
|
153
|
+
|
|
154
|
+
"""
|
|
155
|
+
if isinstance(observations, dict):
|
|
156
|
+
if "state" in observations and "observation" not in observations:
|
|
157
|
+
# we rename "state" in "observation" as "observation" is the conventional name
|
|
158
|
+
# for single observation in torchrl.
|
|
159
|
+
# naming it 'state' will result in envs that have a different name for the state vector
|
|
160
|
+
# when queried with and without pixels
|
|
161
|
+
observations["observation"] = observations.pop("state")
|
|
162
|
+
if not isinstance(observations, (TensorDictBase, dict)):
|
|
163
|
+
(key,) = itertools.islice(self.observation_spec.keys(True, True), 1)
|
|
164
|
+
observations = {key: observations}
|
|
165
|
+
return observations
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class IsaacGymEnv(IsaacGymWrapper):
|
|
169
|
+
"""A TorchRL Env interface for IsaacGym environments.
|
|
170
|
+
|
|
171
|
+
See :class:`~.IsaacGymWrapper` for more information.
|
|
172
|
+
|
|
173
|
+
Examples:
|
|
174
|
+
>>> env = IsaacGymEnv(task="Ant", num_envs=2000, device="cuda:0")
|
|
175
|
+
>>> rollout = env.rollout(3)
|
|
176
|
+
>>> assert env.batch_size == (2000,)
|
|
177
|
+
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
@_classproperty
|
|
181
|
+
def available_envs(cls):
|
|
182
|
+
if not _has_isaac:
|
|
183
|
+
return []
|
|
184
|
+
|
|
185
|
+
import isaacgymenvs # noqa
|
|
186
|
+
|
|
187
|
+
return list(isaacgymenvs.tasks.isaacgym_task_map.keys())
|
|
188
|
+
|
|
189
|
+
def __init__(self, task=None, *, env=None, num_envs, device, **kwargs):
|
|
190
|
+
if env is not None and task is not None:
|
|
191
|
+
raise RuntimeError("Cannot provide both `task` and `env` arguments.")
|
|
192
|
+
elif env is not None:
|
|
193
|
+
task = env
|
|
194
|
+
from_pixels = kwargs.pop("from_pixels", False)
|
|
195
|
+
envs = self._make_envs(
|
|
196
|
+
task=task,
|
|
197
|
+
num_envs=num_envs,
|
|
198
|
+
device=device,
|
|
199
|
+
virtual_screen_capture=False,
|
|
200
|
+
**kwargs,
|
|
201
|
+
)
|
|
202
|
+
self.task = task
|
|
203
|
+
super().__init__(envs, from_pixels=from_pixels, **kwargs)
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import dataclasses
|
|
8
|
+
import importlib.util
|
|
9
|
+
|
|
10
|
+
# import jax
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
# from jax import dlpack as jax_dlpack, numpy as jnp
|
|
15
|
+
from tensordict import make_tensordict, TensorDictBase
|
|
16
|
+
from torch.utils import dlpack as torch_dlpack
|
|
17
|
+
from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded
|
|
18
|
+
from torchrl.data.utils import numpy_to_torch_dtype_dict
|
|
19
|
+
|
|
20
|
+
_has_jax = importlib.util.find_spec("jax") is not None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _tree_reshape(x, batch_size: torch.Size):
|
|
24
|
+
import jax
|
|
25
|
+
|
|
26
|
+
shape, n = batch_size, 1
|
|
27
|
+
return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _tree_flatten(x, batch_size: torch.Size):
|
|
31
|
+
import jax
|
|
32
|
+
|
|
33
|
+
shape, n = (batch_size.numel(),), len(batch_size)
|
|
34
|
+
return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
_dtype_conversion = {
|
|
38
|
+
np.dtype("uint16"): np.int16,
|
|
39
|
+
np.dtype("uint32"): np.int32,
|
|
40
|
+
np.dtype("uint64"): np.int64,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _ndarray_to_tensor(value: jnp.ndarray | np.ndarray) -> torch.Tensor: # noqa: F821
|
|
45
|
+
from jax import numpy as jnp
|
|
46
|
+
|
|
47
|
+
# JAX arrays generated by jax.vmap would have Numpy dtypes.
|
|
48
|
+
if value.dtype in _dtype_conversion:
|
|
49
|
+
value = value.view(_dtype_conversion[value.dtype])
|
|
50
|
+
if isinstance(value, jnp.ndarray):
|
|
51
|
+
dlpack_tensor = value.__dlpack__()
|
|
52
|
+
elif isinstance(value, np.ndarray):
|
|
53
|
+
dlpack_tensor = value.__dlpack__()
|
|
54
|
+
else:
|
|
55
|
+
raise NotImplementedError(f"unsupported data type {type(value)}")
|
|
56
|
+
out = torch_dlpack.from_dlpack(dlpack_tensor)
|
|
57
|
+
# dtype can be messed up by dlpack
|
|
58
|
+
return out.to(numpy_to_torch_dtype_dict[value.dtype])
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _tensor_to_ndarray(value: torch.Tensor) -> jnp.ndarray: # noqa: F821
|
|
62
|
+
from jax import dlpack as jax_dlpack
|
|
63
|
+
|
|
64
|
+
# Detach the tensor to remove gradients before converting to DLPack
|
|
65
|
+
value = value.contiguous().detach()
|
|
66
|
+
return jax_dlpack.from_dlpack(value)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _get_object_fields(obj) -> dict:
|
|
70
|
+
"""Converts an object (named tuple or dataclass or dict) to a dict."""
|
|
71
|
+
if isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple
|
|
72
|
+
return dict(zip(obj._fields, obj))
|
|
73
|
+
elif dataclasses.is_dataclass(obj):
|
|
74
|
+
return {
|
|
75
|
+
field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)
|
|
76
|
+
}
|
|
77
|
+
elif isinstance(obj, dict):
|
|
78
|
+
return obj
|
|
79
|
+
elif obj is None:
|
|
80
|
+
return {}
|
|
81
|
+
else:
|
|
82
|
+
raise NotImplementedError(f"unsupported data type {type(obj)}")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _object_to_tensordict(obj, device, batch_size) -> TensorDictBase:
|
|
86
|
+
"""Converts a namedtuple or a dataclass to a TensorDict."""
|
|
87
|
+
from jax import numpy as jnp
|
|
88
|
+
|
|
89
|
+
t = {}
|
|
90
|
+
_fields = _get_object_fields(obj)
|
|
91
|
+
for name, value in _fields.items():
|
|
92
|
+
if isinstance(value, (np.number, int, float)):
|
|
93
|
+
t[name] = _ndarray_to_tensor(np.asarray([value])).to(device)
|
|
94
|
+
elif isinstance(value, (jnp.ndarray, np.ndarray)):
|
|
95
|
+
t[name] = _ndarray_to_tensor(value).to(device)
|
|
96
|
+
else:
|
|
97
|
+
nested = _object_to_tensordict(value, device, batch_size)
|
|
98
|
+
if nested is not None:
|
|
99
|
+
t[name] = nested
|
|
100
|
+
if len(t):
|
|
101
|
+
return make_tensordict(t, device=device, batch_size=batch_size)
|
|
102
|
+
# discard empty tensordicts
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _tensordict_to_object(tensordict: TensorDictBase, object_example, batch_size=None):
|
|
107
|
+
"""Converts a TensorDict to a namedtuple or a dataclass."""
|
|
108
|
+
from jax import dlpack as jax_dlpack, numpy as jnp
|
|
109
|
+
|
|
110
|
+
if batch_size is None:
|
|
111
|
+
batch_size = []
|
|
112
|
+
t = {}
|
|
113
|
+
_fields = _get_object_fields(object_example)
|
|
114
|
+
for name, example in _fields.items():
|
|
115
|
+
value = tensordict.get(name, None)
|
|
116
|
+
if isinstance(value, TensorDictBase):
|
|
117
|
+
t[name] = _tensordict_to_object(value, example, batch_size=batch_size)
|
|
118
|
+
elif value is None:
|
|
119
|
+
if isinstance(example, dict):
|
|
120
|
+
t[name] = _tensordict_to_object({}, example, batch_size=batch_size)
|
|
121
|
+
else:
|
|
122
|
+
t[name] = None
|
|
123
|
+
else:
|
|
124
|
+
if value.dtype is torch.bool:
|
|
125
|
+
value = value.to(torch.uint8)
|
|
126
|
+
shape = value.shape
|
|
127
|
+
# We need to flatten to fix https://github.com/pytorch/rl/issues/2184
|
|
128
|
+
value = value.contiguous()
|
|
129
|
+
value = value.detach()
|
|
130
|
+
if value.ndim > 1:
|
|
131
|
+
value = value.flatten().clone()
|
|
132
|
+
else:
|
|
133
|
+
# Need this because otherwise an exception is raised
|
|
134
|
+
# ValueError: INTERNAL: Address of buffer 1 must be a multiple of 10, but was 0x7efccec00824
|
|
135
|
+
value = value.clone()
|
|
136
|
+
value = jax_dlpack.from_dlpack(value)
|
|
137
|
+
if shape.numel() == 1 and not value.shape:
|
|
138
|
+
while value.shape != shape:
|
|
139
|
+
value = jnp.expand_dims(value, 0)
|
|
140
|
+
if value.dtype != example.dtype:
|
|
141
|
+
t[name] = value.view(example.dtype)
|
|
142
|
+
else:
|
|
143
|
+
t[name] = value
|
|
144
|
+
else:
|
|
145
|
+
value = jnp.reshape(value, tuple(shape))
|
|
146
|
+
t[name] = value.view(example.dtype).reshape(
|
|
147
|
+
(*batch_size, *example.shape)
|
|
148
|
+
)
|
|
149
|
+
return type(object_example)(**t)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _extract_spec(data: torch.Tensor | TensorDictBase, key=None) -> TensorSpec:
|
|
153
|
+
if isinstance(data, torch.Tensor):
|
|
154
|
+
shape = data.shape
|
|
155
|
+
if key in ("reward", "done"):
|
|
156
|
+
shape = (*shape, 1)
|
|
157
|
+
if data.dtype in (torch.float, torch.double, torch.half):
|
|
158
|
+
return Unbounded(shape=shape, dtype=data.dtype, device=data.device)
|
|
159
|
+
else:
|
|
160
|
+
return Unbounded(shape=shape, dtype=data.dtype, device=data.device)
|
|
161
|
+
elif isinstance(data, TensorDictBase):
|
|
162
|
+
return Composite(
|
|
163
|
+
{key: _extract_spec(value, key=key) for key, value in data.items()}
|
|
164
|
+
)
|
|
165
|
+
else:
|
|
166
|
+
raise TypeError(f"Unsupported data type {type(data)}")
|