torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib
|
|
8
|
+
import os
|
|
9
|
+
import warnings
|
|
10
|
+
from copy import copy
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import torch
|
|
15
|
+
from tensordict import TensorDict
|
|
16
|
+
|
|
17
|
+
from torchrl.data.tensor_specs import Unbounded
|
|
18
|
+
from torchrl.envs.common import _maybe_unlock
|
|
19
|
+
from torchrl.envs.libs.gym import (
|
|
20
|
+
_gym_to_torchrl_spec_transform,
|
|
21
|
+
_GymAsyncMeta,
|
|
22
|
+
gym_backend,
|
|
23
|
+
GymEnv,
|
|
24
|
+
)
|
|
25
|
+
from torchrl.envs.utils import _classproperty, make_composite_from_td
|
|
26
|
+
|
|
27
|
+
_has_gym = (
|
|
28
|
+
importlib.util.find_spec("gym") is not None
|
|
29
|
+
or importlib.util.find_spec("gymnasium") is not None
|
|
30
|
+
)
|
|
31
|
+
_has_robohive = importlib.util.find_spec("robohive") is not None and _has_gym
|
|
32
|
+
|
|
33
|
+
if _has_robohive:
|
|
34
|
+
os.environ.setdefault("sim_backend", "MUJOCO")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class set_directory:
|
|
38
|
+
"""Sets the cwd within the context.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
path (Path): The path to the cwd
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, path: Path):
|
|
45
|
+
self.path = path
|
|
46
|
+
self.origin = Path().absolute()
|
|
47
|
+
|
|
48
|
+
def __enter__(self):
|
|
49
|
+
os.chdir(self.path)
|
|
50
|
+
|
|
51
|
+
def __exit__(self, *args, **kwargs):
|
|
52
|
+
os.chdir(self.origin)
|
|
53
|
+
|
|
54
|
+
def __call__(self, fun):
|
|
55
|
+
def new_fun(*args, **kwargs):
|
|
56
|
+
with set_directory(Path(self.path)):
|
|
57
|
+
return fun(*args, **kwargs)
|
|
58
|
+
|
|
59
|
+
return new_fun
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class _RoboHiveBuild(_GymAsyncMeta):
|
|
63
|
+
def __call__(self, *args, **kwargs):
|
|
64
|
+
instance: RoboHiveEnv = super().__call__(*args, **kwargs)
|
|
65
|
+
instance._refine_specs()
|
|
66
|
+
return instance
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class RoboHiveEnv(GymEnv, metaclass=_RoboHiveBuild):
|
|
70
|
+
"""A wrapper for RoboHive gym environments.
|
|
71
|
+
|
|
72
|
+
RoboHive is a collection of environments/tasks simulated with the MuJoCo physics engine exposed using the OpenAI-Gym API.
|
|
73
|
+
|
|
74
|
+
Github: https://github.com/vikashplus/robohive/
|
|
75
|
+
|
|
76
|
+
Doc: https://github.com/vikashplus/robohive/wiki
|
|
77
|
+
|
|
78
|
+
Paper: https://arxiv.org/abs/2310.06828
|
|
79
|
+
|
|
80
|
+
.. warning::
|
|
81
|
+
RoboHive requires gym 0.13.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
env_name (str): the environment name to build. Must be one of :attr:`.available_envs`
|
|
85
|
+
categorical_action_encoding (bool, optional): if ``True``, categorical
|
|
86
|
+
specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
|
|
87
|
+
otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
|
|
88
|
+
Defaults to ``False``.
|
|
89
|
+
|
|
90
|
+
Keyword Args:
|
|
91
|
+
from_pixels (bool, optional): if ``True``, an attempt to return the pixel
|
|
92
|
+
observations from the env will be performed. By default, these observations
|
|
93
|
+
will be written under the ``"pixels"`` entry.
|
|
94
|
+
The method being used varies
|
|
95
|
+
depending on the gym version and may involve a ``wrappers.pixel_observation.PixelObservationWrapper``.
|
|
96
|
+
Defaults to ``False``.
|
|
97
|
+
pixels_only (bool, optional): if ``True``, only the pixel observations will
|
|
98
|
+
be returned (by default under the ``"pixels"`` entry in the output tensordict).
|
|
99
|
+
If ``False``, observations (eg, states) and pixels will be returned
|
|
100
|
+
whenever ``from_pixels=True``. Defaults to ``True``.
|
|
101
|
+
from_depths (bool, optional): if ``True``, an attempt to return the depth
|
|
102
|
+
observations from the env will be performed. By default, these observations
|
|
103
|
+
will be written under the ``"depths"`` entry. Requires ``from_pixels`` to be ``True``.
|
|
104
|
+
Defaults to ``False``.
|
|
105
|
+
frame_skip (int, optional): if provided, indicates for how many steps the
|
|
106
|
+
same action is to be repeated. The observation returned will be the
|
|
107
|
+
last observation of the sequence, whereas the reward will be the sum
|
|
108
|
+
of rewards across steps.
|
|
109
|
+
device (torch.device, optional): if provided, the device on which the data
|
|
110
|
+
is to be cast. Defaults to ``torch.device("cpu")``.
|
|
111
|
+
batch_size (torch.Size, optional): Only ``torch.Size([])`` will work with
|
|
112
|
+
``RoboHiveEnv`` since vectorized environments are not supported within the
|
|
113
|
+
class. To execute more than one environment at a time, see :class:`~torchrl.envs.ParallelEnv`.
|
|
114
|
+
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
|
|
115
|
+
for envs to be ``done`` just after :meth:`reset` is called.
|
|
116
|
+
Defaults to ``False``.
|
|
117
|
+
|
|
118
|
+
Attributes:
|
|
119
|
+
available_envs (list): a list of available envs to build.
|
|
120
|
+
|
|
121
|
+
Examples:
|
|
122
|
+
>>> from torchrl.envs import RoboHiveEnv
|
|
123
|
+
>>> env = RoboHiveEnv(RoboHiveEnv.available_envs[0])
|
|
124
|
+
>>> env.rollout(3)
|
|
125
|
+
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
env_list = []
|
|
129
|
+
|
|
130
|
+
@_classproperty
|
|
131
|
+
def CURR_DIR(cls):
|
|
132
|
+
if _has_robohive:
|
|
133
|
+
import robohive.envs.multi_task.substeps1
|
|
134
|
+
|
|
135
|
+
return robohive.envs.multi_task.substeps1.CURR_DIR
|
|
136
|
+
else:
|
|
137
|
+
return None
|
|
138
|
+
|
|
139
|
+
@_classproperty
|
|
140
|
+
def available_envs(cls):
|
|
141
|
+
if not _has_robohive:
|
|
142
|
+
return []
|
|
143
|
+
cls.register_envs()
|
|
144
|
+
return cls.env_list
|
|
145
|
+
|
|
146
|
+
@classmethod
|
|
147
|
+
def register_envs(cls):
|
|
148
|
+
if not _has_robohive:
|
|
149
|
+
raise ImportError(
|
|
150
|
+
"Cannot load robohive from the current virtual environment."
|
|
151
|
+
)
|
|
152
|
+
from robohive import robohive_env_suite as robohive_envs
|
|
153
|
+
from robohive.utils.prompt_utils import Prompt, set_prompt_verbosity
|
|
154
|
+
|
|
155
|
+
set_prompt_verbosity(Prompt.WARN)
|
|
156
|
+
cls.env_list += robohive_envs
|
|
157
|
+
if not len(robohive_envs):
|
|
158
|
+
raise RuntimeError("did not load any environment.")
|
|
159
|
+
|
|
160
|
+
def _build_env( # noqa: F811
|
|
161
|
+
self,
|
|
162
|
+
env_name: str,
|
|
163
|
+
from_pixels: bool = False,
|
|
164
|
+
pixels_only: bool = False,
|
|
165
|
+
from_depths: bool = False,
|
|
166
|
+
**kwargs,
|
|
167
|
+
) -> gym.core.Env: # noqa: F821
|
|
168
|
+
if from_pixels:
|
|
169
|
+
if "cameras" not in kwargs:
|
|
170
|
+
warnings.warn(
|
|
171
|
+
"from_pixels=True will lead to a registration of ALL available cameras, "
|
|
172
|
+
"which may lead to performance issue. "
|
|
173
|
+
"Consider passing only the needed cameras through cameras=list_of_cameras. "
|
|
174
|
+
"The list of available cameras for a specific environment can be obtained via "
|
|
175
|
+
"RobohiveEnv.get_available_cams(env_name)."
|
|
176
|
+
)
|
|
177
|
+
kwargs["cameras"] = self.get_available_cams(env_name)
|
|
178
|
+
cams = list(kwargs.pop("cameras"))
|
|
179
|
+
env_name = self.register_visual_env(
|
|
180
|
+
cams=cams, env_name=env_name, from_depths=from_depths
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
elif "cameras" in kwargs and kwargs["cameras"]:
|
|
184
|
+
raise RuntimeError("Got a list of cameras but from_pixels is set to False.")
|
|
185
|
+
|
|
186
|
+
self.pixels_only = pixels_only
|
|
187
|
+
try:
|
|
188
|
+
render_device = int(str(self.device)[-1])
|
|
189
|
+
except ValueError:
|
|
190
|
+
render_device = 0
|
|
191
|
+
|
|
192
|
+
if not _has_robohive:
|
|
193
|
+
raise ImportError(
|
|
194
|
+
f"gym/robohive not found, unable to create {env_name}. "
|
|
195
|
+
f"Consider downloading and installing dm_control from"
|
|
196
|
+
f" {self.git_url}"
|
|
197
|
+
)
|
|
198
|
+
try:
|
|
199
|
+
env = self.lib.make(
|
|
200
|
+
env_name,
|
|
201
|
+
frameskip=self.frame_skip,
|
|
202
|
+
device_id=render_device,
|
|
203
|
+
return_dict=True,
|
|
204
|
+
**kwargs,
|
|
205
|
+
)
|
|
206
|
+
self.wrapper_frame_skip = 1
|
|
207
|
+
except TypeError as err:
|
|
208
|
+
if "unexpected keyword argument 'frameskip" not in str(err):
|
|
209
|
+
raise err
|
|
210
|
+
kwargs.pop("framek_skip")
|
|
211
|
+
env = self.lib.make(
|
|
212
|
+
env_name, return_dict=True, device_id=render_device, **kwargs
|
|
213
|
+
)
|
|
214
|
+
self.wrapper_frame_skip = self.frame_skip
|
|
215
|
+
# except Exception as err:
|
|
216
|
+
# raise RuntimeError(f"Failed to build env {env_name}.") from err
|
|
217
|
+
self.from_pixels = from_pixels
|
|
218
|
+
self.from_depths = from_depths
|
|
219
|
+
self.render_device = render_device
|
|
220
|
+
if kwargs.get("read_info", True):
|
|
221
|
+
self.set_info_dict_reader(self.read_info)
|
|
222
|
+
return env
|
|
223
|
+
|
|
224
|
+
def _make_specs(self, env: gym.Env, batch_size=None) -> None: # noqa: F821
|
|
225
|
+
out = super()._make_specs(env=env, batch_size=batch_size)
|
|
226
|
+
self.env.reset()
|
|
227
|
+
*_, info = self.env.step(self.env.action_space.sample())
|
|
228
|
+
info = self.read_info(info, TensorDict())
|
|
229
|
+
info = info.get("info")
|
|
230
|
+
self.observation_spec["info"] = make_composite_from_td(info)
|
|
231
|
+
return out
|
|
232
|
+
|
|
233
|
+
@classmethod
|
|
234
|
+
def register_visual_env(cls, env_name, cams, from_depths):
|
|
235
|
+
with set_directory(cls.CURR_DIR):
|
|
236
|
+
from robohive.envs.env_variants import register_env_variant
|
|
237
|
+
|
|
238
|
+
if not len(cams):
|
|
239
|
+
raise RuntimeError("Cannot create a visual envs without cameras.")
|
|
240
|
+
cams = sorted(cams)
|
|
241
|
+
cams_rep = [i.replace("A:", "A_") for i in cams]
|
|
242
|
+
new_env_name = "-".join([cam[:-3] for cam in cams_rep] + [env_name])
|
|
243
|
+
visual_keys = [f"rgb:{c}:224x224:2d" for c in cams]
|
|
244
|
+
if from_depths:
|
|
245
|
+
visual_keys.extend([f"d:{c}:224x224:2d" for c in cams])
|
|
246
|
+
register_env_variant(
|
|
247
|
+
env_name,
|
|
248
|
+
variants={
|
|
249
|
+
"visual_keys": visual_keys,
|
|
250
|
+
},
|
|
251
|
+
variant_id=new_env_name,
|
|
252
|
+
)
|
|
253
|
+
env_name = new_env_name
|
|
254
|
+
cls.env_list += [env_name]
|
|
255
|
+
return env_name
|
|
256
|
+
|
|
257
|
+
@_maybe_unlock
|
|
258
|
+
def _refine_specs(self) -> None: # noqa: F821
|
|
259
|
+
env = self._env
|
|
260
|
+
self.action_spec = _gym_to_torchrl_spec_transform(
|
|
261
|
+
env.action_space, device=self.device
|
|
262
|
+
)
|
|
263
|
+
# get a np rollout
|
|
264
|
+
rollout = TensorDict({"done": torch.zeros(3, 1)}, [3])
|
|
265
|
+
env.reset()
|
|
266
|
+
|
|
267
|
+
def get_obs():
|
|
268
|
+
_dict = {}
|
|
269
|
+
obs_dict = copy(env.obs_dict)
|
|
270
|
+
if self.from_pixels:
|
|
271
|
+
visual = self.env.get_exteroception()
|
|
272
|
+
obs_dict.update(visual)
|
|
273
|
+
pixel_list, depth_list = [], []
|
|
274
|
+
for obs_key in obs_dict:
|
|
275
|
+
if obs_key.startswith("rgb"):
|
|
276
|
+
pix = obs_dict[obs_key]
|
|
277
|
+
if not pix.shape[0] == 1:
|
|
278
|
+
pix = pix[None]
|
|
279
|
+
pixel_list.append(pix)
|
|
280
|
+
elif obs_key.startswith("d:"):
|
|
281
|
+
dep = obs_dict[obs_key]
|
|
282
|
+
dep = dep[None]
|
|
283
|
+
depth_list.append(dep)
|
|
284
|
+
elif obs_key in env.obs_keys:
|
|
285
|
+
value = env.obs_dict[obs_key]
|
|
286
|
+
if not value.shape:
|
|
287
|
+
value = value[None]
|
|
288
|
+
_dict[obs_key] = value
|
|
289
|
+
if pixel_list:
|
|
290
|
+
_dict["pixels"] = np.concatenate(pixel_list, 0)
|
|
291
|
+
if depth_list:
|
|
292
|
+
_dict["depths"] = np.concatenate(depth_list, 0)
|
|
293
|
+
return _dict
|
|
294
|
+
|
|
295
|
+
for i in range(3):
|
|
296
|
+
_dict = {}
|
|
297
|
+
_dict.update(get_obs())
|
|
298
|
+
_dict["action"] = action = env.action_space.sample()
|
|
299
|
+
_, r, trunc, term, done, _ = self._output_transform(env.step(action))
|
|
300
|
+
_dict[("next", "reward")] = r.reshape(1)
|
|
301
|
+
_dict[("next", "done")] = [1]
|
|
302
|
+
_dict[("next", "terminated")] = [1]
|
|
303
|
+
_dict[("next", "truncated")] = [1]
|
|
304
|
+
_dict["next"] = get_obs()
|
|
305
|
+
rollout[i] = TensorDict(_dict, [])
|
|
306
|
+
|
|
307
|
+
observation_spec = make_composite_from_td(
|
|
308
|
+
rollout.get("next").exclude("done", "reward", "terminated", "truncated")[0]
|
|
309
|
+
)
|
|
310
|
+
self.observation_spec = observation_spec
|
|
311
|
+
|
|
312
|
+
self.reward_spec = Unbounded(
|
|
313
|
+
shape=(1,),
|
|
314
|
+
device=self.device,
|
|
315
|
+
) # default
|
|
316
|
+
|
|
317
|
+
rollout = self.rollout(2, return_contiguous=False).get("next")
|
|
318
|
+
rollout = rollout.exclude(
|
|
319
|
+
self.reward_key, *self.done_keys, *self.observation_spec.keys(True, True)
|
|
320
|
+
)
|
|
321
|
+
rollout = rollout[..., 0]
|
|
322
|
+
spec = make_composite_from_td(rollout)
|
|
323
|
+
self.observation_spec.update(spec)
|
|
324
|
+
self.empty_cache()
|
|
325
|
+
|
|
326
|
+
def _reset_output_transform(self, reset_data):
|
|
327
|
+
if not (isinstance(reset_data, tuple) and len(reset_data) == 2):
|
|
328
|
+
return reset_data, {}
|
|
329
|
+
return reset_data
|
|
330
|
+
|
|
331
|
+
def set_from_pixels(self, from_pixels: bool) -> None:
|
|
332
|
+
"""Sets the from_pixels attribute to an existing environment.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
from_pixels (bool): new value for the from_pixels attribute
|
|
336
|
+
|
|
337
|
+
"""
|
|
338
|
+
if from_pixels is self.from_pixels:
|
|
339
|
+
return
|
|
340
|
+
self.from_pixels = from_pixels
|
|
341
|
+
self._refine_specs()
|
|
342
|
+
|
|
343
|
+
def read_obs(self, observation):
|
|
344
|
+
# the info is missing from the reset
|
|
345
|
+
observations = self.env.obs_dict
|
|
346
|
+
try:
|
|
347
|
+
del observations["t"]
|
|
348
|
+
except KeyError:
|
|
349
|
+
pass
|
|
350
|
+
# recover vec
|
|
351
|
+
obsdict = {}
|
|
352
|
+
pixel_list, depth_list = [], []
|
|
353
|
+
if self.from_pixels:
|
|
354
|
+
visual = self.env.get_exteroception()
|
|
355
|
+
observations.update(visual)
|
|
356
|
+
for key in observations:
|
|
357
|
+
if key.startswith("rgb"):
|
|
358
|
+
pix = observations[key]
|
|
359
|
+
if not pix.shape[0] == 1:
|
|
360
|
+
pix = pix[None]
|
|
361
|
+
pixel_list.append(pix)
|
|
362
|
+
elif key.startswith("d:"):
|
|
363
|
+
dep = observations[key]
|
|
364
|
+
dep = dep[None]
|
|
365
|
+
depth_list.append(dep)
|
|
366
|
+
elif key in self._env.obs_keys:
|
|
367
|
+
value = observations[key]
|
|
368
|
+
if not value.shape:
|
|
369
|
+
value = value[None]
|
|
370
|
+
obsdict[key] = value # ravel helps with images
|
|
371
|
+
# if obsvec:
|
|
372
|
+
# obsvec = np.concatenate(obsvec, 0)
|
|
373
|
+
if self.from_pixels:
|
|
374
|
+
obsdict.update({"pixels": np.concatenate(pixel_list, 0)})
|
|
375
|
+
if self.from_pixels and self.from_depths:
|
|
376
|
+
obsdict.update({"depths": np.concatenate(depth_list, 0)})
|
|
377
|
+
out = obsdict
|
|
378
|
+
return super().read_obs(out)
|
|
379
|
+
|
|
380
|
+
def read_info(self, info, tensordict_out):
|
|
381
|
+
if not info:
|
|
382
|
+
info_spec = self.observation_spec.get("info", None)
|
|
383
|
+
if info_spec is None:
|
|
384
|
+
return tensordict_out
|
|
385
|
+
tensordict_out.set("info", info_spec.zero())
|
|
386
|
+
return tensordict_out
|
|
387
|
+
out = (
|
|
388
|
+
TensorDict(info, [])
|
|
389
|
+
.filter_non_tensor_data()
|
|
390
|
+
.exclude("obs_dict", "done", "reward", *self._env.obs_keys, "act")
|
|
391
|
+
.apply(lambda x: x, filter_empty=True)
|
|
392
|
+
)
|
|
393
|
+
if "info" in self.observation_spec.keys():
|
|
394
|
+
info_spec = self.observation_spec["info"]
|
|
395
|
+
|
|
396
|
+
def func(name, x):
|
|
397
|
+
spec = info_spec.get(name, None)
|
|
398
|
+
if spec is None:
|
|
399
|
+
return None
|
|
400
|
+
return x.reshape(info_spec[name].shape)
|
|
401
|
+
|
|
402
|
+
out.update(out.named_apply(func, nested_keys=True, filter_empty=True))
|
|
403
|
+
else:
|
|
404
|
+
out.update(
|
|
405
|
+
out.apply(
|
|
406
|
+
lambda x: x.reshape((1,)) if not x.shape else x, filter_empty=True
|
|
407
|
+
)
|
|
408
|
+
)
|
|
409
|
+
tensordict_out.set("info", out)
|
|
410
|
+
return tensordict_out
|
|
411
|
+
|
|
412
|
+
def _init_env(self):
|
|
413
|
+
pass
|
|
414
|
+
|
|
415
|
+
def to(self, *args, **kwargs):
|
|
416
|
+
out = super().to(*args, **kwargs)
|
|
417
|
+
try:
|
|
418
|
+
render_device = int(str(out.device)[-1])
|
|
419
|
+
except ValueError:
|
|
420
|
+
render_device = 0
|
|
421
|
+
if render_device != self.render_device:
|
|
422
|
+
out._build_env(**self._constructor_kwargs)
|
|
423
|
+
return out
|
|
424
|
+
|
|
425
|
+
@classmethod
|
|
426
|
+
def get_available_cams(cls, env_name):
|
|
427
|
+
env = gym_backend().make(env_name)
|
|
428
|
+
cams = [env.sim.model.id2name(ic, 7) for ic in range(env.sim.model.ncam)]
|
|
429
|
+
return cams
|