torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
torchrl/envs/libs/gym.py
ADDED
|
@@ -0,0 +1,2239 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import collections
|
|
9
|
+
import importlib
|
|
10
|
+
import warnings
|
|
11
|
+
from contextlib import nullcontext
|
|
12
|
+
from copy import copy
|
|
13
|
+
from functools import partial
|
|
14
|
+
from types import ModuleType
|
|
15
|
+
from warnings import warn
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch
|
|
19
|
+
from packaging import version
|
|
20
|
+
from tensordict import TensorDict, TensorDictBase
|
|
21
|
+
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
|
22
|
+
|
|
23
|
+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
|
|
24
|
+
|
|
25
|
+
from torchrl._utils import implement_for, logger as torchrl_logger
|
|
26
|
+
from torchrl.data.tensor_specs import (
|
|
27
|
+
_minmax_dtype,
|
|
28
|
+
Binary,
|
|
29
|
+
Bounded,
|
|
30
|
+
Categorical,
|
|
31
|
+
Composite,
|
|
32
|
+
MultiCategorical,
|
|
33
|
+
MultiOneHot,
|
|
34
|
+
NonTensor,
|
|
35
|
+
OneHot,
|
|
36
|
+
TensorSpec,
|
|
37
|
+
Unbounded,
|
|
38
|
+
)
|
|
39
|
+
from torchrl.data.utils import numpy_to_torch_dtype_dict, torch_to_numpy_dtype_dict
|
|
40
|
+
from torchrl.envs.common import _EnvPostInit
|
|
41
|
+
from torchrl.envs.gym_like import default_info_dict_reader, GymLikeEnv
|
|
42
|
+
from torchrl.envs.utils import _classproperty
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
from torch.utils._contextlib import _DecoratorContextManager
|
|
46
|
+
except ModuleNotFoundError:
|
|
47
|
+
from torchrl._utils import _DecoratorContextManager
|
|
48
|
+
|
|
49
|
+
DEFAULT_GYM = None
|
|
50
|
+
IMPORT_ERROR = None
|
|
51
|
+
# check gym presence without importing it
|
|
52
|
+
_has_gym = importlib.util.find_spec("gym") is not None
|
|
53
|
+
if not _has_gym:
|
|
54
|
+
_has_gym = importlib.util.find_spec("gymnasium") is not None
|
|
55
|
+
|
|
56
|
+
_has_mo = importlib.util.find_spec("mo_gymnasium") is not None
|
|
57
|
+
_has_sb3 = importlib.util.find_spec("stable_baselines3") is not None
|
|
58
|
+
_has_isaaclab = importlib.util.find_spec("isaaclab") is not None
|
|
59
|
+
_has_minigrid = importlib.util.find_spec("minigrid") is not None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
GYMNASIUM_1_ERROR = """RuntimeError: TorchRL does not support gymnasium 1.0 versions due to incompatible
|
|
63
|
+
changes in the Gym API.
|
|
64
|
+
Using gymnasium 1.0 with TorchRL would require significant modifications to your code and may result in:
|
|
65
|
+
* Inaccurate step counting, as the auto-reset feature can cause unpredictable numbers of steps to be executed.
|
|
66
|
+
* Potential data corruption, as the environment may require/produce garbage data during reset steps.
|
|
67
|
+
* Trajectory overlap during data collection.
|
|
68
|
+
* Increased computational overhead, as the library would need to handle the additional complexity of auto-resets.
|
|
69
|
+
* Manual filtering and boilerplate code to mitigate these issues, which would compromise the modularity and ease of
|
|
70
|
+
use of TorchRL.
|
|
71
|
+
To maintain the integrity and efficiency of our library, we cannot support this version of gymnasium at this time.
|
|
72
|
+
If you need to use gymnasium 1.0, we recommend exploring alternative solutions or waiting for future updates
|
|
73
|
+
to TorchRL and gymnasium that may address this compatibility issue.
|
|
74
|
+
For more information, please refer to discussion https://github.com/pytorch/rl/discussions/2483 in torchrl.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _minigrid_lib():
|
|
79
|
+
assert _has_minigrid, "minigrid not found"
|
|
80
|
+
import minigrid
|
|
81
|
+
|
|
82
|
+
return minigrid
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class set_gym_backend(_DecoratorContextManager):
|
|
86
|
+
"""Sets the gym-backend to a certain value.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
backend (python module, string or callable returning a module): the
|
|
90
|
+
gym backend to use. Use a string or callable whenever you wish to
|
|
91
|
+
avoid importing gym at loading time.
|
|
92
|
+
|
|
93
|
+
Examples:
|
|
94
|
+
>>> import gym
|
|
95
|
+
>>> import gymnasium
|
|
96
|
+
>>> with set_gym_backend("gym"):
|
|
97
|
+
... assert gym_backend() == gym
|
|
98
|
+
>>> with set_gym_backend(lambda: gym):
|
|
99
|
+
... assert gym_backend() == gym
|
|
100
|
+
>>> with set_gym_backend(gym):
|
|
101
|
+
... assert gym_backend() == gym
|
|
102
|
+
>>> with set_gym_backend("gymnasium"):
|
|
103
|
+
... assert gym_backend() == gymnasium
|
|
104
|
+
>>> with set_gym_backend(lambda: gymnasium):
|
|
105
|
+
... assert gym_backend() == gymnasium
|
|
106
|
+
>>> with set_gym_backend(gymnasium):
|
|
107
|
+
... assert gym_backend() == gymnasium
|
|
108
|
+
|
|
109
|
+
This class can also be used as a function decorator.
|
|
110
|
+
|
|
111
|
+
Examples:
|
|
112
|
+
>>> @set_gym_backend("gym")
|
|
113
|
+
... def fun():
|
|
114
|
+
... gym = gym_backend()
|
|
115
|
+
... print(gym)
|
|
116
|
+
>>> fun()
|
|
117
|
+
<module 'gym' from '/path/to/env/site-packages/gym/__init__.py'>
|
|
118
|
+
>>> @set_gym_backend("gymnasium")
|
|
119
|
+
... def fun():
|
|
120
|
+
... gym = gym_backend()
|
|
121
|
+
... print(gym)
|
|
122
|
+
>>> fun()
|
|
123
|
+
<module 'gymnasium' from '/path/to/env/site-packages/gymnasium/__init__.py'>
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
def __init__(self, backend):
|
|
129
|
+
self.backend = backend
|
|
130
|
+
|
|
131
|
+
def _call(self):
|
|
132
|
+
"""Sets the backend as default."""
|
|
133
|
+
global DEFAULT_GYM
|
|
134
|
+
DEFAULT_GYM = self.backend
|
|
135
|
+
found_setters = collections.defaultdict(bool)
|
|
136
|
+
for setter in copy(implement_for._setters):
|
|
137
|
+
check_module = (
|
|
138
|
+
callable(setter.module_name)
|
|
139
|
+
and setter.module_name.__name__ == self.backend.__name__
|
|
140
|
+
) or setter.module_name == self.backend.__name__
|
|
141
|
+
check_version = setter.check_version(
|
|
142
|
+
self.backend.__version__, setter.from_version, setter.to_version
|
|
143
|
+
)
|
|
144
|
+
if check_module and check_version:
|
|
145
|
+
setter.module_set()
|
|
146
|
+
found_setter = True
|
|
147
|
+
elif check_module:
|
|
148
|
+
found_setter = False
|
|
149
|
+
else:
|
|
150
|
+
found_setter = None
|
|
151
|
+
if found_setter is not None:
|
|
152
|
+
found_setters[setter.func_name] = (
|
|
153
|
+
found_setters[setter.func_name] or found_setter
|
|
154
|
+
)
|
|
155
|
+
# we keep only the setters we need. This is safe because a copy is saved under self._setters_saved
|
|
156
|
+
for func_name, found_setter in found_setters.items():
|
|
157
|
+
if not found_setter:
|
|
158
|
+
raise ImportError(
|
|
159
|
+
f"could not set anything related to gym backend "
|
|
160
|
+
f"{self.backend.__name__} with version={self.backend.__version__} for the function with name {func_name}. "
|
|
161
|
+
f"Check that the gym versions match!"
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
def set(self):
|
|
165
|
+
"""Irreversibly sets the gym backend in the script."""
|
|
166
|
+
self._call()
|
|
167
|
+
|
|
168
|
+
def __enter__(self):
|
|
169
|
+
global DEFAULT_GYM
|
|
170
|
+
# Save the current DEFAULT_GYM so we can restore it on exit
|
|
171
|
+
self._default_gym_saved = DEFAULT_GYM
|
|
172
|
+
self._call()
|
|
173
|
+
|
|
174
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
175
|
+
global DEFAULT_GYM
|
|
176
|
+
# Restore the previous DEFAULT_GYM
|
|
177
|
+
saved_gym = self._default_gym_saved
|
|
178
|
+
DEFAULT_GYM = saved_gym
|
|
179
|
+
delattr(self, "_default_gym_saved")
|
|
180
|
+
# Re-activate the implementations for the original backend
|
|
181
|
+
# If saved_gym was None, we need to determine the default backend
|
|
182
|
+
# by calling gym_backend() which will initialize DEFAULT_GYM
|
|
183
|
+
if saved_gym is None:
|
|
184
|
+
# Initialize DEFAULT_GYM with the default backend (gymnasium first, then gym)
|
|
185
|
+
saved_gym = gym_backend()
|
|
186
|
+
# Re-apply the original backend's implementations
|
|
187
|
+
for setter in copy(implement_for._setters):
|
|
188
|
+
check_module = (
|
|
189
|
+
callable(setter.module_name)
|
|
190
|
+
and setter.module_name.__name__ == saved_gym.__name__
|
|
191
|
+
) or setter.module_name == saved_gym.__name__
|
|
192
|
+
check_version = setter.check_version(
|
|
193
|
+
saved_gym.__version__, setter.from_version, setter.to_version
|
|
194
|
+
)
|
|
195
|
+
if check_module and check_version:
|
|
196
|
+
setter.module_set()
|
|
197
|
+
|
|
198
|
+
def clone(self):
|
|
199
|
+
# override this method if your children class takes __init__ parameters
|
|
200
|
+
return self.__class__(self.backend)
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def backend(self):
|
|
204
|
+
if isinstance(self._backend, str):
|
|
205
|
+
return importlib.import_module(self._backend)
|
|
206
|
+
elif callable(self._backend):
|
|
207
|
+
return self._backend()
|
|
208
|
+
return self._backend
|
|
209
|
+
|
|
210
|
+
@backend.setter
|
|
211
|
+
def backend(self, value):
|
|
212
|
+
self._backend = value
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def gym_backend(submodule=None):
|
|
216
|
+
"""Returns the gym backend, or a sumbodule of it.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
submodule (str): the submodule to import. If ``None``, the backend
|
|
220
|
+
itself is returned.
|
|
221
|
+
|
|
222
|
+
Examples:
|
|
223
|
+
>>> import mo_gymnasium
|
|
224
|
+
>>> with set_gym_backend("gym"):
|
|
225
|
+
... wrappers = gym_backend('wrappers')
|
|
226
|
+
... print(wrappers)
|
|
227
|
+
>>> with set_gym_backend("gymnasium"):
|
|
228
|
+
... wrappers = gym_backend('wrappers')
|
|
229
|
+
... print(wrappers)
|
|
230
|
+
"""
|
|
231
|
+
global IMPORT_ERROR
|
|
232
|
+
global DEFAULT_GYM
|
|
233
|
+
if DEFAULT_GYM is None:
|
|
234
|
+
try:
|
|
235
|
+
# rule of thumbs: gymnasium precedes
|
|
236
|
+
import gymnasium as gym
|
|
237
|
+
except ImportError as err:
|
|
238
|
+
IMPORT_ERROR = err
|
|
239
|
+
try:
|
|
240
|
+
import gym as gym
|
|
241
|
+
except ImportError as err:
|
|
242
|
+
IMPORT_ERROR = err
|
|
243
|
+
gym = None
|
|
244
|
+
DEFAULT_GYM = gym
|
|
245
|
+
if submodule is not None:
|
|
246
|
+
if not submodule.startswith("."):
|
|
247
|
+
submodule = "." + submodule
|
|
248
|
+
submodule = importlib.import_module(submodule, package=DEFAULT_GYM.__name__)
|
|
249
|
+
return submodule
|
|
250
|
+
return DEFAULT_GYM
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
__all__ = ["GymWrapper", "GymEnv"]
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
# Define a dictionary to store conversion functions for each spec type
|
|
257
|
+
class _ConversionRegistry(collections.UserDict):
|
|
258
|
+
def __getitem__(self, cls):
|
|
259
|
+
if cls not in super().keys():
|
|
260
|
+
# We want to find the closest parent
|
|
261
|
+
parents = {}
|
|
262
|
+
for k in self.keys():
|
|
263
|
+
if not isinstance(k, str):
|
|
264
|
+
parents[k] = k
|
|
265
|
+
continue
|
|
266
|
+
try:
|
|
267
|
+
space_cls = gym_backend("spaces")
|
|
268
|
+
for sbsp in k.split("."):
|
|
269
|
+
space_cls = getattr(space_cls, sbsp)
|
|
270
|
+
except AttributeError:
|
|
271
|
+
# Some specs may be too recent
|
|
272
|
+
continue
|
|
273
|
+
parents[space_cls] = k
|
|
274
|
+
mro = cls.mro()
|
|
275
|
+
for base in mro:
|
|
276
|
+
for p in parents:
|
|
277
|
+
if issubclass(base, p):
|
|
278
|
+
return self[parents[p]]
|
|
279
|
+
else:
|
|
280
|
+
raise KeyError(
|
|
281
|
+
f"No conversion tool could be found with the gym space {cls}. "
|
|
282
|
+
f"You can register your own with `torchrl.envs.libs.register_gym_spec_conversion.`"
|
|
283
|
+
)
|
|
284
|
+
return super().__getitem__(cls)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
_conversion_registry = _ConversionRegistry()
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def register_gym_spec_conversion(spec_type):
|
|
291
|
+
"""Decorator to register a conversion function for a specific spec type.
|
|
292
|
+
|
|
293
|
+
The method must have the following signature:
|
|
294
|
+
|
|
295
|
+
>>> @register_gym_spec_conversion("spec.name")
|
|
296
|
+
... def convert_specname(
|
|
297
|
+
... spec,
|
|
298
|
+
... dtype=None,
|
|
299
|
+
... device=None,
|
|
300
|
+
... categorical_action_encoding=None,
|
|
301
|
+
... remap_state_to_observation=None,
|
|
302
|
+
... batch_size=None,
|
|
303
|
+
... ):
|
|
304
|
+
|
|
305
|
+
where `gym(nasium).spaces.spec.name` is the location of the spec in gym.
|
|
306
|
+
|
|
307
|
+
If the spec type is accessible, this will also work:
|
|
308
|
+
|
|
309
|
+
>>> @register_gym_spec_conversion(SpecType)
|
|
310
|
+
... def convert_specname(
|
|
311
|
+
... spec,
|
|
312
|
+
... dtype=None,
|
|
313
|
+
... device=None,
|
|
314
|
+
... categorical_action_encoding=None,
|
|
315
|
+
... remap_state_to_observation=None,
|
|
316
|
+
... batch_size=None,
|
|
317
|
+
... ):
|
|
318
|
+
|
|
319
|
+
..note:: The wrapped function can be simplified, and unused kwargs can be wrapped in `**kwargs`.
|
|
320
|
+
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
def decorator(conversion_func):
|
|
324
|
+
_conversion_registry[spec_type] = conversion_func
|
|
325
|
+
return conversion_func
|
|
326
|
+
|
|
327
|
+
return decorator
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def _gym_to_torchrl_spec_transform(
|
|
331
|
+
spec,
|
|
332
|
+
dtype=None,
|
|
333
|
+
device=None,
|
|
334
|
+
categorical_action_encoding=False,
|
|
335
|
+
remap_state_to_observation: bool = True,
|
|
336
|
+
batch_size: tuple = (),
|
|
337
|
+
) -> TensorSpec:
|
|
338
|
+
"""Maps the gym specs to the TorchRL specs.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
spec (gym.spaces member): the gym space to transform.
|
|
342
|
+
dtype (torch.dtype): a dtype to use for the spec.
|
|
343
|
+
Defaults to`spec.dtype`.
|
|
344
|
+
device (torch.device): the device for the spec.
|
|
345
|
+
Defaults to ``None`` (no device for composite and default device for specs).
|
|
346
|
+
categorical_action_encoding (bool): whether discrete spaces should be mapped to categorical or one-hot.
|
|
347
|
+
Defaults to ``False`` (one-hot).
|
|
348
|
+
remap_state_to_observation (bool): whether to rename the 'state' key of
|
|
349
|
+
Dict specs to "observation". Default is true.
|
|
350
|
+
batch_size (torch.Size): batch size to which expand the spec. Defaults to
|
|
351
|
+
``torch.Size([])``.
|
|
352
|
+
"""
|
|
353
|
+
if batch_size:
|
|
354
|
+
return _gym_to_torchrl_spec_transform(
|
|
355
|
+
spec,
|
|
356
|
+
dtype=dtype,
|
|
357
|
+
device=device,
|
|
358
|
+
categorical_action_encoding=categorical_action_encoding,
|
|
359
|
+
remap_state_to_observation=remap_state_to_observation,
|
|
360
|
+
batch_size=None,
|
|
361
|
+
).expand(batch_size)
|
|
362
|
+
|
|
363
|
+
# Get the conversion function from the registry
|
|
364
|
+
conversion_func = _conversion_registry[type(spec)]
|
|
365
|
+
# Call the conversion function with the provided arguments
|
|
366
|
+
return conversion_func(
|
|
367
|
+
spec,
|
|
368
|
+
dtype=dtype,
|
|
369
|
+
device=device,
|
|
370
|
+
categorical_action_encoding=categorical_action_encoding,
|
|
371
|
+
remap_state_to_observation=remap_state_to_observation,
|
|
372
|
+
batch_size=batch_size,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
# Register conversion functions for each spec type
|
|
377
|
+
@register_gym_spec_conversion("tuple.Tuple")
|
|
378
|
+
def convert_tuple_spec(
|
|
379
|
+
spec,
|
|
380
|
+
dtype=None,
|
|
381
|
+
device=None,
|
|
382
|
+
categorical_action_encoding=None,
|
|
383
|
+
remap_state_to_observation=None,
|
|
384
|
+
batch_size=None,
|
|
385
|
+
):
|
|
386
|
+
# Implementation for Tuple spec type
|
|
387
|
+
result = torch.stack(
|
|
388
|
+
[
|
|
389
|
+
_gym_to_torchrl_spec_transform(
|
|
390
|
+
s,
|
|
391
|
+
device=device,
|
|
392
|
+
categorical_action_encoding=categorical_action_encoding,
|
|
393
|
+
remap_state_to_observation=remap_state_to_observation,
|
|
394
|
+
)
|
|
395
|
+
for s in spec
|
|
396
|
+
],
|
|
397
|
+
dim=0,
|
|
398
|
+
)
|
|
399
|
+
return result
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
@register_gym_spec_conversion("discrete.Discrete")
|
|
403
|
+
def convert_discrete_spec(
|
|
404
|
+
spec,
|
|
405
|
+
dtype=None,
|
|
406
|
+
device=None,
|
|
407
|
+
categorical_action_encoding=None,
|
|
408
|
+
remap_state_to_observation=None,
|
|
409
|
+
batch_size=None,
|
|
410
|
+
):
|
|
411
|
+
# Implementation for Discrete spec type
|
|
412
|
+
action_space_cls = Categorical if categorical_action_encoding else OneHot
|
|
413
|
+
dtype = (
|
|
414
|
+
numpy_to_torch_dtype_dict[spec.dtype]
|
|
415
|
+
if categorical_action_encoding
|
|
416
|
+
else torch.long
|
|
417
|
+
)
|
|
418
|
+
return action_space_cls(spec.n, device=device, dtype=dtype)
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
@register_gym_spec_conversion("multi_binary.MultiBinary")
|
|
422
|
+
def convert_multi_binary_spec(
|
|
423
|
+
spec,
|
|
424
|
+
dtype=None,
|
|
425
|
+
device=None,
|
|
426
|
+
categorical_action_encoding=None,
|
|
427
|
+
remap_state_to_observation=None,
|
|
428
|
+
batch_size=None,
|
|
429
|
+
):
|
|
430
|
+
# Implementation for MultiBinary spec type
|
|
431
|
+
return Binary(spec.n, device=device, dtype=numpy_to_torch_dtype_dict[spec.dtype])
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
@register_gym_spec_conversion("multi_discrete.MultiDiscrete")
|
|
435
|
+
def convert_multidiscrete_spec(
|
|
436
|
+
spec,
|
|
437
|
+
dtype=None,
|
|
438
|
+
device=None,
|
|
439
|
+
categorical_action_encoding=None,
|
|
440
|
+
remap_state_to_observation=None,
|
|
441
|
+
batch_size=None,
|
|
442
|
+
):
|
|
443
|
+
# Only use MultiCategorical/MultiOneHot for heterogeneous nvec (e.g., [3, 5, 7]).
|
|
444
|
+
# Homogeneous nvec like [2, 2] typically represents independent actions
|
|
445
|
+
# (e.g., vectorized envs with same Discrete(n) per env) and should use stacking.
|
|
446
|
+
if len(spec.nvec.shape) == 1 and len(np.unique(spec.nvec)) > 1:
|
|
447
|
+
dtype = (
|
|
448
|
+
numpy_to_torch_dtype_dict[spec.dtype]
|
|
449
|
+
if categorical_action_encoding
|
|
450
|
+
else torch.long
|
|
451
|
+
)
|
|
452
|
+
return (
|
|
453
|
+
MultiCategorical(spec.nvec, device=device, dtype=dtype)
|
|
454
|
+
if categorical_action_encoding
|
|
455
|
+
else MultiOneHot(spec.nvec, device=device, dtype=dtype)
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
return torch.stack(
|
|
459
|
+
[
|
|
460
|
+
_gym_to_torchrl_spec_transform(
|
|
461
|
+
spec[i],
|
|
462
|
+
device=device,
|
|
463
|
+
categorical_action_encoding=categorical_action_encoding,
|
|
464
|
+
remap_state_to_observation=remap_state_to_observation,
|
|
465
|
+
)
|
|
466
|
+
for i in range(len(spec.nvec))
|
|
467
|
+
],
|
|
468
|
+
0,
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
@register_gym_spec_conversion("Box")
|
|
473
|
+
def convert_box_spec(
|
|
474
|
+
spec,
|
|
475
|
+
dtype=None,
|
|
476
|
+
device=None,
|
|
477
|
+
categorical_action_encoding=None,
|
|
478
|
+
remap_state_to_observation=None,
|
|
479
|
+
batch_size=None,
|
|
480
|
+
):
|
|
481
|
+
shape = spec.shape
|
|
482
|
+
if not len(shape):
|
|
483
|
+
shape = torch.Size([1])
|
|
484
|
+
if dtype is None:
|
|
485
|
+
dtype = numpy_to_torch_dtype_dict[spec.dtype]
|
|
486
|
+
low = torch.as_tensor(spec.low, device=device, dtype=dtype)
|
|
487
|
+
high = torch.as_tensor(spec.high, device=device, dtype=dtype)
|
|
488
|
+
is_unbounded = low.isinf().all() and high.isinf().all()
|
|
489
|
+
|
|
490
|
+
minval, maxval = _minmax_dtype(dtype)
|
|
491
|
+
minval = torch.as_tensor(minval).to(low.device, dtype)
|
|
492
|
+
maxval = torch.as_tensor(maxval).to(low.device, dtype)
|
|
493
|
+
is_unbounded = is_unbounded or (
|
|
494
|
+
torch.isclose(low, torch.as_tensor(minval, dtype=dtype)).all()
|
|
495
|
+
and torch.isclose(high, torch.as_tensor(maxval, dtype=dtype)).all()
|
|
496
|
+
)
|
|
497
|
+
return (
|
|
498
|
+
Unbounded(shape, device=device, dtype=dtype)
|
|
499
|
+
if is_unbounded
|
|
500
|
+
else Bounded(
|
|
501
|
+
low,
|
|
502
|
+
high,
|
|
503
|
+
shape,
|
|
504
|
+
dtype=dtype,
|
|
505
|
+
device=device,
|
|
506
|
+
)
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
@register_gym_spec_conversion("Sequence")
|
|
511
|
+
def convert_sequence_spec(
|
|
512
|
+
spec,
|
|
513
|
+
dtype=None,
|
|
514
|
+
device=None,
|
|
515
|
+
categorical_action_encoding=None,
|
|
516
|
+
remap_state_to_observation=None,
|
|
517
|
+
batch_size=None,
|
|
518
|
+
):
|
|
519
|
+
if not hasattr(spec, "stack"):
|
|
520
|
+
# gym does not have a stack attribute in sequence
|
|
521
|
+
raise ValueError(
|
|
522
|
+
"gymnasium should be used whenever a Sequence is present, as it needs to be stacked. "
|
|
523
|
+
"If you need the gym backend at all price, please raise an issue on the TorchRL GitHub repository."
|
|
524
|
+
)
|
|
525
|
+
if not getattr(spec, "stack", False):
|
|
526
|
+
raise ValueError(
|
|
527
|
+
"Sequence spaces must have the stack argument set to ``True``. "
|
|
528
|
+
)
|
|
529
|
+
space = spec.feature_space
|
|
530
|
+
out = _gym_to_torchrl_spec_transform(space, device=device, dtype=dtype)
|
|
531
|
+
out = out.unsqueeze(0)
|
|
532
|
+
out.make_neg_dim(0)
|
|
533
|
+
return out
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
@register_gym_spec_conversion(dict)
|
|
537
|
+
def convert_dict_spec(
|
|
538
|
+
spec,
|
|
539
|
+
dtype=None,
|
|
540
|
+
device=None,
|
|
541
|
+
categorical_action_encoding=None,
|
|
542
|
+
remap_state_to_observation=None,
|
|
543
|
+
batch_size=None,
|
|
544
|
+
):
|
|
545
|
+
spec_out = {}
|
|
546
|
+
for k in spec.keys():
|
|
547
|
+
key = k
|
|
548
|
+
if (
|
|
549
|
+
remap_state_to_observation
|
|
550
|
+
and k == "state"
|
|
551
|
+
and "observation" not in spec.keys()
|
|
552
|
+
):
|
|
553
|
+
# we rename "state" in "observation" as "observation" is the conventional name
|
|
554
|
+
# for single observation in torchrl.
|
|
555
|
+
# naming it 'state' will result in envs that have a different name for the state vector
|
|
556
|
+
# when queried with and without pixels
|
|
557
|
+
key = "observation"
|
|
558
|
+
spec_out[key] = _gym_to_torchrl_spec_transform(
|
|
559
|
+
spec[k],
|
|
560
|
+
device=device,
|
|
561
|
+
categorical_action_encoding=categorical_action_encoding,
|
|
562
|
+
remap_state_to_observation=remap_state_to_observation,
|
|
563
|
+
batch_size=batch_size,
|
|
564
|
+
)
|
|
565
|
+
# the batch-size must be set later
|
|
566
|
+
return Composite(spec_out, device=device)
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
@register_gym_spec_conversion("Text")
|
|
570
|
+
def convert_text_soec(
|
|
571
|
+
spec,
|
|
572
|
+
dtype=None,
|
|
573
|
+
device=None,
|
|
574
|
+
categorical_action_encoding=None,
|
|
575
|
+
remap_state_to_observation=None,
|
|
576
|
+
batch_size=None,
|
|
577
|
+
):
|
|
578
|
+
return NonTensor((), device=device, example_data="a string")
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
@register_gym_spec_conversion("dict.Dict")
|
|
582
|
+
def convert_dict_spec2(
|
|
583
|
+
spec,
|
|
584
|
+
dtype=None,
|
|
585
|
+
device=None,
|
|
586
|
+
categorical_action_encoding=None,
|
|
587
|
+
remap_state_to_observation=None,
|
|
588
|
+
batch_size=None,
|
|
589
|
+
):
|
|
590
|
+
return _gym_to_torchrl_spec_transform(
|
|
591
|
+
spec.spaces,
|
|
592
|
+
device=device,
|
|
593
|
+
categorical_action_encoding=categorical_action_encoding,
|
|
594
|
+
remap_state_to_observation=remap_state_to_observation,
|
|
595
|
+
batch_size=batch_size,
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
@implement_for("gym", None, "0.18")
|
|
600
|
+
def _box_convert(spec, gym_spaces, shape):
|
|
601
|
+
low = spec.low.detach().unique().cpu().item()
|
|
602
|
+
high = spec.high.detach().unique().cpu().item()
|
|
603
|
+
return gym_spaces.Box(low=low, high=high, shape=shape)
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
@implement_for("gym", "0.18")
|
|
607
|
+
def _box_convert(spec, gym_spaces, shape): # noqa: F811
|
|
608
|
+
low = spec.low.detach().cpu().numpy()
|
|
609
|
+
high = spec.high.detach().cpu().numpy()
|
|
610
|
+
return gym_spaces.Box(low=low, high=high, shape=shape)
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
@implement_for("gymnasium", None, "1.0.0")
|
|
614
|
+
def _box_convert(spec, gym_spaces, shape): # noqa: F811
|
|
615
|
+
low = spec.low.detach().cpu().numpy()
|
|
616
|
+
high = spec.high.detach().cpu().numpy()
|
|
617
|
+
return gym_spaces.Box(low=low, high=high, shape=shape)
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
@implement_for("gymnasium", "1.0.0", "1.1.0")
|
|
621
|
+
def _box_convert(spec, gym_spaces, shape): # noqa: F811
|
|
622
|
+
raise ImportError(GYMNASIUM_1_ERROR)
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
@implement_for("gymnasium", "1.1.0")
|
|
626
|
+
def _box_convert(spec, gym_spaces, shape): # noqa: F811
|
|
627
|
+
low = spec.low.detach().cpu().numpy()
|
|
628
|
+
high = spec.high.detach().cpu().numpy()
|
|
629
|
+
return gym_spaces.Box(low=low, high=high, shape=shape)
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
@implement_for("gym", "0.21", None)
|
|
633
|
+
def _multidiscrete_convert(gym_spaces, spec):
|
|
634
|
+
return gym_spaces.multi_discrete.MultiDiscrete(
|
|
635
|
+
spec.nvec, dtype=torch_to_numpy_dtype_dict[spec.dtype]
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
@implement_for("gymnasium", None, "1.0.0")
|
|
640
|
+
def _multidiscrete_convert(gym_spaces, spec): # noqa: F811
|
|
641
|
+
return gym_spaces.multi_discrete.MultiDiscrete(
|
|
642
|
+
spec.nvec, dtype=torch_to_numpy_dtype_dict[spec.dtype]
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
@implement_for("gymnasium", "1.0.0", "1.1.0")
|
|
647
|
+
def _multidiscrete_convert(gym_spaces, spec): # noqa: F811
|
|
648
|
+
raise ImportError(GYMNASIUM_1_ERROR)
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
@implement_for("gymnasium", "1.1.0")
|
|
652
|
+
def _multidiscrete_convert(gym_spaces, spec): # noqa: F811
|
|
653
|
+
return gym_spaces.multi_discrete.MultiDiscrete(
|
|
654
|
+
spec.nvec, dtype=torch_to_numpy_dtype_dict[spec.dtype]
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
@implement_for("gym", None, "0.21")
|
|
659
|
+
def _multidiscrete_convert(gym_spaces, spec): # noqa: F811
|
|
660
|
+
return gym_spaces.multi_discrete.MultiDiscrete(spec.nvec)
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def _torchrl_to_gym_spec_transform(
|
|
664
|
+
spec,
|
|
665
|
+
categorical_action_encoding=False,
|
|
666
|
+
) -> TensorSpec:
|
|
667
|
+
"""Maps TorchRL specs to gym spaces.
|
|
668
|
+
|
|
669
|
+
Args:
|
|
670
|
+
spec: the torchrl spec to transform.
|
|
671
|
+
categorical_action_encoding: whether discrete spaces should be mapped to categorical or one-hot.
|
|
672
|
+
Defaults to one-hot.
|
|
673
|
+
|
|
674
|
+
"""
|
|
675
|
+
gym_spaces = gym_backend("spaces")
|
|
676
|
+
shape = spec.shape
|
|
677
|
+
if any(s == -1 for s in spec.shape):
|
|
678
|
+
if spec.shape[0] == -1:
|
|
679
|
+
spec = spec.clone()
|
|
680
|
+
spec = spec[0]
|
|
681
|
+
return gym_spaces.Sequence(_torchrl_to_gym_spec_transform(spec), stack=True)
|
|
682
|
+
else:
|
|
683
|
+
return gym_spaces.Tuple(
|
|
684
|
+
tuple(_torchrl_to_gym_spec_transform(spec) for spec in spec.unbind(0))
|
|
685
|
+
)
|
|
686
|
+
if isinstance(spec, MultiCategorical):
|
|
687
|
+
return _multidiscrete_convert(gym_spaces, spec)
|
|
688
|
+
if isinstance(spec, MultiOneHot):
|
|
689
|
+
return gym_spaces.multi_discrete.MultiDiscrete(spec.nvec)
|
|
690
|
+
if isinstance(spec, Binary):
|
|
691
|
+
return gym_spaces.multi_binary.MultiBinary(spec.shape[-1])
|
|
692
|
+
if isinstance(spec, Categorical):
|
|
693
|
+
return gym_spaces.discrete.Discrete(
|
|
694
|
+
spec.n
|
|
695
|
+
) # dtype=torch_to_numpy_dtype_dict[spec.dtype])
|
|
696
|
+
if isinstance(spec, OneHot):
|
|
697
|
+
return gym_spaces.discrete.Discrete(spec.n)
|
|
698
|
+
if isinstance(spec, Unbounded):
|
|
699
|
+
minval, maxval = _minmax_dtype(spec.dtype)
|
|
700
|
+
return gym_spaces.Box(
|
|
701
|
+
low=minval,
|
|
702
|
+
high=maxval,
|
|
703
|
+
shape=shape,
|
|
704
|
+
dtype=torch_to_numpy_dtype_dict[spec.dtype],
|
|
705
|
+
)
|
|
706
|
+
if isinstance(spec, Unbounded):
|
|
707
|
+
minval, maxval = _minmax_dtype(spec.dtype)
|
|
708
|
+
return gym_spaces.Box(
|
|
709
|
+
low=minval,
|
|
710
|
+
high=maxval,
|
|
711
|
+
shape=shape,
|
|
712
|
+
dtype=torch_to_numpy_dtype_dict[spec.dtype],
|
|
713
|
+
)
|
|
714
|
+
if isinstance(spec, Bounded):
|
|
715
|
+
return _box_convert(spec, gym_spaces, shape)
|
|
716
|
+
if isinstance(spec, Composite):
|
|
717
|
+
# remove batch size
|
|
718
|
+
while spec.shape:
|
|
719
|
+
spec = spec[0]
|
|
720
|
+
return gym_spaces.Dict(
|
|
721
|
+
**{
|
|
722
|
+
key: _torchrl_to_gym_spec_transform(
|
|
723
|
+
val,
|
|
724
|
+
categorical_action_encoding=categorical_action_encoding,
|
|
725
|
+
)
|
|
726
|
+
for key, val in spec.items()
|
|
727
|
+
}
|
|
728
|
+
)
|
|
729
|
+
else:
|
|
730
|
+
raise NotImplementedError(
|
|
731
|
+
f"spec of type {type(spec).__name__} is currently unaccounted for"
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
|
|
735
|
+
def _get_envs(to_dict=False) -> list:
|
|
736
|
+
if not _has_gym:
|
|
737
|
+
raise ImportError("Gym(nasium) could not be found in your virtual environment.")
|
|
738
|
+
envs = _get_gym_envs()
|
|
739
|
+
envs = list(envs)
|
|
740
|
+
envs = sorted(envs)
|
|
741
|
+
return envs
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
@implement_for("gym", None, "0.26.0")
|
|
745
|
+
def _get_gym_envs(): # noqa: F811
|
|
746
|
+
gym = gym_backend()
|
|
747
|
+
return gym.envs.registration.registry.env_specs.keys()
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
@implement_for("gym", "0.26.0", None)
|
|
751
|
+
def _get_gym_envs(): # noqa: F811
|
|
752
|
+
gym = gym_backend()
|
|
753
|
+
return gym.envs.registration.registry.keys()
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
@implement_for("gymnasium", None, "1.0.0")
|
|
757
|
+
def _get_gym_envs(): # noqa: F811
|
|
758
|
+
gym = gym_backend()
|
|
759
|
+
return gym.envs.registration.registry.keys()
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
@implement_for("gymnasium", "1.0.0", "1.1.0")
|
|
763
|
+
def _get_gym_envs(): # noqa: F811
|
|
764
|
+
raise ImportError(GYMNASIUM_1_ERROR)
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
@implement_for("gymnasium", "1.1.0")
|
|
768
|
+
def _get_gym_envs(): # noqa: F811
|
|
769
|
+
gym = gym_backend()
|
|
770
|
+
return gym.envs.registration.registry.keys()
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
def _is_from_pixels(env):
|
|
774
|
+
observation_spec = env.observation_space
|
|
775
|
+
try:
|
|
776
|
+
PixelObservationWrapper = gym_backend(
|
|
777
|
+
"wrappers.pixel_observation"
|
|
778
|
+
).PixelObservationWrapper
|
|
779
|
+
except ModuleNotFoundError:
|
|
780
|
+
|
|
781
|
+
class PixelObservationWrapper:
|
|
782
|
+
pass
|
|
783
|
+
|
|
784
|
+
from torchrl.envs.libs.utils import (
|
|
785
|
+
GymPixelObservationWrapper as LegacyPixelObservationWrapper,
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
gDict = gym_backend("spaces").dict.Dict
|
|
789
|
+
Box = gym_backend("spaces").Box
|
|
790
|
+
|
|
791
|
+
if isinstance(observation_spec, (dict,)):
|
|
792
|
+
if "pixels" in set(observation_spec.keys()):
|
|
793
|
+
return True
|
|
794
|
+
if isinstance(observation_spec, (gDict,)):
|
|
795
|
+
if "pixels" in set(observation_spec.spaces.keys()):
|
|
796
|
+
return True
|
|
797
|
+
elif (
|
|
798
|
+
isinstance(observation_spec, Box)
|
|
799
|
+
and (observation_spec.low == 0).all()
|
|
800
|
+
and (observation_spec.high == 255).all()
|
|
801
|
+
and observation_spec.low.shape[-1] == 3
|
|
802
|
+
and observation_spec.low.ndim == 3
|
|
803
|
+
):
|
|
804
|
+
return True
|
|
805
|
+
else:
|
|
806
|
+
while True:
|
|
807
|
+
if isinstance(
|
|
808
|
+
env, (LegacyPixelObservationWrapper, PixelObservationWrapper)
|
|
809
|
+
):
|
|
810
|
+
return True
|
|
811
|
+
if hasattr(env, "env"):
|
|
812
|
+
env = env.env
|
|
813
|
+
else:
|
|
814
|
+
break
|
|
815
|
+
return False
|
|
816
|
+
|
|
817
|
+
|
|
818
|
+
class _GymAsyncMeta(_EnvPostInit):
|
|
819
|
+
def __call__(cls, *args, **kwargs):
|
|
820
|
+
missing_obs_value = kwargs.pop("missing_obs_value", None)
|
|
821
|
+
num_workers = kwargs.pop("num_workers", 1)
|
|
822
|
+
|
|
823
|
+
if cls.__name__ == "GymEnv" and num_workers > 1:
|
|
824
|
+
from torchrl.envs import EnvCreator, ParallelEnv
|
|
825
|
+
|
|
826
|
+
env_name = args[0] if args else kwargs.get("env_name")
|
|
827
|
+
env_kwargs = kwargs.copy()
|
|
828
|
+
env_kwargs.pop("env_name", None)
|
|
829
|
+
make_env = partial(cls, env_name, **env_kwargs)
|
|
830
|
+
return ParallelEnv(num_workers, EnvCreator(make_env))
|
|
831
|
+
|
|
832
|
+
instance: GymWrapper = super().__call__(*args, **kwargs)
|
|
833
|
+
|
|
834
|
+
# before gym 0.22, there was no final_observation
|
|
835
|
+
if instance._is_batched:
|
|
836
|
+
gym_backend = instance.get_library_name(instance._env)
|
|
837
|
+
from torchrl.envs.transforms.transforms import (
|
|
838
|
+
TransformedEnv,
|
|
839
|
+
VecGymEnvTransform,
|
|
840
|
+
)
|
|
841
|
+
|
|
842
|
+
if _has_isaaclab:
|
|
843
|
+
from isaaclab.envs import ManagerBasedRLEnv
|
|
844
|
+
|
|
845
|
+
kwargs = {}
|
|
846
|
+
if missing_obs_value is not None:
|
|
847
|
+
kwargs["missing_obs_value"] = missing_obs_value
|
|
848
|
+
if isinstance(instance._env.unwrapped, ManagerBasedRLEnv):
|
|
849
|
+
return TransformedEnv(instance, VecGymEnvTransform(**kwargs))
|
|
850
|
+
|
|
851
|
+
if _has_sb3:
|
|
852
|
+
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
|
|
853
|
+
|
|
854
|
+
if isinstance(instance._env, VecEnv):
|
|
855
|
+
backend = "sb3"
|
|
856
|
+
else:
|
|
857
|
+
backend = gym_backend
|
|
858
|
+
else:
|
|
859
|
+
backend = gym_backend
|
|
860
|
+
|
|
861
|
+
# we need 3 checks: the backend is not sb3 (if so, gymnasium is used),
|
|
862
|
+
# it is gym and not gymnasium and the version is before 0.22.0
|
|
863
|
+
add_info_dict = True
|
|
864
|
+
if backend == "gym" and gym_backend == "gym": # check gym against gymnasium
|
|
865
|
+
import gym
|
|
866
|
+
|
|
867
|
+
if version.parse(gym.__version__) < version.parse("0.22.0"):
|
|
868
|
+
warn(
|
|
869
|
+
"A batched gym environment is being wrapped in a GymWrapper with gym version < 0.22. "
|
|
870
|
+
"This implies that the next-observation is wrongly tracked (as the batched environment auto-resets "
|
|
871
|
+
"and discards the true next observation to return the result of the step). "
|
|
872
|
+
"This isn't compatible with TorchRL API and should be used with caution.",
|
|
873
|
+
category=UserWarning,
|
|
874
|
+
)
|
|
875
|
+
add_info_dict = False
|
|
876
|
+
if gym_backend == "gymnasium":
|
|
877
|
+
import gymnasium
|
|
878
|
+
|
|
879
|
+
if version.parse(gymnasium.__version__) >= version.parse("1.1.0"):
|
|
880
|
+
add_info_dict = (
|
|
881
|
+
instance._env.autoreset_mode
|
|
882
|
+
!= gymnasium.vector.AutoresetMode.DISABLED
|
|
883
|
+
)
|
|
884
|
+
if not add_info_dict:
|
|
885
|
+
return instance
|
|
886
|
+
if add_info_dict:
|
|
887
|
+
# register terminal_obs_reader
|
|
888
|
+
instance.auto_register_info_dict(
|
|
889
|
+
info_dict_reader=terminal_obs_reader(
|
|
890
|
+
instance.observation_spec, backend=backend
|
|
891
|
+
)
|
|
892
|
+
)
|
|
893
|
+
kwargs = {}
|
|
894
|
+
if missing_obs_value is not None:
|
|
895
|
+
kwargs["missing_obs_value"] = missing_obs_value
|
|
896
|
+
return TransformedEnv(instance, VecGymEnvTransform(**kwargs))
|
|
897
|
+
return instance
|
|
898
|
+
|
|
899
|
+
|
|
900
|
+
class GymWrapper(GymLikeEnv, metaclass=_GymAsyncMeta):
|
|
901
|
+
"""OpenAI Gym environment wrapper.
|
|
902
|
+
|
|
903
|
+
Works across `gymnasium <https://gymnasium.farama.org/>`_ and `OpenAI/gym <https://github.com/openai/gym>`_.
|
|
904
|
+
|
|
905
|
+
Args:
|
|
906
|
+
env (gym.Env): the environment to wrap. Batched environments (:class:`~stable_baselines3.common.vec_env.base_vec_env.VecEnv`
|
|
907
|
+
or :class:`gym.VectorEnv`) are supported and the environment batch-size
|
|
908
|
+
will reflect the number of environments executed in parallel.
|
|
909
|
+
categorical_action_encoding (bool, optional): if ``True``, categorical
|
|
910
|
+
specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
|
|
911
|
+
otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
|
|
912
|
+
Defaults to ``False``.
|
|
913
|
+
|
|
914
|
+
Keyword Args:
|
|
915
|
+
from_pixels (bool, optional): if ``True``, an attempt to return the pixel
|
|
916
|
+
observations from the env will be performed. By default, these observations
|
|
917
|
+
will be written under the ``"pixels"`` entry.
|
|
918
|
+
The method being used varies
|
|
919
|
+
depending on the gym version and may involve a ``wrappers.pixel_observation.PixelObservationWrapper``.
|
|
920
|
+
Defaults to ``False``.
|
|
921
|
+
pixels_only (bool, optional): if ``True``, only the pixel observations will
|
|
922
|
+
be returned (by default under the ``"pixels"`` entry in the output tensordict).
|
|
923
|
+
If ``False``, observations (eg, states) and pixels will be returned
|
|
924
|
+
whenever ``from_pixels=True``. Defaults to ``True``.
|
|
925
|
+
frame_skip (int, optional): if provided, indicates for how many steps the
|
|
926
|
+
same action is to be repeated. The observation returned will be the
|
|
927
|
+
last observation of the sequence, whereas the reward will be the sum
|
|
928
|
+
of rewards across steps.
|
|
929
|
+
device (torch.device, optional): if provided, the device on which the data
|
|
930
|
+
is to be cast. Defaults to ``torch.device("cpu")``.
|
|
931
|
+
batch_size (torch.Size, optional): the batch size of the environment.
|
|
932
|
+
Should match the leading dimensions of all observations, done states,
|
|
933
|
+
rewards, actions and infos.
|
|
934
|
+
Defaults to ``torch.Size([])``.
|
|
935
|
+
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
|
|
936
|
+
for envs to be ``done`` just after :meth:`reset` is called.
|
|
937
|
+
Defaults to ``False``.
|
|
938
|
+
convert_actions_to_numpy (bool, optional): if ``True``, actions will be
|
|
939
|
+
converted from tensors to numpy arrays and moved to CPU before being passed to the
|
|
940
|
+
env step function. Set this to ``False`` if the environment is evaluated
|
|
941
|
+
on GPU, such as IsaacLab.
|
|
942
|
+
Defaults to ``True``.
|
|
943
|
+
missing_obs_value (Any, optional): default value to use as placeholder for missing observations, when
|
|
944
|
+
the environment is auto-resetting and missing observations cannot be found in the info dictionary
|
|
945
|
+
(e.g., with IsaacLab). This argument is passed to :class:`~torchrl.envs.VecGymEnvTransform` by
|
|
946
|
+
the metaclass.
|
|
947
|
+
|
|
948
|
+
Attributes:
|
|
949
|
+
available_envs (List[str]): a list of environments to build.
|
|
950
|
+
|
|
951
|
+
.. note::
|
|
952
|
+
If an attribute cannot be found, this class will attempt to retrieve it from
|
|
953
|
+
the nested env:
|
|
954
|
+
|
|
955
|
+
>>> from torchrl.envs import GymWrapper
|
|
956
|
+
>>> import gymnasium as gym
|
|
957
|
+
>>> env = GymWrapper(gym.make("Pendulum-v1"))
|
|
958
|
+
>>> print(env.spec.max_episode_steps)
|
|
959
|
+
200
|
|
960
|
+
|
|
961
|
+
Examples:
|
|
962
|
+
>>> import gymnasium as gym
|
|
963
|
+
>>> from torchrl.envs import GymWrapper
|
|
964
|
+
>>> base_env = gym.make("Pendulum-v1")
|
|
965
|
+
>>> env = GymWrapper(base_env)
|
|
966
|
+
>>> td = env.rand_step()
|
|
967
|
+
>>> print(td)
|
|
968
|
+
TensorDict(
|
|
969
|
+
fields={
|
|
970
|
+
action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
971
|
+
next: TensorDict(
|
|
972
|
+
fields={
|
|
973
|
+
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
974
|
+
observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
975
|
+
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
976
|
+
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
977
|
+
truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
978
|
+
batch_size=torch.Size([]),
|
|
979
|
+
device=cpu,
|
|
980
|
+
is_shared=False)},
|
|
981
|
+
batch_size=torch.Size([]),
|
|
982
|
+
device=cpu,
|
|
983
|
+
is_shared=False)
|
|
984
|
+
>>> print(env.available_envs)
|
|
985
|
+
['ALE/Adventure-ram-v5', 'ALE/Adventure-v5', 'ALE/AirRaid-ram-v5', 'ALE/AirRaid-v5', 'ALE/Alien-ram-v5', 'ALE/Alien-v5',
|
|
986
|
+
|
|
987
|
+
.. note::
|
|
988
|
+
info dictionaries will be read using :class:`~torchrl.envs.gym_like.default_info_dict_reader`
|
|
989
|
+
if no other reader is provided. To provide another reader, refer to
|
|
990
|
+
:meth:`set_info_dict_reader`. To automatically register the info_dict
|
|
991
|
+
content, refer to :meth:`torchrl.envs.GymLikeEnv.auto_register_info_dict`.
|
|
992
|
+
For parallel (Vectorized) environments, the info dictionary reader is automatically set and should
|
|
993
|
+
not be set manually.
|
|
994
|
+
|
|
995
|
+
.. note:: Gym spaces are not completely covered.
|
|
996
|
+
The following spaces are accounted for provided that they can be represented by a torch.Tensor, a nested tensor
|
|
997
|
+
and/or within a tensordict:
|
|
998
|
+
|
|
999
|
+
- spaces.Box
|
|
1000
|
+
- spaces.Sequence
|
|
1001
|
+
- spaces.Tuple
|
|
1002
|
+
- spaces.Discrete
|
|
1003
|
+
- spaces.MultiBinary
|
|
1004
|
+
- spaces.MultiDiscrete
|
|
1005
|
+
- spaces.Dict
|
|
1006
|
+
|
|
1007
|
+
Some considerations should be made when working with gym spaces. For instance, a tuple of spaces
|
|
1008
|
+
can only be supported if the spaces are semantically identical (same dtype and same number of dimensions).
|
|
1009
|
+
Ragged dimension can be supported through :func:`~torch.nested.nested_tensor`, but then there should be only
|
|
1010
|
+
one level of tuple and data should be stacked along the first dimension (as nested_tensors can only be
|
|
1011
|
+
stacked along the first dimension).
|
|
1012
|
+
|
|
1013
|
+
Check the example in examples/envs/gym_conversion_examples.py to know more!
|
|
1014
|
+
|
|
1015
|
+
"""
|
|
1016
|
+
|
|
1017
|
+
git_url = "https://github.com/openai/gym"
|
|
1018
|
+
libname = "gym"
|
|
1019
|
+
|
|
1020
|
+
@_classproperty
|
|
1021
|
+
def available_envs(cls):
|
|
1022
|
+
if not _has_gym:
|
|
1023
|
+
return []
|
|
1024
|
+
return list(_get_envs())
|
|
1025
|
+
|
|
1026
|
+
@staticmethod
|
|
1027
|
+
def get_library_name(env) -> str:
|
|
1028
|
+
"""Given a gym environment, returns the backend name (either gym or gymnasium).
|
|
1029
|
+
|
|
1030
|
+
This can be used to set the appropriate backend when needed:
|
|
1031
|
+
|
|
1032
|
+
Examples:
|
|
1033
|
+
>>> env = gymnasium.make("Pendulum-v1")
|
|
1034
|
+
>>> with set_gym_backend(env):
|
|
1035
|
+
... env = GymWrapper(env)
|
|
1036
|
+
|
|
1037
|
+
:class:`~GymWrapper` and similar use this method to set their method
|
|
1038
|
+
to the right backend during instantiation.
|
|
1039
|
+
|
|
1040
|
+
"""
|
|
1041
|
+
try:
|
|
1042
|
+
import gym
|
|
1043
|
+
|
|
1044
|
+
if isinstance(env.action_space, gym.spaces.space.Space):
|
|
1045
|
+
return "gym"
|
|
1046
|
+
except ImportError:
|
|
1047
|
+
pass
|
|
1048
|
+
try:
|
|
1049
|
+
import gymnasium
|
|
1050
|
+
|
|
1051
|
+
if isinstance(env.action_space, gymnasium.spaces.space.Space):
|
|
1052
|
+
return "gymnasium"
|
|
1053
|
+
except ImportError:
|
|
1054
|
+
pass
|
|
1055
|
+
raise ImportError(
|
|
1056
|
+
f"Could not find the library of env {env}. Please file an issue on torchrl github repo."
|
|
1057
|
+
)
|
|
1058
|
+
|
|
1059
|
+
def __init__(self, env=None, categorical_action_encoding=False, **kwargs):
|
|
1060
|
+
self._seed_calls_reset = None
|
|
1061
|
+
self._categorical_action_encoding = categorical_action_encoding
|
|
1062
|
+
if env is not None:
|
|
1063
|
+
try:
|
|
1064
|
+
env_str = str(env)
|
|
1065
|
+
except TypeError:
|
|
1066
|
+
# MiniGrid has a bug where the __str__ method fails
|
|
1067
|
+
pass
|
|
1068
|
+
else:
|
|
1069
|
+
if (
|
|
1070
|
+
"EnvCompatibility" in env_str
|
|
1071
|
+
): # a hacky way of knowing if EnvCompatibility is part of the wrappers of env
|
|
1072
|
+
raise ValueError(
|
|
1073
|
+
"GymWrapper does not support the gym.wrapper.compatibility.EnvCompatibility wrapper. "
|
|
1074
|
+
"If this feature is needed, detail your use case in an issue of "
|
|
1075
|
+
"https://github.com/pytorch/rl/issues."
|
|
1076
|
+
)
|
|
1077
|
+
libname = self.get_library_name(env)
|
|
1078
|
+
self._validate_env(env)
|
|
1079
|
+
with set_gym_backend(libname):
|
|
1080
|
+
kwargs["env"] = env
|
|
1081
|
+
super().__init__(**kwargs)
|
|
1082
|
+
else:
|
|
1083
|
+
super().__init__(**kwargs)
|
|
1084
|
+
self._post_init()
|
|
1085
|
+
|
|
1086
|
+
@implement_for("gymnasium", "1.1.0")
|
|
1087
|
+
def _validate_env(self, env):
|
|
1088
|
+
autoreset_mode = getattr(env, "autoreset_mode", None)
|
|
1089
|
+
if autoreset_mode is not None:
|
|
1090
|
+
from gymnasium.vector import AutoresetMode
|
|
1091
|
+
|
|
1092
|
+
if autoreset_mode not in (AutoresetMode.DISABLED, AutoresetMode.SAME_STEP):
|
|
1093
|
+
raise RuntimeError(
|
|
1094
|
+
"The auto-reset mode must be one of SAME_STEP or DISABLED (which is preferred). Got "
|
|
1095
|
+
f"autoreset_mode={autoreset_mode}."
|
|
1096
|
+
)
|
|
1097
|
+
|
|
1098
|
+
@implement_for("gym", None, "1.1.0")
|
|
1099
|
+
def _validate_env(self, env): # noqa
|
|
1100
|
+
pass
|
|
1101
|
+
|
|
1102
|
+
@implement_for("gymnasium", None, "1.1.0")
|
|
1103
|
+
def _validate_env(self, env): # noqa
|
|
1104
|
+
pass
|
|
1105
|
+
|
|
1106
|
+
def _post_init(self):
|
|
1107
|
+
# writes the functions that are gym-version specific to the instance
|
|
1108
|
+
# once and for all. This is aimed at avoiding the need of decorating code
|
|
1109
|
+
# with set_gym_backend + allowing for parallel execution (which would
|
|
1110
|
+
# be troublesome when both an old version of gym and recent gymnasium
|
|
1111
|
+
# are present within the same virtual env).
|
|
1112
|
+
#
|
|
1113
|
+
# These calls seemingly do nothing but they actually get rid of the @implement_for decorator.
|
|
1114
|
+
# We execute them within the set_gym_backend context manager to make sure we get
|
|
1115
|
+
# the right implementation.
|
|
1116
|
+
#
|
|
1117
|
+
# This method is executed by the metaclass of GymWrapper.
|
|
1118
|
+
with set_gym_backend(self.get_library_name(self._env)):
|
|
1119
|
+
self._reset_output_transform = self._reset_output_transform
|
|
1120
|
+
self._output_transform = self._output_transform
|
|
1121
|
+
|
|
1122
|
+
@property
|
|
1123
|
+
def _is_batched(self):
|
|
1124
|
+
tuple_of_classes = ()
|
|
1125
|
+
if _has_sb3:
|
|
1126
|
+
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
|
|
1127
|
+
|
|
1128
|
+
tuple_of_classes = tuple_of_classes + (VecEnv,)
|
|
1129
|
+
if _has_isaaclab:
|
|
1130
|
+
from isaaclab.envs import ManagerBasedRLEnv
|
|
1131
|
+
|
|
1132
|
+
tuple_of_classes = tuple_of_classes + (ManagerBasedRLEnv,)
|
|
1133
|
+
return isinstance(
|
|
1134
|
+
self._env.unwrapped, tuple_of_classes + (gym_backend("vector").VectorEnv,)
|
|
1135
|
+
)
|
|
1136
|
+
|
|
1137
|
+
@implement_for("gym")
|
|
1138
|
+
def _get_batch_size(self, env):
|
|
1139
|
+
if hasattr(env, "num_envs"):
|
|
1140
|
+
batch_size = torch.Size([env.num_envs, *self.batch_size])
|
|
1141
|
+
else:
|
|
1142
|
+
batch_size = self.batch_size
|
|
1143
|
+
return batch_size
|
|
1144
|
+
|
|
1145
|
+
@implement_for("gymnasium", None, "1.0.0") # gymnasium wants the unwrapped env
|
|
1146
|
+
def _get_batch_size(self, env): # noqa: F811
|
|
1147
|
+
env_unwrapped = env.unwrapped
|
|
1148
|
+
if hasattr(env_unwrapped, "num_envs"):
|
|
1149
|
+
batch_size = torch.Size([env_unwrapped.num_envs, *self.batch_size])
|
|
1150
|
+
else:
|
|
1151
|
+
batch_size = self.batch_size
|
|
1152
|
+
return batch_size
|
|
1153
|
+
|
|
1154
|
+
@implement_for("gymnasium", "1.0.0", "1.1.0")
|
|
1155
|
+
def _get_batch_size(self, env): # noqa: F811
|
|
1156
|
+
raise ImportError(GYMNASIUM_1_ERROR)
|
|
1157
|
+
|
|
1158
|
+
@implement_for("gymnasium", "1.1.0") # gymnasium wants the unwrapped env
|
|
1159
|
+
def _get_batch_size(self, env): # noqa: F811
|
|
1160
|
+
env_unwrapped = env.unwrapped
|
|
1161
|
+
if hasattr(env_unwrapped, "num_envs"):
|
|
1162
|
+
batch_size = torch.Size([env_unwrapped.num_envs, *self.batch_size])
|
|
1163
|
+
else:
|
|
1164
|
+
batch_size = self.batch_size
|
|
1165
|
+
return batch_size
|
|
1166
|
+
|
|
1167
|
+
def _check_kwargs(self, kwargs: dict):
|
|
1168
|
+
if "env" not in kwargs:
|
|
1169
|
+
raise TypeError("Could not find environment key 'env' in kwargs.")
|
|
1170
|
+
env = kwargs["env"]
|
|
1171
|
+
if not (hasattr(env, "action_space") and hasattr(env, "observation_space")):
|
|
1172
|
+
raise TypeError("env is not of type 'gym.Env'.")
|
|
1173
|
+
|
|
1174
|
+
def _build_env(
|
|
1175
|
+
self,
|
|
1176
|
+
env,
|
|
1177
|
+
from_pixels: bool = False,
|
|
1178
|
+
pixels_only: bool = False,
|
|
1179
|
+
) -> gym.core.Env: # noqa: F821
|
|
1180
|
+
self.batch_size = self._get_batch_size(env)
|
|
1181
|
+
|
|
1182
|
+
env_from_pixels = _is_from_pixels(env)
|
|
1183
|
+
from_pixels = from_pixels or env_from_pixels
|
|
1184
|
+
self.from_pixels = from_pixels
|
|
1185
|
+
self.pixels_only = pixels_only
|
|
1186
|
+
if from_pixels and not env_from_pixels:
|
|
1187
|
+
try:
|
|
1188
|
+
PixelObservationWrapper = gym_backend(
|
|
1189
|
+
"wrappers.pixel_observation.PixelObservationWrapper"
|
|
1190
|
+
)
|
|
1191
|
+
if isinstance(env, PixelObservationWrapper):
|
|
1192
|
+
raise TypeError(
|
|
1193
|
+
"PixelObservationWrapper cannot be used to wrap an environment "
|
|
1194
|
+
"that is already a PixelObservationWrapper instance."
|
|
1195
|
+
)
|
|
1196
|
+
except ModuleNotFoundError:
|
|
1197
|
+
pass
|
|
1198
|
+
env = self._build_gym_env(env, pixels_only)
|
|
1199
|
+
return env
|
|
1200
|
+
|
|
1201
|
+
def read_action(self, action):
|
|
1202
|
+
action = super().read_action(action)
|
|
1203
|
+
if isinstance(self.action_spec, (OneHot, Categorical)) and action.size == 1:
|
|
1204
|
+
# some envs require an integer for indexing
|
|
1205
|
+
action = int(action)
|
|
1206
|
+
return action
|
|
1207
|
+
|
|
1208
|
+
@implement_for("gym", None, "0.19.0")
|
|
1209
|
+
def _build_gym_env(self, env, pixels_only): # noqa: F811
|
|
1210
|
+
from .utils import GymPixelObservationWrapper as PixelObservationWrapper
|
|
1211
|
+
|
|
1212
|
+
return PixelObservationWrapper(env, pixels_only=pixels_only)
|
|
1213
|
+
|
|
1214
|
+
@implement_for("gym", "0.19.0", "0.26.0")
|
|
1215
|
+
def _build_gym_env(self, env, pixels_only): # noqa: F811
|
|
1216
|
+
pixel_observation = gym_backend("wrappers.pixel_observation")
|
|
1217
|
+
return pixel_observation.PixelObservationWrapper(env, pixels_only=pixels_only)
|
|
1218
|
+
|
|
1219
|
+
@implement_for("gym", "0.26.0", None)
|
|
1220
|
+
def _build_gym_env(self, env, pixels_only): # noqa: F811
|
|
1221
|
+
compatibility = gym_backend("wrappers.compatibility")
|
|
1222
|
+
pixel_observation = gym_backend("wrappers.pixel_observation")
|
|
1223
|
+
|
|
1224
|
+
if env.render_mode:
|
|
1225
|
+
return pixel_observation.PixelObservationWrapper(
|
|
1226
|
+
env, pixels_only=pixels_only
|
|
1227
|
+
)
|
|
1228
|
+
|
|
1229
|
+
warnings.warn(
|
|
1230
|
+
"Environments provided to GymWrapper that need to be wrapped in PixelObservationWrapper "
|
|
1231
|
+
"should be created with `gym.make(env_name, render_mode=mode)` where possible,"
|
|
1232
|
+
'where mode is either "rgb_array" or any other supported mode.'
|
|
1233
|
+
)
|
|
1234
|
+
# resetting as 0.26 comes with a very 'nice' OrderEnforcing wrapper
|
|
1235
|
+
env = compatibility.EnvCompatibility(env)
|
|
1236
|
+
env.reset()
|
|
1237
|
+
from torchrl.envs.libs.utils import (
|
|
1238
|
+
GymPixelObservationWrapper as LegacyPixelObservationWrapper,
|
|
1239
|
+
)
|
|
1240
|
+
|
|
1241
|
+
return LegacyPixelObservationWrapper(env, pixels_only=pixels_only)
|
|
1242
|
+
|
|
1243
|
+
@implement_for("gymnasium", "1.0.0", "1.1.0")
|
|
1244
|
+
def _build_gym_env(self, env, pixels_only): # noqa: F811
|
|
1245
|
+
raise ImportError(GYMNASIUM_1_ERROR)
|
|
1246
|
+
|
|
1247
|
+
@implement_for("gymnasium", None, "1.0.0")
|
|
1248
|
+
def _build_gym_env(self, env, pixels_only): # noqa: F811
|
|
1249
|
+
compatibility = gym_backend("wrappers.compatibility")
|
|
1250
|
+
pixel_observation = gym_backend("wrappers.pixel_observation")
|
|
1251
|
+
|
|
1252
|
+
if env.render_mode:
|
|
1253
|
+
return pixel_observation.PixelObservationWrapper(
|
|
1254
|
+
env, pixels_only=pixels_only
|
|
1255
|
+
)
|
|
1256
|
+
|
|
1257
|
+
warnings.warn(
|
|
1258
|
+
"Environments provided to GymWrapper that need to be wrapped in PixelObservationWrapper "
|
|
1259
|
+
"should be created with `gym.make(env_name, render_mode=mode)` where possible,"
|
|
1260
|
+
'where mode is either "rgb_array" or any other supported mode.'
|
|
1261
|
+
)
|
|
1262
|
+
# resetting as 0.26 comes with a very 'nice' OrderEnforcing wrapper
|
|
1263
|
+
env = compatibility.EnvCompatibility(env)
|
|
1264
|
+
env.reset()
|
|
1265
|
+
from torchrl.envs.libs.utils import (
|
|
1266
|
+
GymPixelObservationWrapper as LegacyPixelObservationWrapper,
|
|
1267
|
+
)
|
|
1268
|
+
|
|
1269
|
+
return LegacyPixelObservationWrapper(env, pixels_only=pixels_only)
|
|
1270
|
+
|
|
1271
|
+
@implement_for("gymnasium", "1.1.0")
|
|
1272
|
+
def _build_gym_env(self, env, pixels_only): # noqa: F811
|
|
1273
|
+
wrappers = gym_backend("wrappers")
|
|
1274
|
+
|
|
1275
|
+
if env.render_mode:
|
|
1276
|
+
return wrappers.AddRenderObservation(env, render_only=pixels_only)
|
|
1277
|
+
|
|
1278
|
+
warnings.warn(
|
|
1279
|
+
"Environments provided to GymWrapper that need to be wrapped in PixelObservationWrapper "
|
|
1280
|
+
"should be created with `gym.make(env_name, render_mode=mode)` where possible,"
|
|
1281
|
+
'where mode is either "rgb_array" or any other supported mode.'
|
|
1282
|
+
)
|
|
1283
|
+
env.reset()
|
|
1284
|
+
from torchrl.envs.libs.utils import (
|
|
1285
|
+
GymPixelObservationWrapper as LegacyPixelObservationWrapper,
|
|
1286
|
+
)
|
|
1287
|
+
|
|
1288
|
+
return LegacyPixelObservationWrapper(env, pixels_only=pixels_only)
|
|
1289
|
+
|
|
1290
|
+
@property
|
|
1291
|
+
def lib(self) -> ModuleType:
|
|
1292
|
+
gym = gym_backend()
|
|
1293
|
+
if gym is None:
|
|
1294
|
+
raise RuntimeError(
|
|
1295
|
+
"Gym backend is not available. Please install gym or gymnasium."
|
|
1296
|
+
)
|
|
1297
|
+
return gym
|
|
1298
|
+
|
|
1299
|
+
def _set_seed(self, seed: int | None) -> None: # noqa: F811
|
|
1300
|
+
if self._seed_calls_reset is None:
|
|
1301
|
+
# Determine basing on gym version whether `reset` is called when setting seed.
|
|
1302
|
+
self._set_seed_initial(seed)
|
|
1303
|
+
elif self._seed_calls_reset:
|
|
1304
|
+
self.reset(seed=seed)
|
|
1305
|
+
else:
|
|
1306
|
+
self._env.seed(seed=seed)
|
|
1307
|
+
|
|
1308
|
+
@implement_for("gym", None, "0.15.0")
|
|
1309
|
+
def _set_seed_initial(self, seed: int) -> None: # noqa: F811
|
|
1310
|
+
self._seed_calls_reset = False
|
|
1311
|
+
self._env.seed(seed)
|
|
1312
|
+
|
|
1313
|
+
@implement_for("gym", "0.15.0", "0.19.0")
|
|
1314
|
+
def _set_seed_initial(self, seed: int) -> None: # noqa: F811
|
|
1315
|
+
self._seed_calls_reset = False
|
|
1316
|
+
self._env.seed(seed=seed)
|
|
1317
|
+
|
|
1318
|
+
@implement_for("gym", "0.19.0", "0.21.0")
|
|
1319
|
+
def _set_seed_initial(self, seed: int) -> None: # noqa: F811
|
|
1320
|
+
# In gym 0.19-0.21, reset() doesn't accept seed kwarg yet,
|
|
1321
|
+
# and VectorEnv.seed uses seeds= (plural) instead of seed=
|
|
1322
|
+
self._seed_calls_reset = False
|
|
1323
|
+
if hasattr(self._env, "num_envs"):
|
|
1324
|
+
# Vector environment uses seeds= (plural)
|
|
1325
|
+
self._env.seed(seeds=seed)
|
|
1326
|
+
else:
|
|
1327
|
+
self._env.seed(seed=seed)
|
|
1328
|
+
|
|
1329
|
+
@implement_for("gym", "0.21.0", None)
|
|
1330
|
+
def _set_seed_initial(self, seed: int) -> None: # noqa: F811
|
|
1331
|
+
try:
|
|
1332
|
+
self.reset(seed=seed)
|
|
1333
|
+
self._seed_calls_reset = True
|
|
1334
|
+
except TypeError as err:
|
|
1335
|
+
warnings.warn(
|
|
1336
|
+
f"reset with seed kwarg returned an exception: {err}.\n"
|
|
1337
|
+
f"Calling env.seed from now on."
|
|
1338
|
+
)
|
|
1339
|
+
self._seed_calls_reset = False
|
|
1340
|
+
try:
|
|
1341
|
+
self._env.seed(seed=seed)
|
|
1342
|
+
except AttributeError as err2:
|
|
1343
|
+
raise err from err2
|
|
1344
|
+
|
|
1345
|
+
@implement_for("gymnasium", "1.0.0", "1.1.0")
|
|
1346
|
+
def _set_seed_initial(self, seed: int) -> None: # noqa: F811
|
|
1347
|
+
raise ImportError(GYMNASIUM_1_ERROR)
|
|
1348
|
+
|
|
1349
|
+
@implement_for("gymnasium", None, "1.0.0")
|
|
1350
|
+
def _set_seed_initial(self, seed: int) -> None: # noqa: F811
|
|
1351
|
+
try:
|
|
1352
|
+
self.reset(seed=seed)
|
|
1353
|
+
self._seed_calls_reset = True
|
|
1354
|
+
except TypeError as err:
|
|
1355
|
+
warnings.warn(
|
|
1356
|
+
f"reset with seed kwarg returned an exception: {err}.\n"
|
|
1357
|
+
f"Calling env.seed from now on."
|
|
1358
|
+
)
|
|
1359
|
+
self._seed_calls_reset = False
|
|
1360
|
+
self._env.seed(seed=seed)
|
|
1361
|
+
|
|
1362
|
+
@implement_for("gymnasium", "1.1.0")
|
|
1363
|
+
def _set_seed_initial(self, seed: int) -> None: # noqa: F811
|
|
1364
|
+
try:
|
|
1365
|
+
self.reset(seed=seed)
|
|
1366
|
+
self._seed_calls_reset = True
|
|
1367
|
+
except TypeError as err:
|
|
1368
|
+
warnings.warn(
|
|
1369
|
+
f"reset with seed kwarg returned an exception: {err}.\n"
|
|
1370
|
+
f"Calling env.seed from now on."
|
|
1371
|
+
)
|
|
1372
|
+
self._seed_calls_reset = False
|
|
1373
|
+
self._env.seed(seed=seed)
|
|
1374
|
+
|
|
1375
|
+
@implement_for("gym")
|
|
1376
|
+
def _reward_space(self, env):
|
|
1377
|
+
if hasattr(env, "reward_space") and env.reward_space is not None:
|
|
1378
|
+
return env.reward_space
|
|
1379
|
+
|
|
1380
|
+
@implement_for("gymnasium", "1.0.0", "1.1.0")
|
|
1381
|
+
def _reward_space(self, env): # noqa: F811
|
|
1382
|
+
raise ImportError(GYMNASIUM_1_ERROR)
|
|
1383
|
+
|
|
1384
|
+
@implement_for("gymnasium", None, "1.0.0")
|
|
1385
|
+
def _reward_space(self, env): # noqa: F811
|
|
1386
|
+
env = env.unwrapped
|
|
1387
|
+
if hasattr(env, "reward_space") and env.reward_space is not None:
|
|
1388
|
+
rs = env.reward_space
|
|
1389
|
+
return rs
|
|
1390
|
+
|
|
1391
|
+
@implement_for("gymnasium", "1.1.0")
|
|
1392
|
+
def _reward_space(self, env): # noqa: F811
|
|
1393
|
+
env = env.unwrapped
|
|
1394
|
+
if hasattr(env, "reward_space") and env.reward_space is not None:
|
|
1395
|
+
rs = env.reward_space
|
|
1396
|
+
return rs
|
|
1397
|
+
|
|
1398
|
+
def _make_specs(self, env: gym.Env, batch_size=None) -> None: # noqa: F821
|
|
1399
|
+
# If batch_size is provided, we set it to tell what batch size must be used
|
|
1400
|
+
# instead of self.batch_size
|
|
1401
|
+
cur_batch_size = self.batch_size if batch_size is None else torch.Size([])
|
|
1402
|
+
observation_spec = _gym_to_torchrl_spec_transform(
|
|
1403
|
+
env.observation_space,
|
|
1404
|
+
device=self.device,
|
|
1405
|
+
categorical_action_encoding=self._categorical_action_encoding,
|
|
1406
|
+
)
|
|
1407
|
+
action_spec = _gym_to_torchrl_spec_transform(
|
|
1408
|
+
env.action_space,
|
|
1409
|
+
device=self.device,
|
|
1410
|
+
categorical_action_encoding=self._categorical_action_encoding,
|
|
1411
|
+
)
|
|
1412
|
+
# When the action space is MultiDiscrete and an action_mask is present in the
|
|
1413
|
+
# observation with shape matching nvec, we convert to a flattened Categorical/OneHot
|
|
1414
|
+
# so that the mask can be applied directly to all possible action combinations.
|
|
1415
|
+
# This is useful for grid-based games where the mask indicates valid (row, col) positions.
|
|
1416
|
+
gym_spaces = gym_backend("spaces")
|
|
1417
|
+
MultiDiscrete = getattr(gym_spaces, "MultiDiscrete", None)
|
|
1418
|
+
if MultiDiscrete is None:
|
|
1419
|
+
# Fallback for gym versions where MultiDiscrete is in a submodule
|
|
1420
|
+
multi_discrete_module = getattr(gym_spaces, "multi_discrete", None)
|
|
1421
|
+
if multi_discrete_module is not None:
|
|
1422
|
+
MultiDiscrete = getattr(multi_discrete_module, "MultiDiscrete", None)
|
|
1423
|
+
if MultiDiscrete is not None and isinstance(env.action_space, MultiDiscrete):
|
|
1424
|
+
nvec = np.asarray(env.action_space.nvec)
|
|
1425
|
+
if (
|
|
1426
|
+
nvec.ndim == 1
|
|
1427
|
+
and isinstance(observation_spec, Composite)
|
|
1428
|
+
and "action_mask" in observation_spec
|
|
1429
|
+
):
|
|
1430
|
+
mask_spec = observation_spec["action_mask"]
|
|
1431
|
+
if tuple(mask_spec.shape) == tuple(nvec):
|
|
1432
|
+
prod_n = int(np.prod(nvec))
|
|
1433
|
+
dtype = (
|
|
1434
|
+
numpy_to_torch_dtype_dict[env.action_space.dtype]
|
|
1435
|
+
if self._categorical_action_encoding
|
|
1436
|
+
else torch.long
|
|
1437
|
+
)
|
|
1438
|
+
# Flattened action: single choice from prod(nvec) options.
|
|
1439
|
+
# The mask (which has shape matching nvec) will be reshaped
|
|
1440
|
+
# by Categorical/OneHot.update_mask when applied.
|
|
1441
|
+
if self._categorical_action_encoding:
|
|
1442
|
+
action_spec = Categorical(
|
|
1443
|
+
prod_n,
|
|
1444
|
+
shape=(),
|
|
1445
|
+
device=self.device,
|
|
1446
|
+
dtype=dtype,
|
|
1447
|
+
)
|
|
1448
|
+
else:
|
|
1449
|
+
action_spec = OneHot(
|
|
1450
|
+
prod_n,
|
|
1451
|
+
shape=(prod_n,),
|
|
1452
|
+
device=self.device,
|
|
1453
|
+
dtype=torch.bool,
|
|
1454
|
+
)
|
|
1455
|
+
if not isinstance(observation_spec, Composite):
|
|
1456
|
+
if self.from_pixels:
|
|
1457
|
+
observation_spec = Composite(
|
|
1458
|
+
pixels=observation_spec, shape=cur_batch_size
|
|
1459
|
+
)
|
|
1460
|
+
else:
|
|
1461
|
+
observation_spec = Composite(
|
|
1462
|
+
observation=observation_spec, shape=cur_batch_size
|
|
1463
|
+
)
|
|
1464
|
+
elif observation_spec.shape[: len(cur_batch_size)] != cur_batch_size:
|
|
1465
|
+
observation_spec.shape = cur_batch_size
|
|
1466
|
+
|
|
1467
|
+
reward_space = self._reward_space(env)
|
|
1468
|
+
if reward_space is not None:
|
|
1469
|
+
reward_spec = _gym_to_torchrl_spec_transform(
|
|
1470
|
+
reward_space,
|
|
1471
|
+
device=self.device,
|
|
1472
|
+
categorical_action_encoding=self._categorical_action_encoding,
|
|
1473
|
+
)
|
|
1474
|
+
else:
|
|
1475
|
+
reward_spec = Unbounded(
|
|
1476
|
+
shape=[1],
|
|
1477
|
+
device=self.device,
|
|
1478
|
+
)
|
|
1479
|
+
if batch_size is not None:
|
|
1480
|
+
action_spec = action_spec.expand(*batch_size, *action_spec.shape)
|
|
1481
|
+
reward_spec = reward_spec.expand(*batch_size, *reward_spec.shape)
|
|
1482
|
+
observation_spec = observation_spec.expand(
|
|
1483
|
+
*batch_size, *observation_spec.shape
|
|
1484
|
+
)
|
|
1485
|
+
|
|
1486
|
+
self.done_spec = self._make_done_spec()
|
|
1487
|
+
self.action_spec = action_spec
|
|
1488
|
+
if reward_spec.shape[: len(cur_batch_size)] != cur_batch_size:
|
|
1489
|
+
self.reward_spec = reward_spec.expand(*cur_batch_size, *reward_spec.shape)
|
|
1490
|
+
else:
|
|
1491
|
+
self.reward_spec = reward_spec
|
|
1492
|
+
self.observation_spec = observation_spec
|
|
1493
|
+
|
|
1494
|
+
@implement_for("gym", None, "0.26")
|
|
1495
|
+
def _make_done_spec(self): # noqa: F811
|
|
1496
|
+
return Composite(
|
|
1497
|
+
{
|
|
1498
|
+
"done": Categorical(
|
|
1499
|
+
2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
|
|
1500
|
+
),
|
|
1501
|
+
"terminated": Categorical(
|
|
1502
|
+
2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
|
|
1503
|
+
),
|
|
1504
|
+
"truncated": Categorical(
|
|
1505
|
+
2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
|
|
1506
|
+
),
|
|
1507
|
+
},
|
|
1508
|
+
shape=self.batch_size,
|
|
1509
|
+
)
|
|
1510
|
+
|
|
1511
|
+
@implement_for("gym", "0.26", None)
|
|
1512
|
+
def _make_done_spec(self): # noqa: F811
|
|
1513
|
+
return Composite(
|
|
1514
|
+
{
|
|
1515
|
+
"done": Categorical(
|
|
1516
|
+
2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
|
|
1517
|
+
),
|
|
1518
|
+
"terminated": Categorical(
|
|
1519
|
+
2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
|
|
1520
|
+
),
|
|
1521
|
+
"truncated": Categorical(
|
|
1522
|
+
2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
|
|
1523
|
+
),
|
|
1524
|
+
},
|
|
1525
|
+
shape=self.batch_size,
|
|
1526
|
+
)
|
|
1527
|
+
|
|
1528
|
+
@implement_for("gymnasium", "0.27", None)
|
|
1529
|
+
def _make_done_spec(self): # noqa: F811
|
|
1530
|
+
return Composite(
|
|
1531
|
+
{
|
|
1532
|
+
"done": Categorical(
|
|
1533
|
+
2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
|
|
1534
|
+
),
|
|
1535
|
+
"terminated": Categorical(
|
|
1536
|
+
2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
|
|
1537
|
+
),
|
|
1538
|
+
"truncated": Categorical(
|
|
1539
|
+
2, dtype=torch.bool, device=self.device, shape=(*self.batch_size, 1)
|
|
1540
|
+
),
|
|
1541
|
+
},
|
|
1542
|
+
shape=self.batch_size,
|
|
1543
|
+
)
|
|
1544
|
+
|
|
1545
|
+
@implement_for("gym", None, "0.26")
|
|
1546
|
+
def _reset_output_transform(self, reset_data): # noqa: F811
|
|
1547
|
+
if (
|
|
1548
|
+
isinstance(reset_data, tuple)
|
|
1549
|
+
and len(reset_data) == 2
|
|
1550
|
+
and isinstance(reset_data[1], dict)
|
|
1551
|
+
):
|
|
1552
|
+
return reset_data
|
|
1553
|
+
return reset_data, None
|
|
1554
|
+
|
|
1555
|
+
@implement_for("gym", "0.26", None)
|
|
1556
|
+
def _reset_output_transform(self, reset_data): # noqa: F811
|
|
1557
|
+
return reset_data
|
|
1558
|
+
|
|
1559
|
+
@implement_for("gymnasium", "0.27", None)
|
|
1560
|
+
def _reset_output_transform(self, reset_data): # noqa: F811
|
|
1561
|
+
return reset_data
|
|
1562
|
+
|
|
1563
|
+
@implement_for("gym", None, "0.24")
|
|
1564
|
+
def _output_transform(self, step_outputs_tuple): # noqa: F811
|
|
1565
|
+
observations, reward, done, info = step_outputs_tuple
|
|
1566
|
+
if self._is_batched:
|
|
1567
|
+
# info needs to be flipped
|
|
1568
|
+
info = _flip_info_tuple(info)
|
|
1569
|
+
# The variable naming follows torchrl's convention here.
|
|
1570
|
+
# A done is interpreted the union of terminated and truncated.
|
|
1571
|
+
# (as in earlier versions of gym).
|
|
1572
|
+
truncated = info.pop("TimeLimit.truncated", False)
|
|
1573
|
+
if not isinstance(done, bool) and isinstance(truncated, bool):
|
|
1574
|
+
# if bool is an array, make truncated an array
|
|
1575
|
+
truncated = [truncated] * len(done)
|
|
1576
|
+
truncated = np.array(truncated)
|
|
1577
|
+
elif not isinstance(truncated, bool):
|
|
1578
|
+
# make sure it's a boolean np.array
|
|
1579
|
+
truncated = np.array(truncated, dtype=np.dtype("bool"))
|
|
1580
|
+
terminated = done & ~truncated
|
|
1581
|
+
if not isinstance(terminated, np.ndarray):
|
|
1582
|
+
# if it's not a ndarray, we must return bool
|
|
1583
|
+
# since it's not a bool, we make it so
|
|
1584
|
+
terminated = bool(terminated)
|
|
1585
|
+
|
|
1586
|
+
if isinstance(observations, list) and len(observations) == 1:
|
|
1587
|
+
# Until gym 0.25.2 we had rendered frames returned in lists of length 1
|
|
1588
|
+
observations = observations[0]
|
|
1589
|
+
|
|
1590
|
+
return (observations, reward, terminated, truncated, done, info)
|
|
1591
|
+
|
|
1592
|
+
@implement_for("gym", "0.24", "0.26")
|
|
1593
|
+
def _output_transform(self, step_outputs_tuple): # noqa: F811
|
|
1594
|
+
observations, reward, done, info = step_outputs_tuple
|
|
1595
|
+
# The variable naming follows torchrl's convention here.
|
|
1596
|
+
# A done is interpreted the union of terminated and truncated.
|
|
1597
|
+
# (as in earlier versions of gym).
|
|
1598
|
+
truncated = info.pop("TimeLimit.truncated", False)
|
|
1599
|
+
if not isinstance(done, bool) and isinstance(truncated, bool):
|
|
1600
|
+
# if bool is an array, make truncated an array
|
|
1601
|
+
truncated = [truncated] * len(done)
|
|
1602
|
+
truncated = np.array(truncated)
|
|
1603
|
+
elif not isinstance(truncated, bool):
|
|
1604
|
+
# make sure it's a boolean np.array
|
|
1605
|
+
truncated = np.array(truncated, dtype=np.dtype("bool"))
|
|
1606
|
+
terminated = done & ~truncated
|
|
1607
|
+
if not isinstance(terminated, np.ndarray):
|
|
1608
|
+
# if it's not a ndarray, we must return bool
|
|
1609
|
+
# since it's not a bool, we make it so
|
|
1610
|
+
terminated = bool(terminated)
|
|
1611
|
+
|
|
1612
|
+
if isinstance(observations, list) and len(observations) == 1:
|
|
1613
|
+
# Until gym 0.25.2 we had rendered frames returned in lists of length 1
|
|
1614
|
+
observations = observations[0]
|
|
1615
|
+
|
|
1616
|
+
return (observations, reward, terminated, truncated, done, info)
|
|
1617
|
+
|
|
1618
|
+
@implement_for("gym", "0.26", None)
|
|
1619
|
+
def _output_transform(self, step_outputs_tuple): # noqa: F811
|
|
1620
|
+
# The variable naming follows torchrl's convention here.
|
|
1621
|
+
observations, reward, terminated, truncated, info = step_outputs_tuple
|
|
1622
|
+
return (
|
|
1623
|
+
observations,
|
|
1624
|
+
reward,
|
|
1625
|
+
terminated,
|
|
1626
|
+
truncated,
|
|
1627
|
+
terminated | truncated,
|
|
1628
|
+
info,
|
|
1629
|
+
)
|
|
1630
|
+
|
|
1631
|
+
@implement_for("gymnasium", "0.27", None)
|
|
1632
|
+
def _output_transform(self, step_outputs_tuple): # noqa: F811
|
|
1633
|
+
# The variable naming follows torchrl's convention here.
|
|
1634
|
+
observations, reward, terminated, truncated, info = step_outputs_tuple
|
|
1635
|
+
return (
|
|
1636
|
+
observations,
|
|
1637
|
+
reward,
|
|
1638
|
+
terminated,
|
|
1639
|
+
truncated,
|
|
1640
|
+
terminated | truncated,
|
|
1641
|
+
info,
|
|
1642
|
+
)
|
|
1643
|
+
|
|
1644
|
+
def _init_env(self):
|
|
1645
|
+
pass
|
|
1646
|
+
# init_reset = self.init_reset
|
|
1647
|
+
# if init_reset is None:
|
|
1648
|
+
# warnings.warn(f"init_env is None in the {type(self).__name__} constructor. The current "
|
|
1649
|
+
# f"default behavior is to reset the gym env as soon as it's wrapped in the "
|
|
1650
|
+
# f"class (init_reset=True), but from v0.9 this will be changed to False. "
|
|
1651
|
+
# f"To adapt for these changes, pass init_reset to your constructor.", category=FutureWarning)
|
|
1652
|
+
# init_reset = True
|
|
1653
|
+
# if init_reset:
|
|
1654
|
+
# self._env.reset()
|
|
1655
|
+
|
|
1656
|
+
def __repr__(self) -> str:
|
|
1657
|
+
return (
|
|
1658
|
+
f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})"
|
|
1659
|
+
)
|
|
1660
|
+
|
|
1661
|
+
def rebuild_with_kwargs(self, **new_kwargs):
|
|
1662
|
+
self._constructor_kwargs.update(new_kwargs)
|
|
1663
|
+
self._env = self._build_env(**self._constructor_kwargs)
|
|
1664
|
+
self._make_specs(self._env)
|
|
1665
|
+
|
|
1666
|
+
@implement_for("gym")
|
|
1667
|
+
def _replace_reset(self, reset, kwargs):
|
|
1668
|
+
return kwargs
|
|
1669
|
+
|
|
1670
|
+
@implement_for("gymnasium", None, "1.1.0")
|
|
1671
|
+
def _replace_reset(self, reset, kwargs): # noqa
|
|
1672
|
+
return kwargs
|
|
1673
|
+
|
|
1674
|
+
# From gymnasium 1.1.0, AutoresetMode.DISABLED is like resets in torchrl
|
|
1675
|
+
@implement_for("gymnasium", "1.1.0")
|
|
1676
|
+
def _replace_reset(self, reset, kwargs): # noqa
|
|
1677
|
+
import gymnasium as gym
|
|
1678
|
+
|
|
1679
|
+
if (
|
|
1680
|
+
getattr(self._env, "autoreset_mode", None)
|
|
1681
|
+
== gym.vector.AutoresetMode.DISABLED
|
|
1682
|
+
):
|
|
1683
|
+
options = {"reset_mask": reset.view(self.batch_size).numpy()}
|
|
1684
|
+
kwargs.setdefault("options", {}).update(options)
|
|
1685
|
+
return kwargs
|
|
1686
|
+
|
|
1687
|
+
def _reset(
|
|
1688
|
+
self, tensordict: TensorDictBase | None = None, **kwargs
|
|
1689
|
+
) -> TensorDictBase:
|
|
1690
|
+
if self._is_batched:
|
|
1691
|
+
# batched (aka 'vectorized') env reset is a bit special: envs are
|
|
1692
|
+
# automatically reset. What we do here is just to check if _reset
|
|
1693
|
+
# is present. If it is not, we just reset. Otherwise, we just skip.
|
|
1694
|
+
if tensordict is None:
|
|
1695
|
+
return super()._reset(tensordict, **kwargs)
|
|
1696
|
+
reset = tensordict.get("_reset", None)
|
|
1697
|
+
kwargs = self._replace_reset(reset, kwargs)
|
|
1698
|
+
if reset is not None:
|
|
1699
|
+
# we must copy the tensordict because the transform
|
|
1700
|
+
# expects a tuple (tensordict, tensordict_reset) where the
|
|
1701
|
+
# first still carries a _reset
|
|
1702
|
+
tensordict = tensordict.exclude("_reset")
|
|
1703
|
+
if reset is None or reset.all() or "options" in kwargs:
|
|
1704
|
+
result = super()._reset(tensordict, **kwargs)
|
|
1705
|
+
return result
|
|
1706
|
+
else:
|
|
1707
|
+
return tensordict
|
|
1708
|
+
return super()._reset(tensordict, **kwargs)
|
|
1709
|
+
|
|
1710
|
+
|
|
1711
|
+
ACCEPTED_TYPE_ERRORS = {
|
|
1712
|
+
"render_mode": "__init__() got an unexpected keyword argument 'render_mode'",
|
|
1713
|
+
"frame_skip": "unexpected keyword argument 'frameskip'",
|
|
1714
|
+
}
|
|
1715
|
+
|
|
1716
|
+
|
|
1717
|
+
class GymEnv(GymWrapper):
|
|
1718
|
+
"""OpenAI Gym environment wrapper constructed by environment ID directly.
|
|
1719
|
+
|
|
1720
|
+
Works across `gymnasium <https://gymnasium.farama.org/>`_ and `OpenAI/gym <https://github.com/openai/gym>`_.
|
|
1721
|
+
|
|
1722
|
+
Args:
|
|
1723
|
+
env_name (str): the environment id registered in `gym.registry`.
|
|
1724
|
+
categorical_action_encoding (bool, optional): if ``True``, categorical
|
|
1725
|
+
specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
|
|
1726
|
+
otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
|
|
1727
|
+
Defaults to ``False``.
|
|
1728
|
+
|
|
1729
|
+
Keyword Args:
|
|
1730
|
+
num_envs (int, optional): the number of envs to run in parallel. Defaults to
|
|
1731
|
+
``None`` (a single env is to be run). :class:`~gym.vector.AsyncVectorEnv`
|
|
1732
|
+
will be used by default.
|
|
1733
|
+
num_workers (int, optional): number of top-level worker subprocesses used to create/run
|
|
1734
|
+
multiple :class:`GymEnv` instances in parallel (handled by the metaclass
|
|
1735
|
+
:class:`_GymAsyncMeta`). When ``num_workers > 1``, a lazy
|
|
1736
|
+
:class:`~torchrl.envs.ParallelEnv` is returned whose factory preserves the original
|
|
1737
|
+
`GymEnv` kwargs. You can modify the ParallelEnv construction/configuration before
|
|
1738
|
+
it starts by calling :meth:`~torchrl.envs.batched_envs.BatchedEnvBase.configure_parallel`
|
|
1739
|
+
on the returned object (for example: ``env.configure_parallel(use_buffers=True, num_threads=2)``).
|
|
1740
|
+
When both ``num_workers`` and ``num_envs`` are greater than 1, the total number of
|
|
1741
|
+
environments executed in parallel is ``num_workers * num_envs``. Defaults to ``1``.
|
|
1742
|
+
disable_env_checker (bool, optional): for gym > 0.24 only. If ``True`` (default
|
|
1743
|
+
for these versions), the environment checker won't be run.
|
|
1744
|
+
from_pixels (bool, optional): if ``True``, an attempt to return the pixel
|
|
1745
|
+
observations from the env will be performed. By default, these observations
|
|
1746
|
+
will be written under the ``"pixels"`` entry.
|
|
1747
|
+
The method being used varies
|
|
1748
|
+
depending on the gym version and may involve a ``wrappers.pixel_observation.PixelObservationWrapper``.
|
|
1749
|
+
Defaults to ``False``.
|
|
1750
|
+
pixels_only (bool, optional): if ``True``, only the pixel observations will
|
|
1751
|
+
be returned (by default under the ``"pixels"`` entry in the output tensordict).
|
|
1752
|
+
If ``False``, observations (eg, states) and pixels will be returned
|
|
1753
|
+
whenever ``from_pixels=True``. Defaults to ``False``.
|
|
1754
|
+
frame_skip (int, optional): if provided, indicates for how many steps the
|
|
1755
|
+
same action is to be repeated. The observation returned will be the
|
|
1756
|
+
last observation of the sequence, whereas the reward will be the sum
|
|
1757
|
+
of rewards across steps.
|
|
1758
|
+
device (torch.device, optional): if provided, the device on which the data
|
|
1759
|
+
is to be cast. Defaults to ``torch.device("cpu")``.
|
|
1760
|
+
batch_size (torch.Size, optional): the batch size of the environment.
|
|
1761
|
+
Should match the leading dimensions of all observations, done states,
|
|
1762
|
+
rewards, actions and infos.
|
|
1763
|
+
Defaults to ``torch.Size([])``.
|
|
1764
|
+
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
|
|
1765
|
+
for envs to be ``done`` just after :meth:`reset` is called.
|
|
1766
|
+
Defaults to ``False``.
|
|
1767
|
+
|
|
1768
|
+
Attributes:
|
|
1769
|
+
available_envs (List[str]): the list of envs that can be built.
|
|
1770
|
+
|
|
1771
|
+
.. note::
|
|
1772
|
+
If an attribute cannot be found, this class will attempt to retrieve it from
|
|
1773
|
+
the nested env:
|
|
1774
|
+
|
|
1775
|
+
>>> from torchrl.envs import GymEnv
|
|
1776
|
+
>>> env = GymEnv("Pendulum-v1")
|
|
1777
|
+
>>> print(env.spec.max_episode_steps)
|
|
1778
|
+
200
|
|
1779
|
+
|
|
1780
|
+
|
|
1781
|
+
If a use-case is not covered by TorchRL, please submit an issue on GitHub.
|
|
1782
|
+
|
|
1783
|
+
Examples:
|
|
1784
|
+
>>> from torchrl.envs import GymEnv
|
|
1785
|
+
>>> env = GymEnv("Pendulum-v1")
|
|
1786
|
+
>>> td = env.rand_step()
|
|
1787
|
+
>>> print(td)
|
|
1788
|
+
TensorDict(
|
|
1789
|
+
fields={
|
|
1790
|
+
action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1791
|
+
next: TensorDict(
|
|
1792
|
+
fields={
|
|
1793
|
+
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1794
|
+
observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1795
|
+
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1796
|
+
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1797
|
+
truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
1798
|
+
batch_size=torch.Size([]),
|
|
1799
|
+
device=cpu,
|
|
1800
|
+
is_shared=False)},
|
|
1801
|
+
batch_size=torch.Size([]),
|
|
1802
|
+
device=cpu,
|
|
1803
|
+
is_shared=False)
|
|
1804
|
+
>>> print(env.available_envs)
|
|
1805
|
+
['ALE/Adventure-ram-v5', 'ALE/Adventure-v5', 'ALE/AirRaid-ram-v5', 'ALE/AirRaid-v5', 'ALE/Alien-ram-v5', 'ALE/Alien-v5',
|
|
1806
|
+
|
|
1807
|
+
To run multiple environments in parallel:
|
|
1808
|
+
>>> from torchrl.envs import GymEnv
|
|
1809
|
+
>>> env = GymEnv("Pendulum-v1", num_workers=4)
|
|
1810
|
+
>>> td_reset = env.reset()
|
|
1811
|
+
>>> td = env.rand_step(td_reset)
|
|
1812
|
+
>>> print(td)
|
|
1813
|
+
TensorDict(
|
|
1814
|
+
fields={
|
|
1815
|
+
action: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1816
|
+
done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1817
|
+
next: TensorDict(
|
|
1818
|
+
fields={
|
|
1819
|
+
done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1820
|
+
observation: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1821
|
+
reward: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1822
|
+
terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1823
|
+
truncated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
1824
|
+
batch_size=torch.Size([4]),
|
|
1825
|
+
device=None,
|
|
1826
|
+
is_shared=False),
|
|
1827
|
+
observation: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1828
|
+
terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1829
|
+
truncated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
1830
|
+
batch_size=torch.Size([4]),
|
|
1831
|
+
device=None,
|
|
1832
|
+
is_shared=False)
|
|
1833
|
+
|
|
1834
|
+
.. note::
|
|
1835
|
+
If both `OpenAI/gym` and `gymnasium` are present in the virtual environment,
|
|
1836
|
+
one can swap backend using :func:`~torchrl.envs.libs.gym.set_gym_backend`:
|
|
1837
|
+
|
|
1838
|
+
>>> from torchrl.envs import set_gym_backend, GymEnv
|
|
1839
|
+
>>> with set_gym_backend("gym"):
|
|
1840
|
+
... env = GymEnv("Pendulum-v1")
|
|
1841
|
+
... print(env._env)
|
|
1842
|
+
<class 'gym.wrappers.time_limit.TimeLimit'>
|
|
1843
|
+
>>> with set_gym_backend("gymnasium"):
|
|
1844
|
+
... env = GymEnv("Pendulum-v1")
|
|
1845
|
+
... print(env._env)
|
|
1846
|
+
<class 'gymnasium.wrappers.time_limit.TimeLimit'>
|
|
1847
|
+
|
|
1848
|
+
.. note::
|
|
1849
|
+
info dictionaries will be read using :class:`~torchrl.envs.gym_like.default_info_dict_reader`
|
|
1850
|
+
if no other reader is provided. To provide another reader, refer to
|
|
1851
|
+
:meth:`set_info_dict_reader`. To automatically register the info_dict
|
|
1852
|
+
content, refer to :meth:`torchrl.envs.GymLikeEnv.auto_register_info_dict`.
|
|
1853
|
+
|
|
1854
|
+
.. note:: Gym spaces are not completely covered.
|
|
1855
|
+
The following spaces are accounted for provided that they can be represented by a torch.Tensor, a nested tensor
|
|
1856
|
+
and/or within a tensordict:
|
|
1857
|
+
|
|
1858
|
+
- spaces.Box
|
|
1859
|
+
- spaces.Sequence
|
|
1860
|
+
- spaces.Tuple
|
|
1861
|
+
- spaces.Discrete
|
|
1862
|
+
- spaces.MultiBinary
|
|
1863
|
+
- spaces.MultiDiscrete
|
|
1864
|
+
- spaces.Dict
|
|
1865
|
+
|
|
1866
|
+
Some considerations should be made when working with gym spaces. For instance, a tuple of spaces
|
|
1867
|
+
can only be supported if the spaces are semantically identical (same dtype and same number of dimensions).
|
|
1868
|
+
Ragged dimension can be supported through :func:`~torch.nested.nested_tensor`, but then there should be only
|
|
1869
|
+
one level of tuple and data should be stacked along the first dimension (as nested_tensors can only be
|
|
1870
|
+
stacked along the first dimension).
|
|
1871
|
+
|
|
1872
|
+
Check the example in examples/envs/gym_conversion_examples.py to know more!
|
|
1873
|
+
|
|
1874
|
+
"""
|
|
1875
|
+
|
|
1876
|
+
def __init__(self, env_name, **kwargs):
|
|
1877
|
+
backend = kwargs.pop("backend", None)
|
|
1878
|
+
with set_gym_backend(backend) if backend is not None else nullcontext():
|
|
1879
|
+
kwargs["env_name"] = env_name
|
|
1880
|
+
self._set_gym_args(kwargs)
|
|
1881
|
+
super().__init__(**kwargs)
|
|
1882
|
+
|
|
1883
|
+
@implement_for("gym", None, "0.24.0")
|
|
1884
|
+
def _set_gym_args(self, kwargs) -> None: # noqa: F811
|
|
1885
|
+
disable_env_checker = kwargs.pop("disable_env_checker", None)
|
|
1886
|
+
if disable_env_checker is not None:
|
|
1887
|
+
raise RuntimeError(
|
|
1888
|
+
"disable_env_checker should only be set if gym version is > 0.24"
|
|
1889
|
+
)
|
|
1890
|
+
|
|
1891
|
+
@implement_for("gym", "0.24.0", None)
|
|
1892
|
+
def _set_gym_args( # noqa: F811
|
|
1893
|
+
self,
|
|
1894
|
+
kwargs,
|
|
1895
|
+
) -> None:
|
|
1896
|
+
kwargs.setdefault("disable_env_checker", True)
|
|
1897
|
+
|
|
1898
|
+
@implement_for("gymnasium", "1.0.0", "1.1.0")
|
|
1899
|
+
def _set_gym_args( # noqa: F811
|
|
1900
|
+
self,
|
|
1901
|
+
kwargs,
|
|
1902
|
+
) -> None:
|
|
1903
|
+
raise ImportError(GYMNASIUM_1_ERROR)
|
|
1904
|
+
|
|
1905
|
+
@implement_for("gymnasium", None, "1.0.0")
|
|
1906
|
+
def _set_gym_args( # noqa: F811
|
|
1907
|
+
self,
|
|
1908
|
+
kwargs,
|
|
1909
|
+
) -> None:
|
|
1910
|
+
kwargs.setdefault("disable_env_checker", True)
|
|
1911
|
+
|
|
1912
|
+
@implement_for("gymnasium", "1.1.0")
|
|
1913
|
+
def _set_gym_args( # noqa: F811
|
|
1914
|
+
self,
|
|
1915
|
+
kwargs,
|
|
1916
|
+
) -> None:
|
|
1917
|
+
kwargs.setdefault("disable_env_checker", True)
|
|
1918
|
+
|
|
1919
|
+
def _async_env(self, *args, **kwargs):
|
|
1920
|
+
return gym_backend("vector").AsyncVectorEnv(*args, **kwargs)
|
|
1921
|
+
|
|
1922
|
+
def _build_env(
|
|
1923
|
+
self,
|
|
1924
|
+
env_name: str,
|
|
1925
|
+
**kwargs,
|
|
1926
|
+
) -> gym.core.Env: # noqa: F821
|
|
1927
|
+
if not _has_gym:
|
|
1928
|
+
raise RuntimeError(
|
|
1929
|
+
f"gym not found, unable to create {env_name}. "
|
|
1930
|
+
f"Consider downloading and installing gym from"
|
|
1931
|
+
f" {self.git_url}"
|
|
1932
|
+
)
|
|
1933
|
+
from_pixels = kwargs.pop("from_pixels", False)
|
|
1934
|
+
self._set_gym_default(kwargs, from_pixels)
|
|
1935
|
+
pixels_only = kwargs.pop("pixels_only", True)
|
|
1936
|
+
num_envs = kwargs.pop("num_envs", 0)
|
|
1937
|
+
made_env = False
|
|
1938
|
+
kwargs["frameskip"] = self.frame_skip
|
|
1939
|
+
self.wrapper_frame_skip = 1
|
|
1940
|
+
while not made_env:
|
|
1941
|
+
# env.__init__ may not be compatible with all the kwargs that
|
|
1942
|
+
# have been preset. We iterate through the various solutions
|
|
1943
|
+
# to find the config that works.
|
|
1944
|
+
try:
|
|
1945
|
+
with warnings.catch_warnings(record=True) as w:
|
|
1946
|
+
if env_name.startswith("ALE/"):
|
|
1947
|
+
try:
|
|
1948
|
+
import ale_py # noqa: F401
|
|
1949
|
+
except ImportError as err:
|
|
1950
|
+
torchrl_logger.warning(
|
|
1951
|
+
f"ale_py not found, this may cause issues with ALE environments: {err}"
|
|
1952
|
+
)
|
|
1953
|
+
# we catch warnings as they may cause silent bugs
|
|
1954
|
+
env = self.lib.make(env_name, **kwargs)
|
|
1955
|
+
if len(w) and "frameskip" in str(w[-1].message):
|
|
1956
|
+
raise TypeError("unexpected keyword argument 'frameskip'")
|
|
1957
|
+
made_env = True
|
|
1958
|
+
except TypeError as err:
|
|
1959
|
+
if ACCEPTED_TYPE_ERRORS["frame_skip"] in str(err):
|
|
1960
|
+
# we can disable this, not strictly indispensable to know
|
|
1961
|
+
# warn(
|
|
1962
|
+
# "Discarding frameskip arg. This will be taken care of by TorchRL env wrapper."
|
|
1963
|
+
# )
|
|
1964
|
+
self.wrapper_frame_skip = kwargs.pop("frameskip")
|
|
1965
|
+
elif ACCEPTED_TYPE_ERRORS["render_mode"] in str(err):
|
|
1966
|
+
warn("Discarding render_mode from the env constructor.")
|
|
1967
|
+
kwargs.pop("render_mode")
|
|
1968
|
+
else:
|
|
1969
|
+
raise err
|
|
1970
|
+
env = super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels)
|
|
1971
|
+
if num_envs > 0:
|
|
1972
|
+
make_fn = partial(self.lib.make, env_name, **kwargs)
|
|
1973
|
+
env = self._async_env([make_fn] * num_envs)
|
|
1974
|
+
self.batch_size = torch.Size([num_envs, *self.batch_size])
|
|
1975
|
+
return env
|
|
1976
|
+
|
|
1977
|
+
@implement_for("gym", None, "0.25.0")
|
|
1978
|
+
def _set_gym_default(self, kwargs, from_pixels: bool) -> None: # noqa: F811
|
|
1979
|
+
# Do nothing for older gym versions (render_mode was introduced in 0.25.0).
|
|
1980
|
+
pass
|
|
1981
|
+
|
|
1982
|
+
@implement_for("gym", "0.25.0", None)
|
|
1983
|
+
def _set_gym_default(self, kwargs, from_pixels: bool) -> None: # noqa: F811
|
|
1984
|
+
if from_pixels:
|
|
1985
|
+
kwargs.setdefault("render_mode", "rgb_array")
|
|
1986
|
+
|
|
1987
|
+
@implement_for("gymnasium", None, "0.27.0")
|
|
1988
|
+
def _set_gym_default(self, kwargs, from_pixels: bool) -> None: # noqa: F811
|
|
1989
|
+
# gymnasium < 0.27.0 also supports render_mode (forked from gym 0.26+)
|
|
1990
|
+
if from_pixels:
|
|
1991
|
+
kwargs.setdefault("render_mode", "rgb_array")
|
|
1992
|
+
|
|
1993
|
+
@implement_for("gymnasium", "0.27.0", None)
|
|
1994
|
+
def _set_gym_default(self, kwargs, from_pixels: bool) -> None: # noqa: F811
|
|
1995
|
+
if from_pixels:
|
|
1996
|
+
kwargs.setdefault("render_mode", "rgb_array")
|
|
1997
|
+
|
|
1998
|
+
@property
|
|
1999
|
+
def env_name(self):
|
|
2000
|
+
return self._constructor_kwargs["env_name"]
|
|
2001
|
+
|
|
2002
|
+
def _check_kwargs(self, kwargs: dict):
|
|
2003
|
+
if "env_name" not in kwargs:
|
|
2004
|
+
raise TypeError("Expected 'env_name' to be part of kwargs")
|
|
2005
|
+
|
|
2006
|
+
def __repr__(self) -> str:
|
|
2007
|
+
return f"{self.__class__.__name__}(env={self.env_name}, batch_size={self.batch_size}, device={self.device})"
|
|
2008
|
+
|
|
2009
|
+
|
|
2010
|
+
class MOGymWrapper(GymWrapper):
|
|
2011
|
+
"""FARAMA MO-Gymnasium environment wrapper.
|
|
2012
|
+
|
|
2013
|
+
Examples:
|
|
2014
|
+
>>> import mo_gymnasium as mo_gym
|
|
2015
|
+
>>> env = MOGymWrapper(mo_gym.make('minecart-v0'), frame_skip=4)
|
|
2016
|
+
>>> td = env.rand_step()
|
|
2017
|
+
>>> print(td)
|
|
2018
|
+
|
|
2019
|
+
"""
|
|
2020
|
+
|
|
2021
|
+
git_url = "https://github.com/Farama-Foundation/MO-Gymnasium"
|
|
2022
|
+
libname = "mo-gymnasium"
|
|
2023
|
+
|
|
2024
|
+
_make_specs = set_gym_backend("gymnasium")(GymEnv._make_specs)
|
|
2025
|
+
|
|
2026
|
+
@_classproperty
|
|
2027
|
+
def available_envs(cls):
|
|
2028
|
+
if not _has_mo:
|
|
2029
|
+
return []
|
|
2030
|
+
return [
|
|
2031
|
+
"deep-sea-treasure-v0",
|
|
2032
|
+
"deep-sea-treasure-concave-v0",
|
|
2033
|
+
"resource-gathering-v0",
|
|
2034
|
+
"fishwood-v0",
|
|
2035
|
+
"breakable-bottles-v0",
|
|
2036
|
+
"fruit-tree-v0",
|
|
2037
|
+
"water-reservoir-v0",
|
|
2038
|
+
"four-room-v0",
|
|
2039
|
+
"mo-mountaincar-v0",
|
|
2040
|
+
"mo-mountaincarcontinuous-v0",
|
|
2041
|
+
"mo-lunar-lander-v2",
|
|
2042
|
+
"minecart-v0",
|
|
2043
|
+
"mo-highway-v0",
|
|
2044
|
+
"mo-highway-fast-v0",
|
|
2045
|
+
"mo-supermario-v0",
|
|
2046
|
+
"mo-reacher-v4",
|
|
2047
|
+
"mo-hopper-v4",
|
|
2048
|
+
"mo-halfcheetah-v4",
|
|
2049
|
+
]
|
|
2050
|
+
|
|
2051
|
+
|
|
2052
|
+
class MOGymEnv(GymEnv):
|
|
2053
|
+
"""FARAMA MO-Gymnasium environment wrapper.
|
|
2054
|
+
|
|
2055
|
+
Examples:
|
|
2056
|
+
>>> env = MOGymEnv(env_name="minecart-v0", frame_skip=4)
|
|
2057
|
+
>>> td = env.rand_step()
|
|
2058
|
+
>>> print(td)
|
|
2059
|
+
>>> print(env.available_envs)
|
|
2060
|
+
|
|
2061
|
+
"""
|
|
2062
|
+
|
|
2063
|
+
git_url = "https://github.com/Farama-Foundation/MO-Gymnasium"
|
|
2064
|
+
libname = "mo-gymnasium"
|
|
2065
|
+
|
|
2066
|
+
available_envs = MOGymWrapper.available_envs
|
|
2067
|
+
|
|
2068
|
+
@property
|
|
2069
|
+
def lib(self) -> ModuleType:
|
|
2070
|
+
if _has_mo:
|
|
2071
|
+
import mo_gymnasium as mo_gym
|
|
2072
|
+
|
|
2073
|
+
return mo_gym
|
|
2074
|
+
else:
|
|
2075
|
+
try:
|
|
2076
|
+
import mo_gymnasium # noqa: F401
|
|
2077
|
+
except ImportError as err:
|
|
2078
|
+
raise ImportError("MO-gymnasium not found, check installation") from err
|
|
2079
|
+
|
|
2080
|
+
_make_specs = set_gym_backend("gymnasium")(GymEnv._make_specs)
|
|
2081
|
+
|
|
2082
|
+
|
|
2083
|
+
class terminal_obs_reader(default_info_dict_reader):
|
|
2084
|
+
"""Terminal observation reader for 'vectorized' gym environments.
|
|
2085
|
+
|
|
2086
|
+
When running envs in parallel, Gym(nasium) writes the result of the true call
|
|
2087
|
+
to `step` in `"final_observation"` entry within the `info` dictionary.
|
|
2088
|
+
|
|
2089
|
+
This breaks the natural flow and makes single-processed and multiprocessed envs
|
|
2090
|
+
incompatible.
|
|
2091
|
+
|
|
2092
|
+
This class reads the info obs, removes the `"final_observation"` from
|
|
2093
|
+
the env and writes its content in the data.
|
|
2094
|
+
|
|
2095
|
+
Next, a :class:`torchrl.envs.VecGymEnvTransform` transform will reorganise the
|
|
2096
|
+
data by caching the result of the (implicit) reset and swap the true next
|
|
2097
|
+
observation with the reset one. At reset time, the true reset data will be
|
|
2098
|
+
replaced.
|
|
2099
|
+
|
|
2100
|
+
Args:
|
|
2101
|
+
observation_spec (Composite): The observation spec of the gym env.
|
|
2102
|
+
backend (str, optional): the backend of the env. One of `"sb3"` for
|
|
2103
|
+
stable-baselines3 or `"gym"` for gym/gymnasium.
|
|
2104
|
+
|
|
2105
|
+
.. note:: In general, this class should not be handled directly. It is
|
|
2106
|
+
created whenever a vectorized environment is placed within a :class:`GymWrapper`.
|
|
2107
|
+
|
|
2108
|
+
"""
|
|
2109
|
+
|
|
2110
|
+
backend_key = {
|
|
2111
|
+
"sb3": "terminal_observation",
|
|
2112
|
+
"gym": "final_observation",
|
|
2113
|
+
"gymnasium": "final_obs",
|
|
2114
|
+
}
|
|
2115
|
+
backend_info_key = {
|
|
2116
|
+
"sb3": "terminal_info",
|
|
2117
|
+
"gym": "final_info",
|
|
2118
|
+
"gymnasium": "final_info",
|
|
2119
|
+
}
|
|
2120
|
+
|
|
2121
|
+
def __init__(self, observation_spec: Composite, backend, name="final"):
|
|
2122
|
+
super().__init__()
|
|
2123
|
+
self.name = name
|
|
2124
|
+
self._obs_spec = observation_spec.clone()
|
|
2125
|
+
self.backend = backend
|
|
2126
|
+
self._final_validated = False
|
|
2127
|
+
|
|
2128
|
+
@property
|
|
2129
|
+
def info_spec(self):
|
|
2130
|
+
return self._info_spec
|
|
2131
|
+
|
|
2132
|
+
def _read_obs(self, obs, key, tensor, index):
|
|
2133
|
+
if obs is None:
|
|
2134
|
+
return
|
|
2135
|
+
if isinstance(obs, np.ndarray):
|
|
2136
|
+
# Simplest case: there is one observation,
|
|
2137
|
+
# presented as a np.ndarray. The key should be pixels or observation.
|
|
2138
|
+
# We just write that value at its location in the tensor
|
|
2139
|
+
tensor[index] = torch.as_tensor(obs, device=tensor.device)
|
|
2140
|
+
if isinstance(obs, torch.Tensor):
|
|
2141
|
+
# Simplest case: there is one observation,
|
|
2142
|
+
# presented as a np.ndarray. The key should be pixels or observation.
|
|
2143
|
+
# We just write that value at its location in the tensor
|
|
2144
|
+
tensor[index] = obs.to(device=tensor.device)
|
|
2145
|
+
elif isinstance(obs, dict):
|
|
2146
|
+
if key not in obs:
|
|
2147
|
+
raise KeyError(
|
|
2148
|
+
f"The observation {key} could not be found in the final observation dict."
|
|
2149
|
+
)
|
|
2150
|
+
subobs = obs[key]
|
|
2151
|
+
if subobs is not None:
|
|
2152
|
+
# if the obs is a dict, we expect that the key points also to
|
|
2153
|
+
# a value in the obs. We retrieve this value and write it in the
|
|
2154
|
+
# tensor
|
|
2155
|
+
tensor[index] = torch.as_tensor(subobs, device=tensor.device)
|
|
2156
|
+
|
|
2157
|
+
elif isinstance(obs, (list, tuple)):
|
|
2158
|
+
# tuples are stacked along the first dimension when passing gym spaces
|
|
2159
|
+
# to torchrl specs. As such, we can simply stack the tuple and set it
|
|
2160
|
+
# at the relevant index (assuming stacking can be achieved)
|
|
2161
|
+
tensor[index] = torch.as_tensor(obs, device=tensor.device)
|
|
2162
|
+
else:
|
|
2163
|
+
raise NotImplementedError(
|
|
2164
|
+
f"Observations of type {type(obs)} are not supported yet."
|
|
2165
|
+
)
|
|
2166
|
+
|
|
2167
|
+
def __call__(self, info_dict, tensordict):
|
|
2168
|
+
# TODO: This is a tad slow, we iterate over each sub-env and call spec.zero() at each step.
|
|
2169
|
+
# In theory we could spare that whole thing but we need to run it once at the beginning if specs
|
|
2170
|
+
# of the info reader are not passed as we need to observe the data to infer the spec.
|
|
2171
|
+
# We should find a way to avoid this call altogether is no env is resetting.
|
|
2172
|
+
def replace_none(nparray):
|
|
2173
|
+
if not isinstance(nparray, np.ndarray) or nparray.dtype != np.dtype("O"):
|
|
2174
|
+
return nparray
|
|
2175
|
+
is_none = np.array([info is None for info in nparray])
|
|
2176
|
+
if is_none.any():
|
|
2177
|
+
# Then it is a final observation and we delegate the registration to the appropriate reader
|
|
2178
|
+
nz = (~is_none).nonzero()[0][0]
|
|
2179
|
+
zero_like = tree_map(lambda x: np.zeros_like(x), nparray[nz])
|
|
2180
|
+
for idx in is_none.nonzero()[0]:
|
|
2181
|
+
nparray[idx] = zero_like
|
|
2182
|
+
# tree_map with multiple trees was added in PyTorch 2.2
|
|
2183
|
+
if TORCH_VERSION >= version.parse("2.2"):
|
|
2184
|
+
return tree_map(lambda *x: np.stack(x), *nparray)
|
|
2185
|
+
else:
|
|
2186
|
+
# For older PyTorch versions, manually flatten/unflatten
|
|
2187
|
+
flat_lists_specs = [tree_flatten(tree) for tree in nparray]
|
|
2188
|
+
flat_lists = [fl for fl, _ in flat_lists_specs]
|
|
2189
|
+
spec = flat_lists_specs[0][1]
|
|
2190
|
+
stacked = [np.stack(elems) for elems in zip(*flat_lists)]
|
|
2191
|
+
return tree_unflatten(stacked, spec)
|
|
2192
|
+
|
|
2193
|
+
info_dict = tree_map(replace_none, info_dict)
|
|
2194
|
+
# convert info_dict to a tensordict
|
|
2195
|
+
info_dict = TensorDict(info_dict)
|
|
2196
|
+
# get the terminal observation
|
|
2197
|
+
terminal_obs = info_dict.pop(self.backend_key[self.backend], None)
|
|
2198
|
+
# get the terminal info dict
|
|
2199
|
+
terminal_info = info_dict.pop(self.backend_info_key[self.backend], None)
|
|
2200
|
+
|
|
2201
|
+
if terminal_info is None:
|
|
2202
|
+
terminal_info = {}
|
|
2203
|
+
|
|
2204
|
+
super().__call__(info_dict, tensordict)
|
|
2205
|
+
if not self._final_validated:
|
|
2206
|
+
self.info_spec[self.name] = self._obs_spec.update(self.info_spec)
|
|
2207
|
+
self._final_validated = True
|
|
2208
|
+
|
|
2209
|
+
final_info = terminal_info.copy()
|
|
2210
|
+
if terminal_obs is not None:
|
|
2211
|
+
final_info["observation"] = terminal_obs
|
|
2212
|
+
|
|
2213
|
+
for key in self.info_spec[self.name].keys():
|
|
2214
|
+
spec = self.info_spec[self.name, key]
|
|
2215
|
+
|
|
2216
|
+
final_obs_buffer = spec.zero()
|
|
2217
|
+
terminal_obs = final_info.get(key, None)
|
|
2218
|
+
if terminal_obs is not None:
|
|
2219
|
+
for i, obs in enumerate(terminal_obs):
|
|
2220
|
+
# writes final_obs inplace with terminal_obs content
|
|
2221
|
+
self._read_obs(obs, key, final_obs_buffer, index=i)
|
|
2222
|
+
tensordict.set((self.name, key), final_obs_buffer)
|
|
2223
|
+
return tensordict
|
|
2224
|
+
|
|
2225
|
+
def reset(self):
|
|
2226
|
+
super().reset()
|
|
2227
|
+
self._final_validated = False
|
|
2228
|
+
|
|
2229
|
+
|
|
2230
|
+
def _flip_info_tuple(info: tuple[dict]) -> dict[str, tuple]:
|
|
2231
|
+
# In Gym < 0.24, batched envs returned tuples of dict, and not dict of tuples.
|
|
2232
|
+
# We patch this by flipping the tuple -> dict order.
|
|
2233
|
+
info_example = set(info[0])
|
|
2234
|
+
for item in info[1:]:
|
|
2235
|
+
info_example = info_example.union(item)
|
|
2236
|
+
result = {}
|
|
2237
|
+
for key in info_example:
|
|
2238
|
+
result[key] = tuple(_info.get(key, None) for _info in info)
|
|
2239
|
+
return result
|