torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,544 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import collections
|
|
8
|
+
import importlib
|
|
9
|
+
import os
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch
|
|
14
|
+
|
|
15
|
+
from torchrl._utils import logger as torchrl_logger, VERBOSE
|
|
16
|
+
from torchrl.data.tensor_specs import (
|
|
17
|
+
Bounded,
|
|
18
|
+
Categorical,
|
|
19
|
+
Composite,
|
|
20
|
+
OneHot,
|
|
21
|
+
TensorSpec,
|
|
22
|
+
Unbounded,
|
|
23
|
+
)
|
|
24
|
+
from torchrl.data.utils import DEVICE_TYPING, numpy_to_torch_dtype_dict
|
|
25
|
+
from torchrl.envs.common import _EnvPostInit
|
|
26
|
+
from torchrl.envs.gym_like import GymLikeEnv
|
|
27
|
+
from torchrl.envs.utils import _classproperty
|
|
28
|
+
|
|
29
|
+
if torch.cuda.device_count() > 1:
|
|
30
|
+
n = torch.cuda.device_count() - 1
|
|
31
|
+
os.environ["EGL_DEVICE_ID"] = str(1 + (os.getpid() % n))
|
|
32
|
+
if VERBOSE:
|
|
33
|
+
torchrl_logger.info(f"EGL_DEVICE_ID: {os.environ['EGL_DEVICE_ID']}")
|
|
34
|
+
|
|
35
|
+
_has_dmc = _has_dm_control = importlib.util.find_spec("dm_control") is not None
|
|
36
|
+
|
|
37
|
+
__all__ = ["DMControlEnv", "DMControlWrapper"]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _dmcontrol_to_torchrl_spec_transform(
|
|
41
|
+
spec,
|
|
42
|
+
dtype: torch.dtype | None = None,
|
|
43
|
+
device: DEVICE_TYPING = None,
|
|
44
|
+
categorical_discrete_encoding: bool = False,
|
|
45
|
+
) -> TensorSpec:
|
|
46
|
+
import dm_env
|
|
47
|
+
|
|
48
|
+
if isinstance(spec, collections.OrderedDict) or isinstance(spec, dict):
|
|
49
|
+
spec = {
|
|
50
|
+
k: _dmcontrol_to_torchrl_spec_transform(
|
|
51
|
+
item,
|
|
52
|
+
device=device,
|
|
53
|
+
categorical_discrete_encoding=categorical_discrete_encoding,
|
|
54
|
+
)
|
|
55
|
+
for k, item in spec.items()
|
|
56
|
+
}
|
|
57
|
+
return Composite(**spec)
|
|
58
|
+
elif isinstance(spec, dm_env.specs.DiscreteArray):
|
|
59
|
+
# DiscreteArray is a type of BoundedArray so this block needs to go first
|
|
60
|
+
action_space_cls = Categorical if categorical_discrete_encoding else OneHot
|
|
61
|
+
if dtype is None:
|
|
62
|
+
dtype = (
|
|
63
|
+
numpy_to_torch_dtype_dict[spec.dtype]
|
|
64
|
+
if categorical_discrete_encoding
|
|
65
|
+
else torch.long
|
|
66
|
+
)
|
|
67
|
+
return action_space_cls(spec.num_values, device=device, dtype=dtype)
|
|
68
|
+
elif isinstance(spec, dm_env.specs.BoundedArray):
|
|
69
|
+
if dtype is None:
|
|
70
|
+
dtype = numpy_to_torch_dtype_dict[spec.dtype]
|
|
71
|
+
shape = spec.shape
|
|
72
|
+
if not len(shape):
|
|
73
|
+
shape = torch.Size([1])
|
|
74
|
+
return Bounded(
|
|
75
|
+
shape=shape,
|
|
76
|
+
low=spec.minimum,
|
|
77
|
+
high=spec.maximum,
|
|
78
|
+
dtype=dtype,
|
|
79
|
+
device=device,
|
|
80
|
+
)
|
|
81
|
+
elif isinstance(spec, dm_env.specs.Array):
|
|
82
|
+
shape = spec.shape
|
|
83
|
+
if not len(shape):
|
|
84
|
+
shape = torch.Size([1])
|
|
85
|
+
if dtype is None:
|
|
86
|
+
dtype = numpy_to_torch_dtype_dict[spec.dtype]
|
|
87
|
+
if dtype in (torch.float, torch.double, torch.half):
|
|
88
|
+
return Unbounded(shape=shape, dtype=dtype, device=device)
|
|
89
|
+
else:
|
|
90
|
+
return Unbounded(shape=shape, dtype=dtype, device=device)
|
|
91
|
+
else:
|
|
92
|
+
raise NotImplementedError(type(spec))
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _get_envs(to_dict: bool = True) -> dict[str, Any]:
|
|
96
|
+
if not _has_dm_control:
|
|
97
|
+
raise ImportError("Cannot find dm_control in virtual environment.")
|
|
98
|
+
from dm_control import suite
|
|
99
|
+
|
|
100
|
+
if not to_dict:
|
|
101
|
+
return tuple(suite.BENCHMARKING) + tuple(suite.EXTRA)
|
|
102
|
+
d = {}
|
|
103
|
+
for tup in suite.BENCHMARKING:
|
|
104
|
+
env_name = tup[0]
|
|
105
|
+
d.setdefault(env_name, []).append(tup[1])
|
|
106
|
+
for tup in suite.EXTRA:
|
|
107
|
+
env_name = tup[0]
|
|
108
|
+
d.setdefault(env_name, []).append(tup[1])
|
|
109
|
+
return d.items()
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _robust_to_tensor(array: float | np.ndarray) -> torch.Tensor:
|
|
113
|
+
if isinstance(array, np.ndarray):
|
|
114
|
+
return torch.as_tensor(array.copy())
|
|
115
|
+
else:
|
|
116
|
+
return torch.as_tensor(array)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class _DMControlMeta(_EnvPostInit):
|
|
120
|
+
"""Metaclass for DMControlEnv that returns a lazy ParallelEnv when num_workers > 1.
|
|
121
|
+
|
|
122
|
+
When ``DMControlEnv(..., num_workers=4)`` is called, this metaclass intercepts the
|
|
123
|
+
call and returns a :class:`~torchrl.envs.ParallelEnv` instead. The returned
|
|
124
|
+
ParallelEnv is lazy — workers are not started until the environment is actually used
|
|
125
|
+
(e.g., via :meth:`torchrl.envs.batched_envs.BatchedEnvBase.reset` / :meth:`torchrl.envs.batched_envs.BatchedEnvBase.step`
|
|
126
|
+
or when accessing specs).
|
|
127
|
+
|
|
128
|
+
Users can call :meth:`torchrl.envs.batched_envs.BatchedEnvBase.configure_parallel`
|
|
129
|
+
to set ParallelEnv parameters before the environment starts.
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
def __call__(cls, *args, num_workers: int | None = None, **kwargs):
|
|
133
|
+
# Extract num_workers from explicit kwarg or kwargs dict
|
|
134
|
+
if num_workers is None:
|
|
135
|
+
num_workers = kwargs.pop("num_workers", 1)
|
|
136
|
+
else:
|
|
137
|
+
kwargs.pop("num_workers", None)
|
|
138
|
+
|
|
139
|
+
num_workers = int(num_workers) if num_workers is not None else 1
|
|
140
|
+
if cls.__name__ == "DMControlEnv" and num_workers > 1:
|
|
141
|
+
from torchrl.envs import ParallelEnv
|
|
142
|
+
|
|
143
|
+
# Extract env_name and task_name from args
|
|
144
|
+
env_name = args[0] if len(args) >= 1 else kwargs.get("env_name")
|
|
145
|
+
task_name = args[1] if len(args) >= 2 else kwargs.get("task_name")
|
|
146
|
+
|
|
147
|
+
# Remove env_name and task_name from kwargs if they were there
|
|
148
|
+
# (they'll be passed positionally to the env creator)
|
|
149
|
+
env_kwargs = {
|
|
150
|
+
k: v for k, v in kwargs.items() if k not in ("env_name", "task_name")
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
# Create factory function that builds single DMControlEnv instances
|
|
154
|
+
def make_env(_env_name=env_name, _task_name=task_name, _kwargs=env_kwargs):
|
|
155
|
+
return cls(_env_name, _task_name, num_workers=1, **_kwargs)
|
|
156
|
+
|
|
157
|
+
# Return lazy ParallelEnv (workers not started yet)
|
|
158
|
+
return ParallelEnv(num_workers, make_env)
|
|
159
|
+
|
|
160
|
+
return super().__call__(*args, **kwargs)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class DMControlWrapper(GymLikeEnv):
|
|
164
|
+
"""DeepMind Control lab environment wrapper.
|
|
165
|
+
|
|
166
|
+
The DeepMind control library can be found here: https://github.com/deepmind/dm_control.
|
|
167
|
+
|
|
168
|
+
Paper: https://arxiv.org/abs/2006.12983
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
env (dm_control.suite env): :class:`~dm_control.suite.base.Task`
|
|
172
|
+
environment instance.
|
|
173
|
+
|
|
174
|
+
Keyword Args:
|
|
175
|
+
from_pixels (bool, optional): if ``True``, an attempt to return the pixel
|
|
176
|
+
observations from the env will be performed.
|
|
177
|
+
By default, these observations
|
|
178
|
+
will be written under the ``"pixels"`` entry.
|
|
179
|
+
Defaults to ``False``.
|
|
180
|
+
pixels_only (bool, optional): if ``True``, only the pixel observations will
|
|
181
|
+
be returned (by default under the ``"pixels"`` entry in the output tensordict).
|
|
182
|
+
If ``False``, observations (eg, states) and pixels will be returned
|
|
183
|
+
whenever ``from_pixels=True``. Defaults to ``True``.
|
|
184
|
+
frame_skip (int, optional): if provided, indicates for how many steps the
|
|
185
|
+
same action is to be repeated. The observation returned will be the
|
|
186
|
+
last observation of the sequence, whereas the reward will be the sum
|
|
187
|
+
of rewards across steps.
|
|
188
|
+
device (torch.device, optional): if provided, the device on which the data
|
|
189
|
+
is to be cast. Defaults to ``torch.device("cpu")``.
|
|
190
|
+
batch_size (torch.Size, optional): the batch size of the environment.
|
|
191
|
+
Should match the leading dimensions of all observations, done states,
|
|
192
|
+
rewards, actions and infos.
|
|
193
|
+
Defaults to ``torch.Size([])``.
|
|
194
|
+
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
|
|
195
|
+
for envs to be ``done`` just after :meth:`reset` is called.
|
|
196
|
+
Defaults to ``False``.
|
|
197
|
+
|
|
198
|
+
Attributes:
|
|
199
|
+
available_envs (list): a list of ``Tuple[str, List[str]]`` representing the
|
|
200
|
+
environment / task pairs available.
|
|
201
|
+
|
|
202
|
+
Examples:
|
|
203
|
+
>>> from dm_control import suite
|
|
204
|
+
>>> from torchrl.envs import DMControlWrapper
|
|
205
|
+
>>> env = suite.load("cheetah", "run")
|
|
206
|
+
>>> env = DMControlWrapper(env,
|
|
207
|
+
... from_pixels=True, frame_skip=4)
|
|
208
|
+
>>> td = env.rand_step()
|
|
209
|
+
>>> print(td)
|
|
210
|
+
TensorDict(
|
|
211
|
+
fields={
|
|
212
|
+
action: Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
213
|
+
next: TensorDict(
|
|
214
|
+
fields={
|
|
215
|
+
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
216
|
+
pixels: Tensor(shape=torch.Size([240, 320, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
217
|
+
position: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
218
|
+
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
219
|
+
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
220
|
+
truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
221
|
+
velocity: Tensor(shape=torch.Size([9]), device=cpu, dtype=torch.float64, is_shared=False)},
|
|
222
|
+
batch_size=torch.Size([]),
|
|
223
|
+
device=cpu,
|
|
224
|
+
is_shared=False)},
|
|
225
|
+
batch_size=torch.Size([]),
|
|
226
|
+
device=cpu,
|
|
227
|
+
is_shared=False)
|
|
228
|
+
>>> print(env.available_envs)
|
|
229
|
+
[('acrobot', ['swingup', 'swingup_sparse']), ('ball_in_cup', ['catch']), ('cartpole', ['balance', 'balance_sparse', 'swingup', 'swingup_sparse', 'three_poles', 'two_poles']), ('cheetah', ['run']), ('finger', ['spin', 'turn_easy', 'turn_hard']), ('fish', ['upright', 'swim']), ('hopper', ['stand', 'hop']), ('humanoid', ['stand', 'walk', 'run', 'run_pure_state']), ('manipulator', ['bring_ball', 'bring_peg', 'insert_ball', 'insert_peg']), ('pendulum', ['swingup']), ('point_mass', ['easy', 'hard']), ('reacher', ['easy', 'hard']), ('swimmer', ['swimmer6', 'swimmer15']), ('walker', ['stand', 'walk', 'run']), ('dog', ['fetch', 'run', 'stand', 'trot', 'walk']), ('humanoid_CMU', ['run', 'stand', 'walk']), ('lqr', ['lqr_2_1', 'lqr_6_2']), ('quadruped', ['escape', 'fetch', 'run', 'walk']), ('stacker', ['stack_2', 'stack_4'])]
|
|
230
|
+
|
|
231
|
+
"""
|
|
232
|
+
|
|
233
|
+
git_url = "https://github.com/deepmind/dm_control"
|
|
234
|
+
libname = "dm_control"
|
|
235
|
+
|
|
236
|
+
@_classproperty
|
|
237
|
+
def available_envs(cls):
|
|
238
|
+
if not _has_dm_control:
|
|
239
|
+
return []
|
|
240
|
+
return list(_get_envs())
|
|
241
|
+
|
|
242
|
+
@property
|
|
243
|
+
def lib(self):
|
|
244
|
+
import dm_control
|
|
245
|
+
|
|
246
|
+
return dm_control
|
|
247
|
+
|
|
248
|
+
def __init__(self, env=None, **kwargs):
|
|
249
|
+
if env is not None:
|
|
250
|
+
kwargs["env"] = env
|
|
251
|
+
super().__init__(**kwargs)
|
|
252
|
+
|
|
253
|
+
def _build_env(
|
|
254
|
+
self,
|
|
255
|
+
env,
|
|
256
|
+
_seed: int | None = None,
|
|
257
|
+
from_pixels: bool = False,
|
|
258
|
+
render_kwargs: dict | None = None,
|
|
259
|
+
pixels_only: bool = False,
|
|
260
|
+
camera_id: int | str = 0,
|
|
261
|
+
**kwargs,
|
|
262
|
+
):
|
|
263
|
+
self.from_pixels = from_pixels
|
|
264
|
+
self.pixels_only = pixels_only
|
|
265
|
+
|
|
266
|
+
if from_pixels:
|
|
267
|
+
from dm_control.suite.wrappers import pixels
|
|
268
|
+
|
|
269
|
+
self._set_egl_device(self.device)
|
|
270
|
+
self.render_kwargs = {"camera_id": camera_id}
|
|
271
|
+
if render_kwargs is not None:
|
|
272
|
+
self.render_kwargs.update(render_kwargs)
|
|
273
|
+
env = pixels.Wrapper(
|
|
274
|
+
env,
|
|
275
|
+
pixels_only=self.pixels_only,
|
|
276
|
+
render_kwargs=self.render_kwargs,
|
|
277
|
+
)
|
|
278
|
+
return env
|
|
279
|
+
|
|
280
|
+
def _make_specs(self, env: gym.Env) -> None: # noqa: F821
|
|
281
|
+
# specs are defined when first called
|
|
282
|
+
self.observation_spec = _dmcontrol_to_torchrl_spec_transform(
|
|
283
|
+
self._env.observation_spec(), device=self.device
|
|
284
|
+
)
|
|
285
|
+
reward_spec = _dmcontrol_to_torchrl_spec_transform(
|
|
286
|
+
self._env.reward_spec(), device=self.device
|
|
287
|
+
)
|
|
288
|
+
if len(reward_spec.shape) == 0:
|
|
289
|
+
reward_spec.shape = torch.Size([1])
|
|
290
|
+
self.reward_spec = reward_spec
|
|
291
|
+
# populate default done spec
|
|
292
|
+
done_spec = Categorical(
|
|
293
|
+
n=2, shape=(*self.batch_size, 1), dtype=torch.bool, device=self.device
|
|
294
|
+
)
|
|
295
|
+
self.done_spec = Composite(
|
|
296
|
+
done=done_spec.clone(),
|
|
297
|
+
truncated=done_spec.clone(),
|
|
298
|
+
terminated=done_spec.clone(),
|
|
299
|
+
device=self.device,
|
|
300
|
+
)
|
|
301
|
+
self.action_spec = _dmcontrol_to_torchrl_spec_transform(
|
|
302
|
+
self._env.action_spec(), device=self.device
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
def _check_kwargs(self, kwargs: dict):
|
|
306
|
+
dm_control = self.lib
|
|
307
|
+
from dm_control.suite.wrappers import pixels
|
|
308
|
+
|
|
309
|
+
if "env" not in kwargs:
|
|
310
|
+
raise TypeError("Could not find environment key 'env' in kwargs.")
|
|
311
|
+
env = kwargs["env"]
|
|
312
|
+
if not isinstance(env, (dm_control.rl.control.Environment, pixels.Wrapper)):
|
|
313
|
+
raise TypeError(
|
|
314
|
+
"env is not of type 'dm_control.rl.control.Environment' or `dm_control.suite.wrappers.pixels.Wrapper`."
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def _set_egl_device(self, device: DEVICE_TYPING):
|
|
318
|
+
# Deprecated as lead to unreliable rendering
|
|
319
|
+
# egl device needs to be set before importing mujoco bindings: in
|
|
320
|
+
# distributed settings, it'll be easy to tell which cuda device to use.
|
|
321
|
+
# In mp settings, we'll need to use mp.Pool with a specific init function
|
|
322
|
+
# that defines the EGL device before importing libraries. For now, we'll
|
|
323
|
+
# just use a common EGL_DEVICE_ID environment variable for all processes.
|
|
324
|
+
return
|
|
325
|
+
|
|
326
|
+
def to(self, device: DEVICE_TYPING) -> DMControlEnv:
|
|
327
|
+
super().to(device)
|
|
328
|
+
self._set_egl_device(self.device)
|
|
329
|
+
return self
|
|
330
|
+
|
|
331
|
+
def _init_env(self, seed: int | None = None) -> int | None:
|
|
332
|
+
seed = self.set_seed(seed)
|
|
333
|
+
return seed
|
|
334
|
+
|
|
335
|
+
def _set_seed(self, _seed: int | None) -> None:
|
|
336
|
+
from dm_control.suite.wrappers import pixels
|
|
337
|
+
|
|
338
|
+
if _seed is None:
|
|
339
|
+
return None
|
|
340
|
+
random_state = np.random.RandomState(_seed)
|
|
341
|
+
if isinstance(self._env, pixels.Wrapper):
|
|
342
|
+
if not hasattr(self._env._env.task, "_random"):
|
|
343
|
+
raise RuntimeError("self._env._env.task._random does not exist")
|
|
344
|
+
self._env._env.task._random = random_state
|
|
345
|
+
else:
|
|
346
|
+
if not hasattr(self._env.task, "_random"):
|
|
347
|
+
raise RuntimeError("self._env._env.task._random does not exist")
|
|
348
|
+
self._env.task._random = random_state
|
|
349
|
+
self.reset()
|
|
350
|
+
|
|
351
|
+
def _output_transform(
|
|
352
|
+
self, timestep_tuple: tuple[TimeStep] # noqa: F821
|
|
353
|
+
) -> tuple[np.ndarray, float, bool, bool, dict]:
|
|
354
|
+
from dm_env import StepType
|
|
355
|
+
|
|
356
|
+
if type(timestep_tuple) is not tuple:
|
|
357
|
+
timestep_tuple = (timestep_tuple,)
|
|
358
|
+
reward = timestep_tuple[0].reward
|
|
359
|
+
|
|
360
|
+
truncated = terminated = False
|
|
361
|
+
if timestep_tuple[0].step_type == StepType.LAST:
|
|
362
|
+
if np.isclose(timestep_tuple[0].discount, 1):
|
|
363
|
+
truncated = True
|
|
364
|
+
else:
|
|
365
|
+
terminated = True
|
|
366
|
+
done = truncated or terminated
|
|
367
|
+
|
|
368
|
+
observation = timestep_tuple[0].observation
|
|
369
|
+
info = {}
|
|
370
|
+
|
|
371
|
+
return observation, reward, terminated, truncated, done, info
|
|
372
|
+
|
|
373
|
+
def _reset_output_transform(self, reset_data):
|
|
374
|
+
(
|
|
375
|
+
observation,
|
|
376
|
+
reward,
|
|
377
|
+
terminated,
|
|
378
|
+
truncated,
|
|
379
|
+
done,
|
|
380
|
+
info,
|
|
381
|
+
) = self._output_transform(reset_data)
|
|
382
|
+
return observation, info
|
|
383
|
+
|
|
384
|
+
def __repr__(self) -> str:
|
|
385
|
+
return (
|
|
386
|
+
f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})"
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
class DMControlEnv(DMControlWrapper, metaclass=_DMControlMeta):
|
|
391
|
+
"""DeepMind Control lab environment wrapper.
|
|
392
|
+
|
|
393
|
+
The DeepMind control library can be found here: https://github.com/deepmind/dm_control.
|
|
394
|
+
|
|
395
|
+
Paper: https://arxiv.org/abs/2006.12983
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
env_name (str): name of the environment.
|
|
399
|
+
task_name (str): name of the task.
|
|
400
|
+
num_workers (int, optional): number of parallel environments. Defaults to 1.
|
|
401
|
+
When ``num_workers > 1``, a lazy :class:`~torchrl.envs.ParallelEnv` is
|
|
402
|
+
returned instead of a single environment. The parallel environment
|
|
403
|
+
is not started until it is actually used (e.g., via reset/step or
|
|
404
|
+
accessing specs). Use :meth:`~torchrl.envs.BatchedEnvBase.configure_parallel`
|
|
405
|
+
to set parallel execution parameters before the environment starts.
|
|
406
|
+
|
|
407
|
+
Keyword Args:
|
|
408
|
+
from_pixels (bool, optional): if ``True``, an attempt to return the pixel
|
|
409
|
+
observations from the env will be performed.
|
|
410
|
+
By default, these observations
|
|
411
|
+
will be written under the ``"pixels"`` entry.
|
|
412
|
+
Defaults to ``False``.
|
|
413
|
+
pixels_only (bool, optional): if ``True``, only the pixel observations will
|
|
414
|
+
be returned (by default under the ``"pixels"`` entry in the output tensordict).
|
|
415
|
+
If ``False``, observations (eg, states) and pixels will be returned
|
|
416
|
+
whenever ``from_pixels=True``. Defaults to ``True``.
|
|
417
|
+
frame_skip (int, optional): if provided, indicates for how many steps the
|
|
418
|
+
same action is to be repeated. The observation returned will be the
|
|
419
|
+
last observation of the sequence, whereas the reward will be the sum
|
|
420
|
+
of rewards across steps.
|
|
421
|
+
device (torch.device, optional): if provided, the device on which the data
|
|
422
|
+
is to be cast. Defaults to ``torch.device("cpu")``.
|
|
423
|
+
batch_size (torch.Size, optional): the batch size of the environment.
|
|
424
|
+
Should match the leading dimensions of all observations, done states,
|
|
425
|
+
rewards, actions and infos.
|
|
426
|
+
Defaults to ``torch.Size([])``.
|
|
427
|
+
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
|
|
428
|
+
for envs to be ``done`` just after :meth:`reset` is called.
|
|
429
|
+
Defaults to ``False``.
|
|
430
|
+
|
|
431
|
+
Attributes:
|
|
432
|
+
available_envs (list): a list of ``Tuple[str, List[str]]`` representing the
|
|
433
|
+
environment / task pairs available.
|
|
434
|
+
|
|
435
|
+
Examples:
|
|
436
|
+
>>> from torchrl.envs import DMControlEnv
|
|
437
|
+
>>> env = DMControlEnv(env_name="cheetah", task_name="run",
|
|
438
|
+
... from_pixels=True, frame_skip=4)
|
|
439
|
+
>>> td = env.rand_step()
|
|
440
|
+
>>> print(td)
|
|
441
|
+
TensorDict(
|
|
442
|
+
fields={
|
|
443
|
+
action: Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
444
|
+
next: TensorDict(
|
|
445
|
+
fields={
|
|
446
|
+
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
447
|
+
pixels: Tensor(shape=torch.Size([240, 320, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
448
|
+
position: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
449
|
+
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
450
|
+
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
451
|
+
truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
452
|
+
velocity: Tensor(shape=torch.Size([9]), device=cpu, dtype=torch.float64, is_shared=False)},
|
|
453
|
+
batch_size=torch.Size([]),
|
|
454
|
+
device=cpu,
|
|
455
|
+
is_shared=False)},
|
|
456
|
+
batch_size=torch.Size([]),
|
|
457
|
+
device=cpu,
|
|
458
|
+
is_shared=False)
|
|
459
|
+
>>> print(env.available_envs)
|
|
460
|
+
[('acrobot', ['swingup', 'swingup_sparse']), ...]
|
|
461
|
+
>>> # For running multiple envs in parallel (returns a lazy ParallelEnv)
|
|
462
|
+
>>> env = DMControlEnv("cheetah", "run", num_workers=4)
|
|
463
|
+
>>> # Configure parallel parameters before the env starts
|
|
464
|
+
>>> env.configure_parallel(use_buffers=True, num_threads=2)
|
|
465
|
+
>>> # Environment starts when first used
|
|
466
|
+
>>> env.reset()
|
|
467
|
+
"""
|
|
468
|
+
|
|
469
|
+
def __init__(self, env_name, task_name, **kwargs):
|
|
470
|
+
if not _has_dmc:
|
|
471
|
+
raise ImportError(
|
|
472
|
+
"dm_control python package was not found. Please install this dependency."
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
kwargs["env_name"] = env_name
|
|
476
|
+
kwargs["task_name"] = task_name
|
|
477
|
+
|
|
478
|
+
super().__init__(**kwargs)
|
|
479
|
+
|
|
480
|
+
def _build_env(
|
|
481
|
+
self,
|
|
482
|
+
env_name: str,
|
|
483
|
+
task_name: str,
|
|
484
|
+
_seed: int | None = None,
|
|
485
|
+
**kwargs,
|
|
486
|
+
):
|
|
487
|
+
from dm_control import suite
|
|
488
|
+
|
|
489
|
+
self.env_name = env_name
|
|
490
|
+
self.task_name = task_name
|
|
491
|
+
|
|
492
|
+
from_pixels = kwargs.get("from_pixels")
|
|
493
|
+
if "from_pixels" in kwargs:
|
|
494
|
+
del kwargs["from_pixels"]
|
|
495
|
+
pixels_only = kwargs.get("pixels_only")
|
|
496
|
+
if "pixels_only" in kwargs:
|
|
497
|
+
del kwargs["pixels_only"]
|
|
498
|
+
|
|
499
|
+
if not _has_dmc:
|
|
500
|
+
raise ImportError(
|
|
501
|
+
f"dm_control not found, unable to create {env_name}:"
|
|
502
|
+
f" {task_name}. Consider downloading and installing "
|
|
503
|
+
f"dm_control from {self.git_url}"
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
camera_id = kwargs.pop("camera_id", 0)
|
|
507
|
+
if _seed is not None:
|
|
508
|
+
random_state = np.random.RandomState(_seed)
|
|
509
|
+
kwargs["random"] = random_state
|
|
510
|
+
env = suite.load(env_name, task_name, task_kwargs=kwargs)
|
|
511
|
+
return super()._build_env(
|
|
512
|
+
env,
|
|
513
|
+
_seed=_seed,
|
|
514
|
+
from_pixels=from_pixels,
|
|
515
|
+
pixels_only=pixels_only,
|
|
516
|
+
camera_id=camera_id,
|
|
517
|
+
**kwargs,
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
def rebuild_with_kwargs(self, **new_kwargs):
|
|
521
|
+
self._constructor_kwargs.update(new_kwargs)
|
|
522
|
+
self._env = self._build_env()
|
|
523
|
+
self._make_specs(self._env)
|
|
524
|
+
|
|
525
|
+
def _check_kwargs(self, kwargs: dict):
|
|
526
|
+
if "env_name" in kwargs:
|
|
527
|
+
env_name = kwargs["env_name"]
|
|
528
|
+
if "task_name" in kwargs:
|
|
529
|
+
task_name = kwargs["task_name"]
|
|
530
|
+
available_envs = dict(self.available_envs)
|
|
531
|
+
if (
|
|
532
|
+
env_name not in available_envs
|
|
533
|
+
or task_name not in available_envs[env_name]
|
|
534
|
+
):
|
|
535
|
+
raise RuntimeError(
|
|
536
|
+
f"{env_name} with task {task_name} is unknown in {self.libname}"
|
|
537
|
+
)
|
|
538
|
+
else:
|
|
539
|
+
raise TypeError("dm_control requires task_name to be specified")
|
|
540
|
+
else:
|
|
541
|
+
raise TypeError("dm_control requires env_name to be specified")
|
|
542
|
+
|
|
543
|
+
def __repr__(self) -> str:
|
|
544
|
+
return f"{self.__class__.__name__}(env={self.env_name}, task={self.task_name}, batch_size={self.batch_size})"
|