torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
torchrl/envs/common.py
ADDED
|
@@ -0,0 +1,4241 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import abc
|
|
9
|
+
import re
|
|
10
|
+
import warnings
|
|
11
|
+
import weakref
|
|
12
|
+
from collections.abc import Callable, Iterator, Sequence
|
|
13
|
+
from copy import deepcopy
|
|
14
|
+
from functools import partial, wraps
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn as nn
|
|
20
|
+
from tensordict import (
|
|
21
|
+
is_tensor_collection,
|
|
22
|
+
LazyStackedTensorDict,
|
|
23
|
+
TensorDictBase,
|
|
24
|
+
unravel_key,
|
|
25
|
+
)
|
|
26
|
+
from tensordict.base import _is_leaf_nontensor, NO_DEFAULT
|
|
27
|
+
from tensordict.utils import is_non_tensor, NestedKey
|
|
28
|
+
from torchrl._utils import (
|
|
29
|
+
_ends_with,
|
|
30
|
+
_make_ordinal_device,
|
|
31
|
+
_replace_last,
|
|
32
|
+
implement_for,
|
|
33
|
+
prod,
|
|
34
|
+
seed_generator,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
from torchrl.data.tensor_specs import (
|
|
38
|
+
Categorical,
|
|
39
|
+
Composite,
|
|
40
|
+
NonTensor,
|
|
41
|
+
TensorSpec,
|
|
42
|
+
Unbounded,
|
|
43
|
+
)
|
|
44
|
+
from torchrl.data.utils import DEVICE_TYPING
|
|
45
|
+
from torchrl.envs.utils import (
|
|
46
|
+
_make_compatible_policy,
|
|
47
|
+
_repr_by_depth,
|
|
48
|
+
_StepMDP,
|
|
49
|
+
_terminated_or_truncated,
|
|
50
|
+
_update_during_reset,
|
|
51
|
+
check_env_specs as check_env_specs_func,
|
|
52
|
+
get_available_libraries,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
LIBRARIES = get_available_libraries()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _tensor_to_np(t):
|
|
59
|
+
return t.detach().cpu().numpy()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
dtype_map = {
|
|
63
|
+
torch.float: np.float32,
|
|
64
|
+
torch.double: np.float64,
|
|
65
|
+
torch.bool: bool,
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _maybe_unlock(func):
|
|
70
|
+
@wraps(func)
|
|
71
|
+
def wrapper(self, *args, **kwargs):
|
|
72
|
+
is_locked = self.is_spec_locked
|
|
73
|
+
try:
|
|
74
|
+
if is_locked:
|
|
75
|
+
self.set_spec_lock_(False)
|
|
76
|
+
result = func(self, *args, **kwargs)
|
|
77
|
+
finally:
|
|
78
|
+
if is_locked:
|
|
79
|
+
if not hasattr(self, "_cache"):
|
|
80
|
+
self._cache = {}
|
|
81
|
+
self._cache.clear()
|
|
82
|
+
self.set_spec_lock_(True)
|
|
83
|
+
return result
|
|
84
|
+
|
|
85
|
+
return wrapper
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _cache_value(func):
|
|
89
|
+
"""Caches the result of the decorated function in env._cache dictionary."""
|
|
90
|
+
func_name = func.__name__
|
|
91
|
+
|
|
92
|
+
@wraps(func)
|
|
93
|
+
def wrapper(self, *args, **kwargs):
|
|
94
|
+
if not self.is_spec_locked:
|
|
95
|
+
return func(self, *args, **kwargs)
|
|
96
|
+
result = self.__dict__.setdefault("_cache", {}).get(func_name, NO_DEFAULT)
|
|
97
|
+
if result is NO_DEFAULT:
|
|
98
|
+
result = func(self, *args, **kwargs)
|
|
99
|
+
self.__dict__.setdefault("_cache", {})[func_name] = result
|
|
100
|
+
return result
|
|
101
|
+
|
|
102
|
+
return wrapper
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _clear_cache_when_set(func):
|
|
106
|
+
"""A decorator for EnvBase methods that should clear the caches when called."""
|
|
107
|
+
|
|
108
|
+
@wraps(func)
|
|
109
|
+
def wrapper(self, *args, **kwargs):
|
|
110
|
+
# if there's no cache we'll just recompute the value
|
|
111
|
+
if "_cache" not in self.__dict__:
|
|
112
|
+
self._cache = {}
|
|
113
|
+
else:
|
|
114
|
+
self._cache.clear()
|
|
115
|
+
result = func(self, *args, **kwargs)
|
|
116
|
+
self._cache.clear()
|
|
117
|
+
return result
|
|
118
|
+
|
|
119
|
+
return wrapper
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class EnvMetaData:
|
|
123
|
+
"""A class for environment meta-data storage and passing in multiprocessed settings."""
|
|
124
|
+
|
|
125
|
+
def __init__(
|
|
126
|
+
self,
|
|
127
|
+
*,
|
|
128
|
+
tensordict: TensorDictBase,
|
|
129
|
+
specs: Composite,
|
|
130
|
+
batch_size: torch.Size,
|
|
131
|
+
env_str: str,
|
|
132
|
+
device: torch.device,
|
|
133
|
+
batch_locked: bool,
|
|
134
|
+
device_map: dict,
|
|
135
|
+
):
|
|
136
|
+
self.device = device
|
|
137
|
+
self.tensordict = tensordict
|
|
138
|
+
self.specs = specs
|
|
139
|
+
self.batch_size = batch_size
|
|
140
|
+
self.env_str = env_str
|
|
141
|
+
self.batch_locked = batch_locked
|
|
142
|
+
self.device_map = device_map
|
|
143
|
+
self.has_dynamic_specs = _has_dynamic_specs(specs)
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def tensordict(self) -> TensorDictBase:
|
|
147
|
+
td = self._tensordict.copy()
|
|
148
|
+
if td.device != self.device:
|
|
149
|
+
if self.device is None:
|
|
150
|
+
return td.clear_device_()
|
|
151
|
+
else:
|
|
152
|
+
return td.to(self.device)
|
|
153
|
+
return td
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def specs(self):
|
|
157
|
+
return self._specs.to(self.device)
|
|
158
|
+
|
|
159
|
+
@tensordict.setter
|
|
160
|
+
def tensordict(self, value: TensorDictBase):
|
|
161
|
+
self._tensordict = value.to("cpu")
|
|
162
|
+
|
|
163
|
+
@specs.setter
|
|
164
|
+
def specs(self, value: Composite):
|
|
165
|
+
self._specs = value.to("cpu")
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def metadata_from_env(env) -> EnvMetaData:
|
|
169
|
+
tensordict = env.fake_tensordict().clone()
|
|
170
|
+
|
|
171
|
+
for done_key in env.done_keys:
|
|
172
|
+
tensordict.set(
|
|
173
|
+
_replace_last(done_key, "_reset"),
|
|
174
|
+
torch.zeros_like(tensordict.get(("next", done_key))),
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
specs = env.specs.to("cpu")
|
|
178
|
+
|
|
179
|
+
batch_size = env.batch_size
|
|
180
|
+
try:
|
|
181
|
+
env_str = str(env)
|
|
182
|
+
except Exception:
|
|
183
|
+
env_str = f"{env.__class__.__name__}()"
|
|
184
|
+
device = env.device
|
|
185
|
+
specs = specs.to("cpu")
|
|
186
|
+
batch_locked = env.batch_locked
|
|
187
|
+
# we need to save the device map, as the tensordict will be placed on cpu
|
|
188
|
+
device_map = {}
|
|
189
|
+
|
|
190
|
+
def fill_device_map(name, val, device_map=device_map):
|
|
191
|
+
device_map[name] = val.device
|
|
192
|
+
|
|
193
|
+
tensordict.named_apply(fill_device_map, nested_keys=True, filter_empty=True)
|
|
194
|
+
return EnvMetaData(
|
|
195
|
+
tensordict=tensordict,
|
|
196
|
+
specs=specs,
|
|
197
|
+
batch_size=batch_size,
|
|
198
|
+
env_str=env_str,
|
|
199
|
+
device=device,
|
|
200
|
+
batch_locked=batch_locked,
|
|
201
|
+
device_map=device_map,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
def expand(self, *size: int) -> EnvMetaData:
|
|
205
|
+
tensordict = self.tensordict.expand(*size).clone()
|
|
206
|
+
batch_size = torch.Size(list(size))
|
|
207
|
+
return EnvMetaData(
|
|
208
|
+
tensordict=tensordict,
|
|
209
|
+
specs=self.specs.expand(*size),
|
|
210
|
+
batch_size=batch_size,
|
|
211
|
+
env_str=self.env_str,
|
|
212
|
+
device=self.device,
|
|
213
|
+
batch_locked=self.batch_locked,
|
|
214
|
+
device_map=self.device_map,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
def clone(self):
|
|
218
|
+
return EnvMetaData(
|
|
219
|
+
tensordict=self.tensordict.clone(),
|
|
220
|
+
specs=self.specs.clone(),
|
|
221
|
+
batch_size=torch.Size([*self.batch_size]),
|
|
222
|
+
env_str=deepcopy(self.env_str),
|
|
223
|
+
device=self.device,
|
|
224
|
+
batch_locked=self.batch_locked,
|
|
225
|
+
device_map=self.device_map,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
def to(self, device: DEVICE_TYPING) -> EnvMetaData:
|
|
229
|
+
if device is not None:
|
|
230
|
+
device = _make_ordinal_device(torch.device(device))
|
|
231
|
+
device_map = {key: device for key in self.device_map}
|
|
232
|
+
tensordict = self.tensordict.contiguous().to(device)
|
|
233
|
+
specs = self.specs.to(device)
|
|
234
|
+
return EnvMetaData(
|
|
235
|
+
tensordict=tensordict,
|
|
236
|
+
specs=specs,
|
|
237
|
+
batch_size=self.batch_size,
|
|
238
|
+
env_str=self.env_str,
|
|
239
|
+
device=device,
|
|
240
|
+
batch_locked=self.batch_locked,
|
|
241
|
+
device_map=device_map,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def __getitem__(self, item):
|
|
245
|
+
from tensordict.utils import _getitem_batch_size
|
|
246
|
+
|
|
247
|
+
return EnvMetaData(
|
|
248
|
+
tensordict=self.tensordict[item],
|
|
249
|
+
specs=self.specs[item],
|
|
250
|
+
batch_size=_getitem_batch_size(self.batch_size, item),
|
|
251
|
+
env_str=self.env_str,
|
|
252
|
+
device=self.device,
|
|
253
|
+
batch_locked=self.batch_locked,
|
|
254
|
+
device_map=self.device_map,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class _EnvPostInit(abc.ABCMeta):
|
|
259
|
+
def __call__(cls, *args, **kwargs):
|
|
260
|
+
spec_locked = kwargs.pop("spec_locked", True)
|
|
261
|
+
auto_reset = kwargs.pop("auto_reset", False)
|
|
262
|
+
auto_reset_replace = kwargs.pop("auto_reset_replace", True)
|
|
263
|
+
instance: EnvBase = super().__call__(*args, **kwargs)
|
|
264
|
+
if "_cache" not in instance.__dict__:
|
|
265
|
+
instance._cache = {}
|
|
266
|
+
|
|
267
|
+
if spec_locked:
|
|
268
|
+
instance.input_spec.lock_(recurse=True)
|
|
269
|
+
instance.output_spec.lock_(recurse=True)
|
|
270
|
+
instance._is_spec_locked = spec_locked
|
|
271
|
+
|
|
272
|
+
# we create the done spec by adding a done/terminated entry if one is missing
|
|
273
|
+
instance._create_done_specs()
|
|
274
|
+
# we access lazy attributed to make sure they're built properly.
|
|
275
|
+
# This isn't done in `__init__` because we don't know if super().__init__
|
|
276
|
+
# will be called before or after the specs, batch size etc are set.
|
|
277
|
+
_ = instance.done_spec
|
|
278
|
+
_ = instance.reward_keys
|
|
279
|
+
# _ = instance.action_keys
|
|
280
|
+
_ = instance.state_spec
|
|
281
|
+
if auto_reset:
|
|
282
|
+
from torchrl.envs.transforms.transforms import (
|
|
283
|
+
AutoResetEnv,
|
|
284
|
+
AutoResetTransform,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
return AutoResetEnv(
|
|
288
|
+
instance, AutoResetTransform(replace=auto_reset_replace)
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
done_keys = set(instance.full_done_spec.keys(True, True))
|
|
292
|
+
obs_keys = set(instance.full_observation_spec.keys(True, True))
|
|
293
|
+
reward_keys = set(instance.full_reward_spec.keys(True, True))
|
|
294
|
+
# state_keys can match obs_keys so we don't test that
|
|
295
|
+
action_keys = set(instance.full_action_spec.keys(True, True))
|
|
296
|
+
state_keys = set(instance.full_state_spec.keys(True, True))
|
|
297
|
+
total_set = set()
|
|
298
|
+
for keyset in (done_keys, obs_keys, reward_keys):
|
|
299
|
+
if total_set.intersection(keyset):
|
|
300
|
+
raise RuntimeError(
|
|
301
|
+
f"The set of keys of one spec collides (culprit: {total_set.intersection(keyset)}) with another."
|
|
302
|
+
)
|
|
303
|
+
total_set = total_set.union(keyset)
|
|
304
|
+
total_set = set()
|
|
305
|
+
for keyset in (state_keys, action_keys):
|
|
306
|
+
if total_set.intersection(keyset):
|
|
307
|
+
raise RuntimeError(
|
|
308
|
+
f"The set of keys of one spec collides (culprit: {total_set.intersection(keyset)}) with another."
|
|
309
|
+
)
|
|
310
|
+
total_set = total_set.union(keyset)
|
|
311
|
+
return instance
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
class EnvBase(nn.Module, metaclass=_EnvPostInit):
|
|
315
|
+
"""Abstract environment parent class.
|
|
316
|
+
|
|
317
|
+
Keyword Args:
|
|
318
|
+
device (torch.device): The device of the environment. Deviceless environments
|
|
319
|
+
are allowed (device=None). If not ``None``, all specs will be cast
|
|
320
|
+
on that device and it is expected that all inputs and outputs will
|
|
321
|
+
live on that device.
|
|
322
|
+
Defaults to ``None``.
|
|
323
|
+
batch_size (torch.Size or equivalent, optional): batch-size of the environment.
|
|
324
|
+
Corresponds to the leading dimension of all the input and output
|
|
325
|
+
tensordicts the environment reads and writes. Defaults to an empty batch-size.
|
|
326
|
+
run_type_checks (bool, optional): If ``True``, type-checks will occur
|
|
327
|
+
at every reset and every step. Defaults to ``False``.
|
|
328
|
+
allow_done_after_reset (bool, optional): if ``True``, an environment can
|
|
329
|
+
be done after a call to :meth:`reset` is made. Defaults to ``False``.
|
|
330
|
+
spec_locked (bool, optional): if ``True``, the specs are locked and can only be
|
|
331
|
+
modified if :meth:`~torchrl.envs.EnvBase.set_spec_lock_` is called.
|
|
332
|
+
|
|
333
|
+
.. note:: The locking is achieved by the `EnvBase` metaclass. It does not appear in the
|
|
334
|
+
`__init__` method and is included in the keyword arguments strictly for type-hinting purpose.
|
|
335
|
+
|
|
336
|
+
.. seealso:: :ref:`Locking environment specs <Environment-lock>`.
|
|
337
|
+
|
|
338
|
+
Defaults to ``True``.
|
|
339
|
+
auto_reset (bool, optional): if ``True``, the env is assumed to reset automatically
|
|
340
|
+
when done. Defaults to ``False``.
|
|
341
|
+
|
|
342
|
+
.. note:: The auto-resetting is achieved by the `EnvBase` metaclass. It does not appear in the
|
|
343
|
+
`__init__` method and is included in the keyword arguments strictly for type-hinting purpose.
|
|
344
|
+
|
|
345
|
+
.. seealso:: The :ref:`auto-resetting environments API <autoresetting_envs>` section in the API
|
|
346
|
+
documentation.
|
|
347
|
+
|
|
348
|
+
Attributes:
|
|
349
|
+
done_spec (Composite): equivalent to ``full_done_spec`` as all
|
|
350
|
+
``done_specs`` contain at least a ``"done"`` and a ``"terminated"`` entry
|
|
351
|
+
action_spec (TensorSpec): the spec of the action. Links to the spec of the leaf
|
|
352
|
+
action if only one action tensor is to be expected. Otherwise links to
|
|
353
|
+
``full_action_spec``.
|
|
354
|
+
observation_spec (Composite): equivalent to ``full_observation_spec``.
|
|
355
|
+
reward_spec (TensorSpec): the spec of the reward. Links to the spec of the leaf
|
|
356
|
+
reward if only one reward tensor is to be expected. Otherwise links to
|
|
357
|
+
``full_reward_spec``.
|
|
358
|
+
state_spec (Composite): equivalent to ``full_state_spec``.
|
|
359
|
+
full_done_spec (Composite): a composite spec such that ``full_done_spec.zero()``
|
|
360
|
+
returns a tensordict containing only the leaves encoding the done status of the
|
|
361
|
+
environment.
|
|
362
|
+
full_action_spec (Composite): a composite spec such that ``full_action_spec.zero()``
|
|
363
|
+
returns a tensordict containing only the leaves encoding the action of the
|
|
364
|
+
environment.
|
|
365
|
+
full_observation_spec (Composite): a composite spec such that ``full_observation_spec.zero()``
|
|
366
|
+
returns a tensordict containing only the leaves encoding the observation of the
|
|
367
|
+
environment.
|
|
368
|
+
full_reward_spec (Composite): a composite spec such that ``full_reward_spec.zero()``
|
|
369
|
+
returns a tensordict containing only the leaves encoding the reward of the
|
|
370
|
+
environment.
|
|
371
|
+
full_state_spec (Composite): a composite spec such that ``full_state_spec.zero()``
|
|
372
|
+
returns a tensordict containing only the leaves encoding the inputs (actions
|
|
373
|
+
excluded) of the environment.
|
|
374
|
+
batch_size (torch.Size): The batch-size of the environment.
|
|
375
|
+
device (torch.device): the device where the input/outputs of the environment
|
|
376
|
+
are to be expected. Can be ``None``.
|
|
377
|
+
is_spec_locked (bool): returns ``True`` if the specs are locked. See the :attr:`spec_locked`
|
|
378
|
+
argument above.
|
|
379
|
+
|
|
380
|
+
Methods:
|
|
381
|
+
step (TensorDictBase -> TensorDictBase): step in the environment
|
|
382
|
+
reset (TensorDictBase, optional -> TensorDictBase): reset the environment
|
|
383
|
+
set_seed (int -> int): sets the seed of the environment
|
|
384
|
+
rand_step (TensorDictBase, optional -> TensorDictBase): random step given the action spec
|
|
385
|
+
rollout (Callable, ... -> TensorDictBase): executes a rollout in the environment with the given policy (or random
|
|
386
|
+
steps if no policy is provided)
|
|
387
|
+
|
|
388
|
+
Examples:
|
|
389
|
+
>>> from torchrl.envs import EnvBase
|
|
390
|
+
>>> class CounterEnv(EnvBase):
|
|
391
|
+
... def __init__(self, batch_size=(), device=None, **kwargs):
|
|
392
|
+
... self.observation_spec = Composite(
|
|
393
|
+
... count=Unbounded(batch_size, device=device, dtype=torch.int64))
|
|
394
|
+
... self.action_spec = Unbounded(batch_size, device=device, dtype=torch.int8)
|
|
395
|
+
... # done spec and reward spec are set automatically
|
|
396
|
+
... def _step(self, tensordict):
|
|
397
|
+
...
|
|
398
|
+
>>> from torchrl.envs.libs.gym import GymEnv
|
|
399
|
+
>>> env = GymEnv("Pendulum-v1")
|
|
400
|
+
>>> env.batch_size # how many envs are run at once
|
|
401
|
+
torch.Size([])
|
|
402
|
+
>>> env.input_spec
|
|
403
|
+
Composite(
|
|
404
|
+
full_state_spec: None,
|
|
405
|
+
full_action_spec: Composite(
|
|
406
|
+
action: BoundedContinuous(
|
|
407
|
+
shape=torch.Size([1]),
|
|
408
|
+
space=ContinuousBox(
|
|
409
|
+
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
410
|
+
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
411
|
+
device=cpu,
|
|
412
|
+
dtype=torch.float32,
|
|
413
|
+
domain=continuous), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
|
|
414
|
+
>>> env.action_spec
|
|
415
|
+
BoundedContinuous(
|
|
416
|
+
shape=torch.Size([1]),
|
|
417
|
+
space=ContinuousBox(
|
|
418
|
+
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
419
|
+
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
420
|
+
device=cpu,
|
|
421
|
+
dtype=torch.float32,
|
|
422
|
+
domain=continuous)
|
|
423
|
+
>>> env.observation_spec
|
|
424
|
+
Composite(
|
|
425
|
+
observation: BoundedContinuous(
|
|
426
|
+
shape=torch.Size([3]),
|
|
427
|
+
space=ContinuousBox(
|
|
428
|
+
low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
429
|
+
high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
430
|
+
device=cpu,
|
|
431
|
+
dtype=torch.float32,
|
|
432
|
+
domain=continuous), device=cpu, shape=torch.Size([]))
|
|
433
|
+
>>> env.reward_spec
|
|
434
|
+
UnboundedContinuous(
|
|
435
|
+
shape=torch.Size([1]),
|
|
436
|
+
space=None,
|
|
437
|
+
device=cpu,
|
|
438
|
+
dtype=torch.float32,
|
|
439
|
+
domain=continuous)
|
|
440
|
+
>>> env.done_spec
|
|
441
|
+
Categorical(
|
|
442
|
+
shape=torch.Size([1]),
|
|
443
|
+
space=DiscreteBox(n=2),
|
|
444
|
+
device=cpu,
|
|
445
|
+
dtype=torch.bool,
|
|
446
|
+
domain=discrete)
|
|
447
|
+
>>> # the output_spec contains all the expected outputs
|
|
448
|
+
>>> env.output_spec
|
|
449
|
+
Composite(
|
|
450
|
+
full_reward_spec: Composite(
|
|
451
|
+
reward: UnboundedContinuous(
|
|
452
|
+
shape=torch.Size([1]),
|
|
453
|
+
space=None,
|
|
454
|
+
device=cpu,
|
|
455
|
+
dtype=torch.float32,
|
|
456
|
+
domain=continuous), device=cpu, shape=torch.Size([])),
|
|
457
|
+
full_observation_spec: Composite(
|
|
458
|
+
observation: BoundedContinuous(
|
|
459
|
+
shape=torch.Size([3]),
|
|
460
|
+
space=ContinuousBox(
|
|
461
|
+
low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
462
|
+
high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
463
|
+
device=cpu,
|
|
464
|
+
dtype=torch.float32,
|
|
465
|
+
domain=continuous), device=cpu, shape=torch.Size([])),
|
|
466
|
+
full_done_spec: Composite(
|
|
467
|
+
done: Categorical(
|
|
468
|
+
shape=torch.Size([1]),
|
|
469
|
+
space=DiscreteBox(n=2),
|
|
470
|
+
device=cpu,
|
|
471
|
+
dtype=torch.bool,
|
|
472
|
+
domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
|
|
473
|
+
|
|
474
|
+
.. note:: Learn more about dynamic specs and environments :ref:`here <dynamic_envs>`.
|
|
475
|
+
"""
|
|
476
|
+
|
|
477
|
+
_batch_size: torch.Size | None
|
|
478
|
+
_device: torch.device | None
|
|
479
|
+
_is_spec_locked: bool = False
|
|
480
|
+
|
|
481
|
+
def __init__(
|
|
482
|
+
self,
|
|
483
|
+
*,
|
|
484
|
+
device: DEVICE_TYPING | None = None,
|
|
485
|
+
batch_size: tuple | torch.Size | None = None,
|
|
486
|
+
run_type_checks: bool = False,
|
|
487
|
+
allow_done_after_reset: bool = False,
|
|
488
|
+
spec_locked: bool = True,
|
|
489
|
+
auto_reset: bool = False,
|
|
490
|
+
):
|
|
491
|
+
if "_cache" not in self.__dict__:
|
|
492
|
+
self._cache = {}
|
|
493
|
+
super().__init__()
|
|
494
|
+
|
|
495
|
+
self.__dict__.setdefault("_batch_size", None)
|
|
496
|
+
self.__dict__.setdefault("_device", None)
|
|
497
|
+
|
|
498
|
+
if batch_size is not None:
|
|
499
|
+
# we want an error to be raised if we pass batch_size but
|
|
500
|
+
# it's already been set
|
|
501
|
+
batch_size = self.batch_size = torch.Size(batch_size)
|
|
502
|
+
else:
|
|
503
|
+
batch_size = torch.Size(())
|
|
504
|
+
|
|
505
|
+
if device is not None:
|
|
506
|
+
device = self.__dict__["_device"] = _make_ordinal_device(
|
|
507
|
+
torch.device(device)
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
output_spec = self.__dict__.get("_output_spec")
|
|
511
|
+
if output_spec is None:
|
|
512
|
+
output_spec = self.__dict__["_output_spec"] = Composite(
|
|
513
|
+
shape=batch_size, device=device
|
|
514
|
+
)
|
|
515
|
+
elif self._output_spec.device != device and device is not None:
|
|
516
|
+
self.__dict__["_output_spec"] = self.__dict__["_output_spec"].to(
|
|
517
|
+
self.device
|
|
518
|
+
)
|
|
519
|
+
input_spec = self.__dict__.get("_input_spec")
|
|
520
|
+
if input_spec is None:
|
|
521
|
+
input_spec = self.__dict__["_input_spec"] = Composite(
|
|
522
|
+
shape=batch_size, device=device
|
|
523
|
+
)
|
|
524
|
+
elif self._input_spec.device != device and device is not None:
|
|
525
|
+
self.__dict__["_input_spec"] = self.__dict__["_input_spec"].to(self.device)
|
|
526
|
+
|
|
527
|
+
output_spec.unlock_(recurse=True)
|
|
528
|
+
input_spec.unlock_(recurse=True)
|
|
529
|
+
if "full_observation_spec" not in output_spec:
|
|
530
|
+
output_spec["full_observation_spec"] = Composite(batch_size=batch_size)
|
|
531
|
+
if "full_done_spec" not in output_spec:
|
|
532
|
+
output_spec["full_done_spec"] = Composite(batch_size=batch_size)
|
|
533
|
+
if "full_reward_spec" not in output_spec:
|
|
534
|
+
output_spec["full_reward_spec"] = Composite(batch_size=batch_size)
|
|
535
|
+
if "full_state_spec" not in input_spec:
|
|
536
|
+
input_spec["full_state_spec"] = Composite(batch_size=batch_size)
|
|
537
|
+
if "full_action_spec" not in input_spec:
|
|
538
|
+
input_spec["full_action_spec"] = Composite(batch_size=batch_size)
|
|
539
|
+
|
|
540
|
+
if "is_closed" not in self.__dir__():
|
|
541
|
+
self.is_closed = True
|
|
542
|
+
self._run_type_checks = run_type_checks
|
|
543
|
+
self._allow_done_after_reset = allow_done_after_reset
|
|
544
|
+
|
|
545
|
+
_collector: weakref.ReferenceType[
|
|
546
|
+
LLMCollector # noqa: F821 # type: ignore
|
|
547
|
+
] | None = None
|
|
548
|
+
|
|
549
|
+
def register_collector(self, collector: BaseCollector): # noqa: F821 # type: ignore
|
|
550
|
+
"""Registers a collector with the environment.
|
|
551
|
+
|
|
552
|
+
Args:
|
|
553
|
+
collector (BaseCollector): The collector to register.
|
|
554
|
+
"""
|
|
555
|
+
self._collector = weakref.ref(collector)
|
|
556
|
+
|
|
557
|
+
@property
|
|
558
|
+
def collector(self) -> BaseCollector | None: # noqa: F821 # type: ignore
|
|
559
|
+
"""Returns the collector associated with the container, if it exists."""
|
|
560
|
+
return self._collector() if self._collector is not None else None
|
|
561
|
+
|
|
562
|
+
def set_spec_lock_(self, mode: bool = True) -> EnvBase:
|
|
563
|
+
"""Locks or unlocks the environment's specs.
|
|
564
|
+
|
|
565
|
+
Args:
|
|
566
|
+
mode (bool): Whether to lock (`True`) or unlock (`False`) the specs. Defaults to `True`.
|
|
567
|
+
|
|
568
|
+
Returns:
|
|
569
|
+
EnvBase: The environment instance itself.
|
|
570
|
+
|
|
571
|
+
.. seealso:: :ref:`Locking environment specs <Environment-lock>`.
|
|
572
|
+
|
|
573
|
+
"""
|
|
574
|
+
output_spec = self.__dict__.get("_output_spec")
|
|
575
|
+
input_spec = self.__dict__.get("_input_spec")
|
|
576
|
+
if mode:
|
|
577
|
+
if output_spec is not None:
|
|
578
|
+
output_spec.lock_(recurse=True)
|
|
579
|
+
if input_spec is not None:
|
|
580
|
+
input_spec.lock_(recurse=True)
|
|
581
|
+
else:
|
|
582
|
+
self._cache.clear()
|
|
583
|
+
if output_spec is not None:
|
|
584
|
+
output_spec.unlock_(recurse=True)
|
|
585
|
+
if input_spec is not None:
|
|
586
|
+
input_spec.unlock_(recurse=True)
|
|
587
|
+
self.__dict__["_is_spec_locked"] = mode
|
|
588
|
+
return self
|
|
589
|
+
|
|
590
|
+
@property
|
|
591
|
+
def is_spec_locked(self):
|
|
592
|
+
"""Gets whether the environment's specs are locked.
|
|
593
|
+
|
|
594
|
+
This property can be modified directly.
|
|
595
|
+
|
|
596
|
+
Returns:
|
|
597
|
+
bool: True if the specs are locked, False otherwise.
|
|
598
|
+
|
|
599
|
+
.. seealso:: :ref:`Locking environment specs <Environment-lock>`.
|
|
600
|
+
|
|
601
|
+
"""
|
|
602
|
+
return self.__dict__.get("_is_spec_locked", False)
|
|
603
|
+
|
|
604
|
+
@is_spec_locked.setter
|
|
605
|
+
def is_spec_locked(self, value: bool):
|
|
606
|
+
self.set_spec_lock_(value)
|
|
607
|
+
|
|
608
|
+
def auto_specs_(
|
|
609
|
+
self,
|
|
610
|
+
policy: Callable[[TensorDictBase], TensorDictBase],
|
|
611
|
+
*,
|
|
612
|
+
tensordict: TensorDictBase | None = None,
|
|
613
|
+
action_key: NestedKey | list[NestedKey] = "action",
|
|
614
|
+
done_key: NestedKey | list[NestedKey] | None = None,
|
|
615
|
+
observation_key: NestedKey | list[NestedKey] = "observation",
|
|
616
|
+
reward_key: NestedKey | list[NestedKey] = "reward",
|
|
617
|
+
):
|
|
618
|
+
"""Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy.
|
|
619
|
+
|
|
620
|
+
This method performs a rollout using the provided policy to infer the input and output specifications of the environment.
|
|
621
|
+
It updates the environment's specs for actions, observations, rewards, and done signals based on the data collected
|
|
622
|
+
during the rollout.
|
|
623
|
+
|
|
624
|
+
Args:
|
|
625
|
+
policy (Callable[[TensorDictBase], TensorDictBase]):
|
|
626
|
+
A callable policy that takes a `TensorDictBase` as input and returns a `TensorDictBase` as output.
|
|
627
|
+
This policy is used to perform the rollout and determine the specs.
|
|
628
|
+
|
|
629
|
+
Keyword Args:
|
|
630
|
+
tensordict (TensorDictBase, optional):
|
|
631
|
+
An optional `TensorDictBase` instance to be used as the initial state for the rollout.
|
|
632
|
+
If not provided, the environment's `reset` method will be called to obtain the initial state.
|
|
633
|
+
action_key (NestedKey or List[NestedKey], optional):
|
|
634
|
+
The key(s) used to identify actions in the `TensorDictBase`. Defaults to "action".
|
|
635
|
+
done_key (NestedKey or List[NestedKey], optional):
|
|
636
|
+
The key(s) used to identify done signals in the `TensorDictBase`. Defaults to ``None``, which will
|
|
637
|
+
attempt to use ["done", "terminated", "truncated"] as potential keys.
|
|
638
|
+
observation_key (NestedKey or List[NestedKey], optional):
|
|
639
|
+
The key(s) used to identify observations in the `TensorDictBase`. Defaults to "observation".
|
|
640
|
+
reward_key (NestedKey or List[NestedKey], optional):
|
|
641
|
+
The key(s) used to identify rewards in the `TensorDictBase`. Defaults to "reward".
|
|
642
|
+
|
|
643
|
+
Returns:
|
|
644
|
+
EnvBase: The environment instance with updated specs.
|
|
645
|
+
|
|
646
|
+
Raises:
|
|
647
|
+
RuntimeError: If there are keys in the output specs that are not accounted for in the provided keys.
|
|
648
|
+
"""
|
|
649
|
+
if self.batch_locked or tensordict is None:
|
|
650
|
+
batch_size = self.batch_size
|
|
651
|
+
else:
|
|
652
|
+
batch_size = tensordict.batch_size
|
|
653
|
+
if tensordict is None:
|
|
654
|
+
tensordict = self.reset()
|
|
655
|
+
|
|
656
|
+
# Input specs
|
|
657
|
+
tensordict.update(policy(tensordict))
|
|
658
|
+
step_0 = self.step(tensordict.copy())
|
|
659
|
+
tensordict2 = step_0.get("next").copy()
|
|
660
|
+
step_1 = self.step(policy(tensordict2).copy())
|
|
661
|
+
nexts_0: TensorDictBase = step_0.pop("next")
|
|
662
|
+
nexts_1: TensorDictBase = step_1.pop("next")
|
|
663
|
+
|
|
664
|
+
input_spec_stack = {}
|
|
665
|
+
tensordict.apply(
|
|
666
|
+
partial(_tensor_to_spec, stack=input_spec_stack),
|
|
667
|
+
tensordict2,
|
|
668
|
+
named=True,
|
|
669
|
+
nested_keys=True,
|
|
670
|
+
is_leaf=_is_leaf_nontensor,
|
|
671
|
+
)
|
|
672
|
+
input_spec = Composite(input_spec_stack, batch_size=batch_size)
|
|
673
|
+
if not self.batch_locked and batch_size != self.batch_size:
|
|
674
|
+
while input_spec.shape:
|
|
675
|
+
input_spec = input_spec[0]
|
|
676
|
+
if isinstance(action_key, NestedKey):
|
|
677
|
+
action_key = [action_key]
|
|
678
|
+
full_action_spec = input_spec.separates(*action_key, default=None)
|
|
679
|
+
|
|
680
|
+
# Output specs
|
|
681
|
+
|
|
682
|
+
output_spec_stack = {}
|
|
683
|
+
nexts_0.apply(
|
|
684
|
+
partial(_tensor_to_spec, stack=output_spec_stack),
|
|
685
|
+
nexts_1,
|
|
686
|
+
named=True,
|
|
687
|
+
nested_keys=True,
|
|
688
|
+
is_leaf=_is_leaf_nontensor,
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
output_spec = Composite(output_spec_stack, batch_size=batch_size)
|
|
692
|
+
if not self.batch_locked and batch_size != self.batch_size:
|
|
693
|
+
while output_spec.shape:
|
|
694
|
+
output_spec = output_spec[0]
|
|
695
|
+
|
|
696
|
+
if done_key is None:
|
|
697
|
+
done_key = ["done", "terminated", "truncated"]
|
|
698
|
+
full_done_spec = output_spec.separates(*done_key, default=None)
|
|
699
|
+
if full_done_spec is not None:
|
|
700
|
+
self.full_done_spec = full_done_spec
|
|
701
|
+
|
|
702
|
+
if isinstance(reward_key, NestedKey):
|
|
703
|
+
reward_key = [reward_key]
|
|
704
|
+
full_reward_spec = output_spec.separates(*reward_key, default=None)
|
|
705
|
+
|
|
706
|
+
if isinstance(observation_key, NestedKey):
|
|
707
|
+
observation_key = [observation_key]
|
|
708
|
+
full_observation_spec = output_spec.separates(*observation_key, default=None)
|
|
709
|
+
if not output_spec.is_empty(recurse=True):
|
|
710
|
+
raise RuntimeError(
|
|
711
|
+
f"Keys {list(output_spec.keys(True, True))} are unaccounted for. "
|
|
712
|
+
f"Make sure you have passed all the leaf names to the auto_specs_ method."
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
if full_action_spec is not None:
|
|
716
|
+
self.full_action_spec = full_action_spec
|
|
717
|
+
if full_done_spec is not None:
|
|
718
|
+
self.full_done_spec = full_done_spec
|
|
719
|
+
if full_observation_spec is not None:
|
|
720
|
+
self.full_observation_spec = full_observation_spec
|
|
721
|
+
if full_reward_spec is not None:
|
|
722
|
+
self.full_reward_spec = full_reward_spec
|
|
723
|
+
full_state_spec = input_spec
|
|
724
|
+
self.full_state_spec = full_state_spec
|
|
725
|
+
|
|
726
|
+
return self
|
|
727
|
+
|
|
728
|
+
def check_env_specs(self, *args, **kwargs):
|
|
729
|
+
kwargs.setdefault("return_contiguous", not self._has_dynamic_specs)
|
|
730
|
+
return check_env_specs_func(self, *args, **kwargs)
|
|
731
|
+
|
|
732
|
+
check_env_specs.__doc__ = check_env_specs_func.__doc__
|
|
733
|
+
|
|
734
|
+
def cardinality(self, tensordict: TensorDictBase | None = None) -> int:
|
|
735
|
+
"""The cardinality of the action space.
|
|
736
|
+
|
|
737
|
+
By default, this is just a wrapper around :meth:`env.action_space.cardinality <~torchrl.data.TensorSpec.cardinality>`.
|
|
738
|
+
|
|
739
|
+
This class is useful when the action spec is variable:
|
|
740
|
+
|
|
741
|
+
- The number of actions can be undefined, e.g., ``Categorical(n=-1)``;
|
|
742
|
+
- The action cardinality may depend on the action mask;
|
|
743
|
+
- The shape can be dynamic, as in ``Unbound(shape=(-1))``.
|
|
744
|
+
|
|
745
|
+
In these cases, the :meth:`cardinality` should be overwritten,
|
|
746
|
+
|
|
747
|
+
Args:
|
|
748
|
+
tensordict (TensorDictBase, optional): a tensordict containing the data required to compute the cardinality.
|
|
749
|
+
|
|
750
|
+
"""
|
|
751
|
+
return self.full_action_spec.cardinality()
|
|
752
|
+
|
|
753
|
+
def configure_parallel(
|
|
754
|
+
self,
|
|
755
|
+
*,
|
|
756
|
+
use_buffers: bool | None = None,
|
|
757
|
+
shared_memory: bool | None = None,
|
|
758
|
+
memmap: bool | None = None,
|
|
759
|
+
mp_start_method: str | None = None,
|
|
760
|
+
num_threads: int | None = None,
|
|
761
|
+
num_sub_threads: int | None = None,
|
|
762
|
+
non_blocking: bool | None = None,
|
|
763
|
+
daemon: bool | None = None,
|
|
764
|
+
) -> EnvBase:
|
|
765
|
+
"""Configure parallel execution parameters.
|
|
766
|
+
|
|
767
|
+
This method allows configuring parameters for parallel environment
|
|
768
|
+
execution before the environment is started. It is only effective
|
|
769
|
+
on :class:`~torchrl.envs.BatchedEnvBase` and its subclasses.
|
|
770
|
+
|
|
771
|
+
Args:
|
|
772
|
+
use_buffers (bool, optional): whether communication between workers should
|
|
773
|
+
occur via circular preallocated memory buffers.
|
|
774
|
+
shared_memory (bool, optional): whether the returned tensordict will be
|
|
775
|
+
placed in shared memory.
|
|
776
|
+
memmap (bool, optional): whether the returned tensordict will be placed
|
|
777
|
+
in memory map.
|
|
778
|
+
mp_start_method (str, optional): the multiprocessing start method.
|
|
779
|
+
num_threads (int, optional): number of threads for this process.
|
|
780
|
+
num_sub_threads (int, optional): number of threads of the subprocesses.
|
|
781
|
+
non_blocking (bool, optional): if ``True``, device moves will be done using
|
|
782
|
+
the ``non_blocking=True`` option.
|
|
783
|
+
daemon (bool, optional): whether the processes should be daemonized.
|
|
784
|
+
|
|
785
|
+
Returns:
|
|
786
|
+
self: Returns self for method chaining.
|
|
787
|
+
|
|
788
|
+
Raises:
|
|
789
|
+
NotImplementedError: If called on an environment that does not support
|
|
790
|
+
parallel configuration.
|
|
791
|
+
RuntimeError: If called after the environment has already started.
|
|
792
|
+
|
|
793
|
+
Example:
|
|
794
|
+
>>> env = DMControlEnv("cheetah", "run", num_envs=4)
|
|
795
|
+
>>> env.configure_parallel(use_buffers=True, num_threads=2)
|
|
796
|
+
>>> env.reset() # Environment starts here, configure_parallel no longer effective
|
|
797
|
+
|
|
798
|
+
"""
|
|
799
|
+
raise NotImplementedError(
|
|
800
|
+
f"{type(self).__name__} does not support configure_parallel. "
|
|
801
|
+
"This method is only available on BatchedEnvBase and its subclasses."
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
@classmethod
|
|
805
|
+
def make_parallel(
|
|
806
|
+
cls,
|
|
807
|
+
create_env_fn,
|
|
808
|
+
*,
|
|
809
|
+
num_envs: int = 1,
|
|
810
|
+
create_env_kwargs: dict | Sequence[dict] | None = None,
|
|
811
|
+
pin_memory: bool = False,
|
|
812
|
+
share_individual_td: bool | None = None,
|
|
813
|
+
shared_memory: bool = True,
|
|
814
|
+
memmap: bool = False,
|
|
815
|
+
policy_proof: Callable | None = None,
|
|
816
|
+
device: DEVICE_TYPING | None = None,
|
|
817
|
+
allow_step_when_done: bool = False,
|
|
818
|
+
num_threads: int | None = None,
|
|
819
|
+
num_sub_threads: int = 1,
|
|
820
|
+
serial_for_single: bool = False,
|
|
821
|
+
non_blocking: bool = False,
|
|
822
|
+
mp_start_method: str | None = None,
|
|
823
|
+
use_buffers: bool | None = None,
|
|
824
|
+
consolidate: bool = True,
|
|
825
|
+
daemon: bool = False,
|
|
826
|
+
**parallel_kwargs,
|
|
827
|
+
) -> EnvBase:
|
|
828
|
+
"""Factory method to create a ParallelEnv from an environment creator.
|
|
829
|
+
|
|
830
|
+
This method provides a convenient way to create parallel environments
|
|
831
|
+
with the same signature as :class:`~torchrl.envs.ParallelEnv`.
|
|
832
|
+
|
|
833
|
+
Args:
|
|
834
|
+
create_env_fn (callable): A callable that creates an environment instance.
|
|
835
|
+
num_envs (int, optional): Number of parallel environments. Defaults to 1.
|
|
836
|
+
create_env_kwargs (dict or list of dicts, optional): kwargs to be used
|
|
837
|
+
with the environments being created.
|
|
838
|
+
pin_memory (bool, optional): Whether to pin memory. Defaults to False.
|
|
839
|
+
share_individual_td (bool, optional): if ``True``, a different tensordict
|
|
840
|
+
is created for every process/worker and a lazy stack is returned.
|
|
841
|
+
shared_memory (bool, optional): whether the returned tensordict will be
|
|
842
|
+
placed in shared memory. Defaults to True.
|
|
843
|
+
memmap (bool, optional): whether the returned tensordict will be placed
|
|
844
|
+
in memory map. Defaults to False.
|
|
845
|
+
policy_proof (callable, optional): if provided, it'll be used to get
|
|
846
|
+
the list of tensors to return through step() and reset() methods.
|
|
847
|
+
device (str, int, torch.device, optional): The device of the batched
|
|
848
|
+
environment.
|
|
849
|
+
allow_step_when_done (bool, optional): Allow stepping when done.
|
|
850
|
+
Defaults to False.
|
|
851
|
+
num_threads (int, optional): number of threads for this process.
|
|
852
|
+
num_sub_threads (int, optional): number of threads of the subprocesses.
|
|
853
|
+
Defaults to 1.
|
|
854
|
+
serial_for_single (bool, optional): if ``True``, creating a parallel
|
|
855
|
+
environment with a single worker will return a SerialEnv instead.
|
|
856
|
+
Defaults to False.
|
|
857
|
+
non_blocking (bool, optional): if ``True``, device moves will be done
|
|
858
|
+
using the ``non_blocking=True`` option. Defaults to False.
|
|
859
|
+
mp_start_method (str, optional): the multiprocessing start method.
|
|
860
|
+
use_buffers (bool, optional): whether communication between workers
|
|
861
|
+
should occur via circular preallocated memory buffers.
|
|
862
|
+
consolidate (bool, optional): Whether to consolidate tensordicts.
|
|
863
|
+
Defaults to True.
|
|
864
|
+
daemon (bool, optional): whether the processes should be daemonized.
|
|
865
|
+
Defaults to False.
|
|
866
|
+
**parallel_kwargs: Additional keyword arguments passed to ParallelEnv.
|
|
867
|
+
|
|
868
|
+
Returns:
|
|
869
|
+
EnvBase: A ParallelEnv (or SerialEnv if serial_for_single=True and num_envs=1).
|
|
870
|
+
|
|
871
|
+
"""
|
|
872
|
+
from torchrl.envs import ParallelEnv
|
|
873
|
+
|
|
874
|
+
return ParallelEnv(
|
|
875
|
+
num_workers=num_envs,
|
|
876
|
+
create_env_fn=create_env_fn,
|
|
877
|
+
create_env_kwargs=create_env_kwargs,
|
|
878
|
+
pin_memory=pin_memory,
|
|
879
|
+
share_individual_td=share_individual_td,
|
|
880
|
+
shared_memory=shared_memory,
|
|
881
|
+
memmap=memmap,
|
|
882
|
+
policy_proof=policy_proof,
|
|
883
|
+
device=device,
|
|
884
|
+
allow_step_when_done=allow_step_when_done,
|
|
885
|
+
num_threads=num_threads,
|
|
886
|
+
num_sub_threads=num_sub_threads,
|
|
887
|
+
serial_for_single=serial_for_single,
|
|
888
|
+
non_blocking=non_blocking,
|
|
889
|
+
mp_start_method=mp_start_method,
|
|
890
|
+
use_buffers=use_buffers,
|
|
891
|
+
consolidate=consolidate,
|
|
892
|
+
daemon=daemon,
|
|
893
|
+
**parallel_kwargs,
|
|
894
|
+
)
|
|
895
|
+
|
|
896
|
+
@classmethod
|
|
897
|
+
def __new__(cls, *args, _inplace_update=False, _batch_locked=True, **kwargs):
|
|
898
|
+
# inplace update will write tensors in-place on the provided tensordict.
|
|
899
|
+
# This is risky, especially if gradients need to be passed (in-place copy
|
|
900
|
+
# for tensors that are part of computational graphs will result in an error).
|
|
901
|
+
# It can also lead to inconsistencies when calling rollout.
|
|
902
|
+
cls._inplace_update = _inplace_update
|
|
903
|
+
cls._batch_locked = _batch_locked
|
|
904
|
+
cls._device = None
|
|
905
|
+
# cached in_keys to be excluded from update when calling step
|
|
906
|
+
cls._cache_in_keys = None
|
|
907
|
+
|
|
908
|
+
# We may assign _input_spec to the cls, but it must be assigned to the instance
|
|
909
|
+
# we pull it off, and place it back where it belongs
|
|
910
|
+
_input_spec = None
|
|
911
|
+
if hasattr(cls, "_input_spec"):
|
|
912
|
+
_input_spec = cls._input_spec.clone()
|
|
913
|
+
delattr(cls, "_input_spec")
|
|
914
|
+
_output_spec = None
|
|
915
|
+
if hasattr(cls, "_output_spec"):
|
|
916
|
+
_output_spec = cls._output_spec.clone()
|
|
917
|
+
delattr(cls, "_output_spec")
|
|
918
|
+
env = super().__new__(cls)
|
|
919
|
+
if _input_spec is not None:
|
|
920
|
+
env.__dict__["_input_spec"] = _input_spec
|
|
921
|
+
if _output_spec is not None:
|
|
922
|
+
env.__dict__["_output_spec"] = _output_spec
|
|
923
|
+
return env
|
|
924
|
+
|
|
925
|
+
return super().__new__(cls)
|
|
926
|
+
|
|
927
|
+
def __setattr__(self, key, value):
|
|
928
|
+
if key in (
|
|
929
|
+
"_input_spec",
|
|
930
|
+
"_observation_spec",
|
|
931
|
+
"_action_spec",
|
|
932
|
+
"_reward_spec",
|
|
933
|
+
"_output_spec",
|
|
934
|
+
"_state_spec",
|
|
935
|
+
"_done_spec",
|
|
936
|
+
):
|
|
937
|
+
raise AttributeError(
|
|
938
|
+
"To set an environment spec, please use `env.observation_spec = obs_spec` (without the leading"
|
|
939
|
+
" underscore)."
|
|
940
|
+
)
|
|
941
|
+
super().__setattr__(key, value)
|
|
942
|
+
|
|
943
|
+
@property
|
|
944
|
+
def batch_locked(self) -> bool:
|
|
945
|
+
"""Whether the environment can be used with a batch size different from the one it was initialized with or not.
|
|
946
|
+
|
|
947
|
+
If True, the env needs to be used with a tensordict having the same batch size as the env.
|
|
948
|
+
batch_locked is an immutable property.
|
|
949
|
+
"""
|
|
950
|
+
return self._batch_locked
|
|
951
|
+
|
|
952
|
+
@batch_locked.setter
|
|
953
|
+
def batch_locked(self, value: bool) -> None:
|
|
954
|
+
raise RuntimeError("batch_locked is a read-only property")
|
|
955
|
+
|
|
956
|
+
@property
|
|
957
|
+
def run_type_checks(self) -> bool:
|
|
958
|
+
return self._run_type_checks
|
|
959
|
+
|
|
960
|
+
@run_type_checks.setter
|
|
961
|
+
def run_type_checks(self, run_type_checks: bool) -> None:
|
|
962
|
+
self._run_type_checks = run_type_checks
|
|
963
|
+
|
|
964
|
+
@property
|
|
965
|
+
def batch_size(self) -> torch.Size:
|
|
966
|
+
"""Number of envs batched in this environment instance organised in a `torch.Size()` object.
|
|
967
|
+
|
|
968
|
+
Environment may be similar or different but it is assumed that they have little if
|
|
969
|
+
not no interactions between them (e.g., multi-task or batched execution
|
|
970
|
+
in parallel).
|
|
971
|
+
|
|
972
|
+
"""
|
|
973
|
+
_batch_size = self.__dict__.get("_batch_size")
|
|
974
|
+
if _batch_size is None:
|
|
975
|
+
_batch_size = self._batch_size = torch.Size([])
|
|
976
|
+
return _batch_size
|
|
977
|
+
|
|
978
|
+
@batch_size.setter
|
|
979
|
+
@_maybe_unlock
|
|
980
|
+
def batch_size(self, value: torch.Size) -> None:
|
|
981
|
+
self._batch_size = torch.Size(value)
|
|
982
|
+
if (
|
|
983
|
+
hasattr(self, "output_spec")
|
|
984
|
+
and self.output_spec.shape[: len(value)] != value
|
|
985
|
+
):
|
|
986
|
+
self.output_spec.shape = value
|
|
987
|
+
if hasattr(self, "input_spec") and self.input_spec.shape[: len(value)] != value:
|
|
988
|
+
self.input_spec.shape = value
|
|
989
|
+
|
|
990
|
+
@property
|
|
991
|
+
def shape(self):
|
|
992
|
+
"""Equivalent to :attr:`~.batch_size`."""
|
|
993
|
+
return self.batch_size
|
|
994
|
+
|
|
995
|
+
@property
|
|
996
|
+
def device(self) -> torch.device:
|
|
997
|
+
device = self.__dict__.get("_device")
|
|
998
|
+
return device
|
|
999
|
+
|
|
1000
|
+
@device.setter
|
|
1001
|
+
def device(self, value: torch.device) -> None:
|
|
1002
|
+
device = self.__dict__.get("_device")
|
|
1003
|
+
if device is None:
|
|
1004
|
+
self.__dict__["_device"] = value
|
|
1005
|
+
return
|
|
1006
|
+
raise RuntimeError("device cannot be set. Call env.to(device) instead.")
|
|
1007
|
+
|
|
1008
|
+
def ndimension(self):
|
|
1009
|
+
return len(self.batch_size)
|
|
1010
|
+
|
|
1011
|
+
@property
|
|
1012
|
+
def ndim(self):
|
|
1013
|
+
return self.ndimension()
|
|
1014
|
+
|
|
1015
|
+
def append_transform(
|
|
1016
|
+
self,
|
|
1017
|
+
transform: Transform | Callable[[TensorDictBase], TensorDictBase], # noqa: F821
|
|
1018
|
+
) -> torchrl.envs.TransformedEnv: # noqa
|
|
1019
|
+
"""Returns a transformed environment where the callable/transform passed is applied.
|
|
1020
|
+
|
|
1021
|
+
Args:
|
|
1022
|
+
transform (Transform or Callable[[TensorDictBase], TensorDictBase]): the transform to apply
|
|
1023
|
+
to the environment.
|
|
1024
|
+
|
|
1025
|
+
Examples:
|
|
1026
|
+
>>> from torchrl.envs import GymEnv
|
|
1027
|
+
>>> import torch
|
|
1028
|
+
>>> env = GymEnv("CartPole-v1")
|
|
1029
|
+
>>> loc = 0.5
|
|
1030
|
+
>>> scale = 1.0
|
|
1031
|
+
>>> transform = lambda data: data.set("observation", (data.get("observation") - loc)/scale)
|
|
1032
|
+
>>> env = env.append_transform(transform=transform)
|
|
1033
|
+
>>> print(env)
|
|
1034
|
+
TransformedEnv(
|
|
1035
|
+
env=GymEnv(env=CartPole-v1, batch_size=torch.Size([]), device=cpu),
|
|
1036
|
+
transform=_CallableTransform(keys=[]))
|
|
1037
|
+
|
|
1038
|
+
"""
|
|
1039
|
+
from torchrl.envs.transforms.transforms import TransformedEnv
|
|
1040
|
+
|
|
1041
|
+
return TransformedEnv(self, transform)
|
|
1042
|
+
|
|
1043
|
+
# Parent specs: input and output spec.
|
|
1044
|
+
@property
|
|
1045
|
+
def input_spec(self) -> TensorSpec:
|
|
1046
|
+
"""Input spec.
|
|
1047
|
+
|
|
1048
|
+
The composite spec containing all specs for data input to the environments.
|
|
1049
|
+
|
|
1050
|
+
It contains:
|
|
1051
|
+
|
|
1052
|
+
- "full_action_spec": the spec of the input actions
|
|
1053
|
+
- "full_state_spec": the spec of all other environment inputs
|
|
1054
|
+
|
|
1055
|
+
This attribute is locked and should be read-only.
|
|
1056
|
+
Instead, to set the specs contained in it, use the respective properties.
|
|
1057
|
+
|
|
1058
|
+
Examples:
|
|
1059
|
+
>>> from torchrl.envs.libs.gym import GymEnv
|
|
1060
|
+
>>> env = GymEnv("Pendulum-v1")
|
|
1061
|
+
>>> env.input_spec
|
|
1062
|
+
Composite(
|
|
1063
|
+
full_state_spec: None,
|
|
1064
|
+
full_action_spec: Composite(
|
|
1065
|
+
action: BoundedContinuous(
|
|
1066
|
+
shape=torch.Size([1]),
|
|
1067
|
+
space=ContinuousBox(
|
|
1068
|
+
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
1069
|
+
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
1070
|
+
device=cpu,
|
|
1071
|
+
dtype=torch.float32,
|
|
1072
|
+
domain=continuous), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
|
|
1073
|
+
|
|
1074
|
+
|
|
1075
|
+
"""
|
|
1076
|
+
input_spec = self.__dict__.get("_input_spec")
|
|
1077
|
+
if input_spec is None:
|
|
1078
|
+
is_locked = self.is_spec_locked
|
|
1079
|
+
if is_locked:
|
|
1080
|
+
self.set_spec_lock_(False)
|
|
1081
|
+
input_spec = Composite(
|
|
1082
|
+
full_state_spec=None,
|
|
1083
|
+
shape=self.batch_size,
|
|
1084
|
+
device=self.device,
|
|
1085
|
+
)
|
|
1086
|
+
self.__dict__["_input_spec"] = input_spec
|
|
1087
|
+
if is_locked:
|
|
1088
|
+
self.set_spec_lock_(True)
|
|
1089
|
+
return input_spec
|
|
1090
|
+
|
|
1091
|
+
@input_spec.setter
|
|
1092
|
+
def input_spec(self, value: TensorSpec) -> None:
|
|
1093
|
+
raise RuntimeError("input_spec is protected.")
|
|
1094
|
+
|
|
1095
|
+
@property
|
|
1096
|
+
def output_spec(self) -> TensorSpec:
|
|
1097
|
+
"""Output spec.
|
|
1098
|
+
|
|
1099
|
+
The composite spec containing all specs for data output from the environments.
|
|
1100
|
+
|
|
1101
|
+
It contains:
|
|
1102
|
+
|
|
1103
|
+
- "full_reward_spec": the spec of reward
|
|
1104
|
+
- "full_done_spec": the spec of done
|
|
1105
|
+
- "full_observation_spec": the spec of all other environment outputs
|
|
1106
|
+
|
|
1107
|
+
This attribute is locked and should be read-only.
|
|
1108
|
+
Instead, to set the specs contained in it, use the respective properties.
|
|
1109
|
+
|
|
1110
|
+
Examples:
|
|
1111
|
+
>>> from torchrl.envs.libs.gym import GymEnv
|
|
1112
|
+
>>> env = GymEnv("Pendulum-v1")
|
|
1113
|
+
>>> env.output_spec
|
|
1114
|
+
Composite(
|
|
1115
|
+
full_reward_spec: Composite(
|
|
1116
|
+
reward: UnboundedContinuous(
|
|
1117
|
+
shape=torch.Size([1]),
|
|
1118
|
+
space=None,
|
|
1119
|
+
device=cpu,
|
|
1120
|
+
dtype=torch.float32,
|
|
1121
|
+
domain=continuous), device=cpu, shape=torch.Size([])),
|
|
1122
|
+
full_observation_spec: Composite(
|
|
1123
|
+
observation: BoundedContinuous(
|
|
1124
|
+
shape=torch.Size([3]),
|
|
1125
|
+
space=ContinuousBox(
|
|
1126
|
+
low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
1127
|
+
high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
1128
|
+
device=cpu,
|
|
1129
|
+
dtype=torch.float32,
|
|
1130
|
+
domain=continuous), device=cpu, shape=torch.Size([])),
|
|
1131
|
+
full_done_spec: Composite(
|
|
1132
|
+
done: Categorical(
|
|
1133
|
+
shape=torch.Size([1]),
|
|
1134
|
+
space=DiscreteBox(n=2),
|
|
1135
|
+
device=cpu,
|
|
1136
|
+
dtype=torch.bool,
|
|
1137
|
+
domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
|
|
1138
|
+
|
|
1139
|
+
|
|
1140
|
+
"""
|
|
1141
|
+
output_spec = self.__dict__.get("_output_spec")
|
|
1142
|
+
if output_spec is None:
|
|
1143
|
+
is_locked = self.is_spec_locked
|
|
1144
|
+
if is_locked:
|
|
1145
|
+
self.set_spec_lock_(False)
|
|
1146
|
+
output_spec = Composite(
|
|
1147
|
+
shape=self.batch_size,
|
|
1148
|
+
device=self.device,
|
|
1149
|
+
)
|
|
1150
|
+
self.__dict__["_output_spec"] = output_spec
|
|
1151
|
+
if is_locked:
|
|
1152
|
+
self.set_spec_lock_(True)
|
|
1153
|
+
return output_spec
|
|
1154
|
+
|
|
1155
|
+
@output_spec.setter
|
|
1156
|
+
def output_spec(self, value: TensorSpec) -> None:
|
|
1157
|
+
raise RuntimeError("output_spec is protected.")
|
|
1158
|
+
|
|
1159
|
+
@property
|
|
1160
|
+
@_cache_value
|
|
1161
|
+
def action_keys(self) -> list[NestedKey]:
|
|
1162
|
+
"""The action keys of an environment.
|
|
1163
|
+
|
|
1164
|
+
By default, there will only be one key named "action".
|
|
1165
|
+
|
|
1166
|
+
Keys are sorted by depth in the data tree.
|
|
1167
|
+
"""
|
|
1168
|
+
keys = self.full_action_spec.keys(True, True)
|
|
1169
|
+
keys = sorted(keys, key=_repr_by_depth)
|
|
1170
|
+
return keys
|
|
1171
|
+
|
|
1172
|
+
@property
|
|
1173
|
+
@_cache_value
|
|
1174
|
+
def state_keys(self) -> list[NestedKey]:
|
|
1175
|
+
"""The state keys of an environment.
|
|
1176
|
+
|
|
1177
|
+
By default, there will only be one key named "state".
|
|
1178
|
+
|
|
1179
|
+
Keys are sorted by depth in the data tree.
|
|
1180
|
+
"""
|
|
1181
|
+
state_keys = self.__dict__.get("_state_keys")
|
|
1182
|
+
if state_keys is not None:
|
|
1183
|
+
return state_keys
|
|
1184
|
+
keys = self.input_spec["full_state_spec"].keys(True, True)
|
|
1185
|
+
keys = sorted(keys, key=_repr_by_depth)
|
|
1186
|
+
self.__dict__["_state_keys"] = keys
|
|
1187
|
+
return keys
|
|
1188
|
+
|
|
1189
|
+
@property
|
|
1190
|
+
def action_key(self) -> NestedKey:
|
|
1191
|
+
"""The action key of an environment.
|
|
1192
|
+
|
|
1193
|
+
By default, this will be "action".
|
|
1194
|
+
|
|
1195
|
+
If there is more than one action key in the environment, this function will raise an exception.
|
|
1196
|
+
"""
|
|
1197
|
+
if len(self.action_keys) > 1:
|
|
1198
|
+
raise KeyError(
|
|
1199
|
+
"action_key requested but more than one key present in the environment"
|
|
1200
|
+
)
|
|
1201
|
+
return self.action_keys[0]
|
|
1202
|
+
|
|
1203
|
+
# Action spec: action specs belong to input_spec
|
|
1204
|
+
@property
|
|
1205
|
+
def action_spec(self) -> TensorSpec:
|
|
1206
|
+
"""The ``action`` spec.
|
|
1207
|
+
|
|
1208
|
+
The ``action_spec`` is always stored as a composite spec.
|
|
1209
|
+
|
|
1210
|
+
If the action spec is provided as a simple spec, this will be returned.
|
|
1211
|
+
|
|
1212
|
+
>>> env.action_spec = Unbounded(1)
|
|
1213
|
+
>>> env.action_spec
|
|
1214
|
+
UnboundedContinuous(
|
|
1215
|
+
shape=torch.Size([1]),
|
|
1216
|
+
space=ContinuousBox(
|
|
1217
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
1218
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
1219
|
+
device=cpu,
|
|
1220
|
+
dtype=torch.float32,
|
|
1221
|
+
domain=continuous)
|
|
1222
|
+
|
|
1223
|
+
If the action spec is provided as a composite spec and contains only one leaf,
|
|
1224
|
+
this function will return just the leaf.
|
|
1225
|
+
|
|
1226
|
+
>>> env.action_spec = Composite({"nested": {"action": Unbounded(1)}})
|
|
1227
|
+
>>> env.action_spec
|
|
1228
|
+
UnboundedContinuous(
|
|
1229
|
+
shape=torch.Size([1]),
|
|
1230
|
+
space=ContinuousBox(
|
|
1231
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
1232
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
1233
|
+
device=cpu,
|
|
1234
|
+
dtype=torch.float32,
|
|
1235
|
+
domain=continuous)
|
|
1236
|
+
|
|
1237
|
+
If the action spec is provided as a composite spec and has more than one leaf,
|
|
1238
|
+
this function will return the whole spec.
|
|
1239
|
+
|
|
1240
|
+
>>> env.action_spec = Composite({"nested": {"action": Unbounded(1), "another_action": Categorical(1)}})
|
|
1241
|
+
>>> env.action_spec
|
|
1242
|
+
Composite(
|
|
1243
|
+
nested: Composite(
|
|
1244
|
+
action: UnboundedContinuous(
|
|
1245
|
+
shape=torch.Size([1]),
|
|
1246
|
+
space=ContinuousBox(
|
|
1247
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
1248
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
1249
|
+
device=cpu,
|
|
1250
|
+
dtype=torch.float32,
|
|
1251
|
+
domain=continuous),
|
|
1252
|
+
another_action: Categorical(
|
|
1253
|
+
shape=torch.Size([]),
|
|
1254
|
+
space=DiscreteBox(n=1),
|
|
1255
|
+
device=cpu,
|
|
1256
|
+
dtype=torch.int64,
|
|
1257
|
+
domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
|
|
1258
|
+
|
|
1259
|
+
To retrieve the full spec passed, use:
|
|
1260
|
+
|
|
1261
|
+
>>> env.input_spec["full_action_spec"]
|
|
1262
|
+
|
|
1263
|
+
This property is mutable.
|
|
1264
|
+
|
|
1265
|
+
Examples:
|
|
1266
|
+
>>> from torchrl.envs.libs.gym import GymEnv
|
|
1267
|
+
>>> env = GymEnv("Pendulum-v1")
|
|
1268
|
+
>>> env.action_spec
|
|
1269
|
+
BoundedContinuous(
|
|
1270
|
+
shape=torch.Size([1]),
|
|
1271
|
+
space=ContinuousBox(
|
|
1272
|
+
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
1273
|
+
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
1274
|
+
device=cpu,
|
|
1275
|
+
dtype=torch.float32,
|
|
1276
|
+
domain=continuous)
|
|
1277
|
+
"""
|
|
1278
|
+
try:
|
|
1279
|
+
action_spec = self.input_spec["full_action_spec"]
|
|
1280
|
+
except (KeyError, AttributeError):
|
|
1281
|
+
raise KeyError("Failed to find the action_spec.")
|
|
1282
|
+
|
|
1283
|
+
if len(self.action_keys) > 1:
|
|
1284
|
+
return action_spec
|
|
1285
|
+
else:
|
|
1286
|
+
if len(self.action_keys) == 1 and self.action_keys[0] != "action":
|
|
1287
|
+
return action_spec
|
|
1288
|
+
try:
|
|
1289
|
+
return action_spec[self.action_key]
|
|
1290
|
+
except KeyError:
|
|
1291
|
+
# the key may have changed
|
|
1292
|
+
raise KeyError(
|
|
1293
|
+
"The action_key attribute seems to have changed. "
|
|
1294
|
+
"This occurs when a action_spec is updated without "
|
|
1295
|
+
"calling `env.action_spec = new_spec`. "
|
|
1296
|
+
"Make sure you rely on this type of command "
|
|
1297
|
+
"to set the action and other specs."
|
|
1298
|
+
)
|
|
1299
|
+
|
|
1300
|
+
@action_spec.setter
|
|
1301
|
+
@_maybe_unlock
|
|
1302
|
+
def action_spec(self, value: TensorSpec) -> None:
|
|
1303
|
+
device = self.input_spec._device
|
|
1304
|
+
if not hasattr(value, "shape"):
|
|
1305
|
+
raise TypeError(
|
|
1306
|
+
f"action_spec of type {type(value)} do not have a shape attribute."
|
|
1307
|
+
)
|
|
1308
|
+
if value.shape[: len(self.batch_size)] != self.batch_size:
|
|
1309
|
+
raise ValueError(
|
|
1310
|
+
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size}). "
|
|
1311
|
+
"Please use `env.action_spec_unbatched = value` to set unbatched versions instead."
|
|
1312
|
+
)
|
|
1313
|
+
|
|
1314
|
+
if not isinstance(value, Composite):
|
|
1315
|
+
value = Composite(
|
|
1316
|
+
action=value.to(device), shape=self.batch_size, device=device
|
|
1317
|
+
)
|
|
1318
|
+
|
|
1319
|
+
self.input_spec["full_action_spec"] = value.to(device)
|
|
1320
|
+
|
|
1321
|
+
@property
|
|
1322
|
+
def full_action_spec(self) -> Composite:
|
|
1323
|
+
"""The full action spec.
|
|
1324
|
+
|
|
1325
|
+
``full_action_spec`` is a :class:`~torchrl.data.Composite`` instance
|
|
1326
|
+
that contains all the action entries.
|
|
1327
|
+
|
|
1328
|
+
Examples:
|
|
1329
|
+
>>> from torchrl.envs import BraxEnv
|
|
1330
|
+
>>> for envname in BraxEnv.available_envs:
|
|
1331
|
+
... break
|
|
1332
|
+
>>> env = BraxEnv(envname)
|
|
1333
|
+
>>> env.full_action_spec
|
|
1334
|
+
Composite(
|
|
1335
|
+
action: BoundedContinuous(
|
|
1336
|
+
shape=torch.Size([8]),
|
|
1337
|
+
space=ContinuousBox(
|
|
1338
|
+
low=Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
1339
|
+
high=Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
1340
|
+
device=cpu,
|
|
1341
|
+
dtype=torch.float32,
|
|
1342
|
+
domain=continuous), device=cpu, shape=torch.Size([]))
|
|
1343
|
+
|
|
1344
|
+
"""
|
|
1345
|
+
full_action_spec = self.input_spec.get("full_action_spec", None)
|
|
1346
|
+
if full_action_spec is None:
|
|
1347
|
+
is_locked = self.is_spec_locked
|
|
1348
|
+
if is_locked:
|
|
1349
|
+
self.set_spec_lock_(False)
|
|
1350
|
+
full_action_spec = Composite(shape=self.batch_size, device=self.device)
|
|
1351
|
+
self.input_spec["full_action_spec"] = full_action_spec
|
|
1352
|
+
if is_locked:
|
|
1353
|
+
self.set_spec_lock_(True)
|
|
1354
|
+
return full_action_spec
|
|
1355
|
+
|
|
1356
|
+
@full_action_spec.setter
|
|
1357
|
+
def full_action_spec(self, spec: Composite) -> None:
|
|
1358
|
+
self.action_spec = spec
|
|
1359
|
+
|
|
1360
|
+
# Reward spec
|
|
1361
|
+
@property
|
|
1362
|
+
@_cache_value
|
|
1363
|
+
def reward_keys(self) -> list[NestedKey]:
|
|
1364
|
+
"""The reward keys of an environment.
|
|
1365
|
+
|
|
1366
|
+
By default, there will only be one key named "reward".
|
|
1367
|
+
|
|
1368
|
+
Keys are sorted by depth in the data tree.
|
|
1369
|
+
"""
|
|
1370
|
+
reward_keys = sorted(self.full_reward_spec.keys(True, True), key=_repr_by_depth)
|
|
1371
|
+
return reward_keys
|
|
1372
|
+
|
|
1373
|
+
@property
|
|
1374
|
+
@_cache_value
|
|
1375
|
+
def observation_keys(self) -> list[NestedKey]:
|
|
1376
|
+
"""The observation keys of an environment.
|
|
1377
|
+
|
|
1378
|
+
By default, there will only be one key named "observation".
|
|
1379
|
+
|
|
1380
|
+
Keys are sorted by depth in the data tree.
|
|
1381
|
+
"""
|
|
1382
|
+
observation_keys = sorted(
|
|
1383
|
+
self.full_observation_spec.keys(True, True), key=_repr_by_depth
|
|
1384
|
+
)
|
|
1385
|
+
return observation_keys
|
|
1386
|
+
|
|
1387
|
+
@property
|
|
1388
|
+
@_cache_value
|
|
1389
|
+
def _observation_keys_step_mdp(self) -> list[NestedKey]:
|
|
1390
|
+
"""The observation keys of an environment that are static under step_mdp (i.e. to be copied as-is during step_mdp)."""
|
|
1391
|
+
observation_keys_leaves = sorted(
|
|
1392
|
+
self.full_observation_spec.keys(True, True, step_mdp_static_only=True),
|
|
1393
|
+
key=_repr_by_depth,
|
|
1394
|
+
)
|
|
1395
|
+
return observation_keys_leaves
|
|
1396
|
+
|
|
1397
|
+
@property
|
|
1398
|
+
@_cache_value
|
|
1399
|
+
def _state_keys_step_mdp(self) -> list[NestedKey]:
|
|
1400
|
+
"""The state keys of an environment that are static under step_mdp (i.e. to be copied as-is during step_mdp)."""
|
|
1401
|
+
state_keys_leaves = sorted(
|
|
1402
|
+
self.full_state_spec.keys(True, True, step_mdp_static_only=True),
|
|
1403
|
+
key=_repr_by_depth,
|
|
1404
|
+
)
|
|
1405
|
+
return state_keys_leaves
|
|
1406
|
+
|
|
1407
|
+
@property
|
|
1408
|
+
@_cache_value
|
|
1409
|
+
def _action_keys_step_mdp(self) -> list[NestedKey]:
|
|
1410
|
+
"""The action keys of an environment that are static under step_mdp (i.e. to be copied as-is during step_mdp)."""
|
|
1411
|
+
action_keys_leaves = sorted(
|
|
1412
|
+
self.full_action_spec.keys(True, True, step_mdp_static_only=True),
|
|
1413
|
+
key=_repr_by_depth,
|
|
1414
|
+
)
|
|
1415
|
+
return action_keys_leaves
|
|
1416
|
+
|
|
1417
|
+
@property
|
|
1418
|
+
@_cache_value
|
|
1419
|
+
def _done_keys_step_mdp(self) -> list[NestedKey]:
|
|
1420
|
+
"""The done keys of an environment that are static under step_mdp (i.e. to be copied as-is during step_mdp)."""
|
|
1421
|
+
done_keys_leaves = sorted(
|
|
1422
|
+
self.full_done_spec.keys(True, True, step_mdp_static_only=True),
|
|
1423
|
+
key=_repr_by_depth,
|
|
1424
|
+
)
|
|
1425
|
+
return done_keys_leaves
|
|
1426
|
+
|
|
1427
|
+
@property
|
|
1428
|
+
@_cache_value
|
|
1429
|
+
def _reward_keys_step_mdp(self) -> list[NestedKey]:
|
|
1430
|
+
"""The reward keys of an environment that are static under step_mdp (i.e. to be copied as-is during step_mdp)."""
|
|
1431
|
+
reward_keys_leaves = sorted(
|
|
1432
|
+
self.full_reward_spec.keys(True, True, step_mdp_static_only=True),
|
|
1433
|
+
key=_repr_by_depth,
|
|
1434
|
+
)
|
|
1435
|
+
return reward_keys_leaves
|
|
1436
|
+
|
|
1437
|
+
@property
|
|
1438
|
+
def reward_key(self):
|
|
1439
|
+
"""The reward key of an environment.
|
|
1440
|
+
|
|
1441
|
+
By default, this will be "reward".
|
|
1442
|
+
|
|
1443
|
+
If there is more than one reward key in the environment, this function will raise an exception.
|
|
1444
|
+
"""
|
|
1445
|
+
if len(self.reward_keys) > 1:
|
|
1446
|
+
raise KeyError(
|
|
1447
|
+
"reward_key requested but more than one key present in the environment"
|
|
1448
|
+
)
|
|
1449
|
+
return self.reward_keys[0]
|
|
1450
|
+
|
|
1451
|
+
# Reward spec: reward specs belong to output_spec
|
|
1452
|
+
@property
|
|
1453
|
+
def reward_spec(self) -> TensorSpec:
|
|
1454
|
+
"""The ``reward`` spec.
|
|
1455
|
+
|
|
1456
|
+
The ``reward_spec`` is always stored as a composite spec.
|
|
1457
|
+
|
|
1458
|
+
If the reward spec is provided as a simple spec, this will be returned.
|
|
1459
|
+
|
|
1460
|
+
>>> env.reward_spec = Unbounded(1)
|
|
1461
|
+
>>> env.reward_spec
|
|
1462
|
+
UnboundedContinuous(
|
|
1463
|
+
shape=torch.Size([1]),
|
|
1464
|
+
space=ContinuousBox(
|
|
1465
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
1466
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
1467
|
+
device=cpu,
|
|
1468
|
+
dtype=torch.float32,
|
|
1469
|
+
domain=continuous)
|
|
1470
|
+
|
|
1471
|
+
If the reward spec is provided as a composite spec and contains only one leaf,
|
|
1472
|
+
this function will return just the leaf.
|
|
1473
|
+
|
|
1474
|
+
>>> env.reward_spec = Composite({"nested": {"reward": Unbounded(1)}})
|
|
1475
|
+
>>> env.reward_spec
|
|
1476
|
+
UnboundedContinuous(
|
|
1477
|
+
shape=torch.Size([1]),
|
|
1478
|
+
space=ContinuousBox(
|
|
1479
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
1480
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
1481
|
+
device=cpu,
|
|
1482
|
+
dtype=torch.float32,
|
|
1483
|
+
domain=continuous)
|
|
1484
|
+
|
|
1485
|
+
If the reward spec is provided as a composite spec and has more than one leaf,
|
|
1486
|
+
this function will return the whole spec.
|
|
1487
|
+
|
|
1488
|
+
>>> env.reward_spec = Composite({"nested": {"reward": Unbounded(1), "another_reward": Categorical(1)}})
|
|
1489
|
+
>>> env.reward_spec
|
|
1490
|
+
Composite(
|
|
1491
|
+
nested: Composite(
|
|
1492
|
+
reward: UnboundedContinuous(
|
|
1493
|
+
shape=torch.Size([1]),
|
|
1494
|
+
space=ContinuousBox(
|
|
1495
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
1496
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
1497
|
+
device=cpu,
|
|
1498
|
+
dtype=torch.float32,
|
|
1499
|
+
domain=continuous),
|
|
1500
|
+
another_reward: Categorical(
|
|
1501
|
+
shape=torch.Size([]),
|
|
1502
|
+
space=DiscreteBox(n=1),
|
|
1503
|
+
device=cpu,
|
|
1504
|
+
dtype=torch.int64,
|
|
1505
|
+
domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
|
|
1506
|
+
|
|
1507
|
+
To retrieve the full spec passed, use:
|
|
1508
|
+
|
|
1509
|
+
>>> env.output_spec["full_reward_spec"]
|
|
1510
|
+
|
|
1511
|
+
This property is mutable.
|
|
1512
|
+
|
|
1513
|
+
Examples:
|
|
1514
|
+
>>> from torchrl.envs.libs.gym import GymEnv
|
|
1515
|
+
>>> env = GymEnv("Pendulum-v1")
|
|
1516
|
+
>>> env.reward_spec
|
|
1517
|
+
UnboundedContinuous(
|
|
1518
|
+
shape=torch.Size([1]),
|
|
1519
|
+
space=None,
|
|
1520
|
+
device=cpu,
|
|
1521
|
+
dtype=torch.float32,
|
|
1522
|
+
domain=continuous)
|
|
1523
|
+
"""
|
|
1524
|
+
try:
|
|
1525
|
+
reward_spec = self.output_spec["full_reward_spec"]
|
|
1526
|
+
except (KeyError, AttributeError):
|
|
1527
|
+
# populate the "reward" entry
|
|
1528
|
+
# this will be raised if there is not full_reward_spec (unlikely) or no reward_key
|
|
1529
|
+
# Since output_spec is lazily populated with an empty composite spec for
|
|
1530
|
+
# reward_spec, the second case is much more likely to occur.
|
|
1531
|
+
self.reward_spec = Unbounded(
|
|
1532
|
+
shape=(*self.batch_size, 1),
|
|
1533
|
+
device=self.device,
|
|
1534
|
+
)
|
|
1535
|
+
reward_spec = self.output_spec["full_reward_spec"]
|
|
1536
|
+
|
|
1537
|
+
reward_keys = self.reward_keys
|
|
1538
|
+
if len(reward_keys) > 1 or not len(reward_keys):
|
|
1539
|
+
return reward_spec
|
|
1540
|
+
else:
|
|
1541
|
+
if len(self.reward_keys) == 1 and self.reward_keys[0] != "reward":
|
|
1542
|
+
return reward_spec
|
|
1543
|
+
return reward_spec[self.reward_keys[0]]
|
|
1544
|
+
|
|
1545
|
+
@reward_spec.setter
|
|
1546
|
+
@_maybe_unlock
|
|
1547
|
+
def reward_spec(self, value: TensorSpec) -> None:
|
|
1548
|
+
device = self.output_spec._device
|
|
1549
|
+
if not hasattr(value, "shape"):
|
|
1550
|
+
raise TypeError(
|
|
1551
|
+
f"reward_spec of type {type(value)} do not have a shape " f"attribute."
|
|
1552
|
+
)
|
|
1553
|
+
if value.shape[: len(self.batch_size)] != self.batch_size:
|
|
1554
|
+
raise ValueError(
|
|
1555
|
+
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size}). "
|
|
1556
|
+
"Please use `env.reward_spec_unbatched = value` to set unbatched versions instead."
|
|
1557
|
+
)
|
|
1558
|
+
if not isinstance(value, Composite):
|
|
1559
|
+
value = Composite(
|
|
1560
|
+
reward=value.to(device), shape=self.batch_size, device=device
|
|
1561
|
+
)
|
|
1562
|
+
for leaf in value.values(True, True):
|
|
1563
|
+
if len(leaf.shape) == 0:
|
|
1564
|
+
raise RuntimeError(
|
|
1565
|
+
"the reward_spec's leaves shape cannot be empty (this error"
|
|
1566
|
+
" usually comes from trying to set a reward_spec"
|
|
1567
|
+
" with a null number of dimensions. Try using a multidimensional"
|
|
1568
|
+
" spec instead, for instance with a singleton dimension at the tail)."
|
|
1569
|
+
)
|
|
1570
|
+
self.output_spec["full_reward_spec"] = value.to(device)
|
|
1571
|
+
|
|
1572
|
+
@property
|
|
1573
|
+
def full_reward_spec(self) -> Composite:
|
|
1574
|
+
"""The full reward spec.
|
|
1575
|
+
|
|
1576
|
+
``full_reward_spec`` is a :class:`~torchrl.data.Composite`` instance
|
|
1577
|
+
that contains all the reward entries.
|
|
1578
|
+
|
|
1579
|
+
Examples:
|
|
1580
|
+
>>> import gymnasium
|
|
1581
|
+
>>> from torchrl.envs import GymWrapper, TransformedEnv, RenameTransform
|
|
1582
|
+
>>> base_env = GymWrapper(gymnasium.make("Pendulum-v1"))
|
|
1583
|
+
>>> env = TransformedEnv(base_env, RenameTransform("reward", ("nested", "reward")))
|
|
1584
|
+
>>> env.full_reward_spec
|
|
1585
|
+
Composite(
|
|
1586
|
+
nested: Composite(
|
|
1587
|
+
reward: UnboundedContinuous(
|
|
1588
|
+
shape=torch.Size([1]),
|
|
1589
|
+
space=ContinuousBox(
|
|
1590
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
1591
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
1592
|
+
device=cpu,
|
|
1593
|
+
dtype=torch.float32,
|
|
1594
|
+
domain=continuous), device=None, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
|
|
1595
|
+
|
|
1596
|
+
"""
|
|
1597
|
+
try:
|
|
1598
|
+
return self.output_spec["full_reward_spec"]
|
|
1599
|
+
except KeyError:
|
|
1600
|
+
# populate the "reward" entry
|
|
1601
|
+
# this will be raised if there is not full_reward_spec (unlikely) or no reward_key
|
|
1602
|
+
# Since output_spec is lazily populated with an empty composite spec for
|
|
1603
|
+
# reward_spec, the second case is much more likely to occur.
|
|
1604
|
+
self.reward_spec = Unbounded(
|
|
1605
|
+
shape=(*self.batch_size, 1),
|
|
1606
|
+
device=self.device,
|
|
1607
|
+
)
|
|
1608
|
+
return self.output_spec["full_reward_spec"]
|
|
1609
|
+
|
|
1610
|
+
@full_reward_spec.setter
|
|
1611
|
+
@_maybe_unlock
|
|
1612
|
+
def full_reward_spec(self, spec: Composite) -> None:
|
|
1613
|
+
self.reward_spec = spec.to(self.device) if self.device is not None else spec
|
|
1614
|
+
|
|
1615
|
+
# done spec
|
|
1616
|
+
@property
|
|
1617
|
+
@_cache_value
|
|
1618
|
+
def done_keys(self) -> list[NestedKey]:
|
|
1619
|
+
"""The done keys of an environment.
|
|
1620
|
+
|
|
1621
|
+
By default, there will only be one key named "done".
|
|
1622
|
+
|
|
1623
|
+
Keys are sorted by depth in the data tree.
|
|
1624
|
+
"""
|
|
1625
|
+
done_keys = sorted(self.full_done_spec.keys(True, True), key=_repr_by_depth)
|
|
1626
|
+
return done_keys
|
|
1627
|
+
|
|
1628
|
+
@property
|
|
1629
|
+
def done_key(self):
|
|
1630
|
+
"""The done key of an environment.
|
|
1631
|
+
|
|
1632
|
+
By default, this will be "done".
|
|
1633
|
+
|
|
1634
|
+
If there is more than one done key in the environment, this function will raise an exception.
|
|
1635
|
+
"""
|
|
1636
|
+
done_keys = self.done_keys
|
|
1637
|
+
if len(done_keys) > 1:
|
|
1638
|
+
raise KeyError(
|
|
1639
|
+
"done_key requested but more than one key present in the environment"
|
|
1640
|
+
)
|
|
1641
|
+
return done_keys[0]
|
|
1642
|
+
|
|
1643
|
+
@property
|
|
1644
|
+
def full_done_spec(self) -> Composite:
|
|
1645
|
+
"""The full done spec.
|
|
1646
|
+
|
|
1647
|
+
``full_done_spec`` is a :class:`~torchrl.data.Composite`` instance
|
|
1648
|
+
that contains all the done entries.
|
|
1649
|
+
It can be used to generate fake data with a structure that mimics the
|
|
1650
|
+
one obtained at runtime.
|
|
1651
|
+
|
|
1652
|
+
Examples:
|
|
1653
|
+
>>> import gymnasium
|
|
1654
|
+
>>> from torchrl.envs import GymWrapper
|
|
1655
|
+
>>> env = GymWrapper(gymnasium.make("Pendulum-v1"))
|
|
1656
|
+
>>> env.full_done_spec
|
|
1657
|
+
Composite(
|
|
1658
|
+
done: Categorical(
|
|
1659
|
+
shape=torch.Size([1]),
|
|
1660
|
+
space=DiscreteBox(n=2),
|
|
1661
|
+
device=cpu,
|
|
1662
|
+
dtype=torch.bool,
|
|
1663
|
+
domain=discrete),
|
|
1664
|
+
truncated: Categorical(
|
|
1665
|
+
shape=torch.Size([1]),
|
|
1666
|
+
space=DiscreteBox(n=2),
|
|
1667
|
+
device=cpu,
|
|
1668
|
+
dtype=torch.bool,
|
|
1669
|
+
domain=discrete), device=cpu, shape=torch.Size([]))
|
|
1670
|
+
|
|
1671
|
+
"""
|
|
1672
|
+
return self.output_spec["full_done_spec"]
|
|
1673
|
+
|
|
1674
|
+
@full_done_spec.setter
|
|
1675
|
+
@_maybe_unlock
|
|
1676
|
+
def full_done_spec(self, spec: Composite) -> None:
|
|
1677
|
+
self.done_spec = spec.to(self.device) if self.device is not None else spec
|
|
1678
|
+
|
|
1679
|
+
# Done spec: done specs belong to output_spec
|
|
1680
|
+
@property
|
|
1681
|
+
def done_spec(self) -> TensorSpec:
|
|
1682
|
+
"""The ``done`` spec.
|
|
1683
|
+
|
|
1684
|
+
The ``done_spec`` is always stored as a composite spec.
|
|
1685
|
+
|
|
1686
|
+
If the done spec is provided as a simple spec, this will be returned.
|
|
1687
|
+
|
|
1688
|
+
>>> env.done_spec = Categorical(2, dtype=torch.bool)
|
|
1689
|
+
>>> env.done_spec
|
|
1690
|
+
Categorical(
|
|
1691
|
+
shape=torch.Size([]),
|
|
1692
|
+
space=DiscreteBox(n=2),
|
|
1693
|
+
device=cpu,
|
|
1694
|
+
dtype=torch.bool,
|
|
1695
|
+
domain=discrete)
|
|
1696
|
+
|
|
1697
|
+
If the done spec is provided as a composite spec and contains only one leaf,
|
|
1698
|
+
this function will return just the leaf.
|
|
1699
|
+
|
|
1700
|
+
>>> env.done_spec = Composite({"nested": {"done": Categorical(2, dtype=torch.bool)}})
|
|
1701
|
+
>>> env.done_spec
|
|
1702
|
+
Categorical(
|
|
1703
|
+
shape=torch.Size([]),
|
|
1704
|
+
space=DiscreteBox(n=2),
|
|
1705
|
+
device=cpu,
|
|
1706
|
+
dtype=torch.bool,
|
|
1707
|
+
domain=discrete)
|
|
1708
|
+
|
|
1709
|
+
If the done spec is provided as a composite spec and has more than one leaf,
|
|
1710
|
+
this function will return the whole spec.
|
|
1711
|
+
|
|
1712
|
+
>>> env.done_spec = Composite({"nested": {"done": Categorical(2, dtype=torch.bool), "another_done": Categorical(2, dtype=torch.bool)}})
|
|
1713
|
+
>>> env.done_spec
|
|
1714
|
+
Composite(
|
|
1715
|
+
nested: Composite(
|
|
1716
|
+
done: Categorical(
|
|
1717
|
+
shape=torch.Size([]),
|
|
1718
|
+
space=DiscreteBox(n=2),
|
|
1719
|
+
device=cpu,
|
|
1720
|
+
dtype=torch.bool,
|
|
1721
|
+
domain=discrete),
|
|
1722
|
+
another_done: Categorical(
|
|
1723
|
+
shape=torch.Size([]),
|
|
1724
|
+
space=DiscreteBox(n=2),
|
|
1725
|
+
device=cpu,
|
|
1726
|
+
dtype=torch.bool,
|
|
1727
|
+
domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
|
|
1728
|
+
|
|
1729
|
+
To always retrieve the full spec passed, use:
|
|
1730
|
+
|
|
1731
|
+
>>> env.output_spec["full_done_spec"]
|
|
1732
|
+
|
|
1733
|
+
This property is mutable.
|
|
1734
|
+
|
|
1735
|
+
Examples:
|
|
1736
|
+
>>> from torchrl.envs.libs.gym import GymEnv
|
|
1737
|
+
>>> env = GymEnv("Pendulum-v1")
|
|
1738
|
+
>>> env.done_spec
|
|
1739
|
+
Categorical(
|
|
1740
|
+
shape=torch.Size([1]),
|
|
1741
|
+
space=DiscreteBox(n=2),
|
|
1742
|
+
device=cpu,
|
|
1743
|
+
dtype=torch.bool,
|
|
1744
|
+
domain=discrete)
|
|
1745
|
+
"""
|
|
1746
|
+
done_spec = self.output_spec["full_done_spec"]
|
|
1747
|
+
return done_spec
|
|
1748
|
+
|
|
1749
|
+
@_maybe_unlock
|
|
1750
|
+
def _create_done_specs(self):
|
|
1751
|
+
"""Reads through the done specs and makes it so that it's complete.
|
|
1752
|
+
|
|
1753
|
+
If the done_specs contain only a ``"done"`` entry, a similar ``"terminated"`` entry is created.
|
|
1754
|
+
Same goes if only ``"terminated"`` key is present.
|
|
1755
|
+
|
|
1756
|
+
If none of ``"done"`` and ``"terminated"`` can be found and the spec is not
|
|
1757
|
+
empty, nothing is changed.
|
|
1758
|
+
|
|
1759
|
+
"""
|
|
1760
|
+
try:
|
|
1761
|
+
full_done_spec = self.output_spec["full_done_spec"]
|
|
1762
|
+
except KeyError:
|
|
1763
|
+
full_done_spec = Composite(
|
|
1764
|
+
shape=self.output_spec.shape, device=self.output_spec.device
|
|
1765
|
+
)
|
|
1766
|
+
full_done_spec["done"] = Categorical(
|
|
1767
|
+
n=2,
|
|
1768
|
+
shape=(*full_done_spec.shape, 1),
|
|
1769
|
+
dtype=torch.bool,
|
|
1770
|
+
device=self.device,
|
|
1771
|
+
)
|
|
1772
|
+
full_done_spec["terminated"] = Categorical(
|
|
1773
|
+
n=2,
|
|
1774
|
+
shape=(*full_done_spec.shape, 1),
|
|
1775
|
+
dtype=torch.bool,
|
|
1776
|
+
device=self.device,
|
|
1777
|
+
)
|
|
1778
|
+
self.output_spec["full_done_spec"] = full_done_spec
|
|
1779
|
+
return
|
|
1780
|
+
|
|
1781
|
+
def check_local_done(spec):
|
|
1782
|
+
shape = None
|
|
1783
|
+
for key, item in list(
|
|
1784
|
+
spec.items()
|
|
1785
|
+
): # list to avoid error due to in-loop changes
|
|
1786
|
+
# in the case where the spec is non-empty and there is no done and no terminated, we do nothing
|
|
1787
|
+
if key == "done" and "terminated" not in spec.keys():
|
|
1788
|
+
spec["terminated"] = item.clone()
|
|
1789
|
+
elif key == "terminated" and "done" not in spec.keys():
|
|
1790
|
+
spec["done"] = item.clone()
|
|
1791
|
+
elif isinstance(item, Composite):
|
|
1792
|
+
check_local_done(item)
|
|
1793
|
+
else:
|
|
1794
|
+
if shape is None:
|
|
1795
|
+
shape = item.shape
|
|
1796
|
+
continue
|
|
1797
|
+
# checks that all shape match
|
|
1798
|
+
if shape != item.shape:
|
|
1799
|
+
raise ValueError(
|
|
1800
|
+
f"All shapes should match in done_spec {spec} (shape={shape}, key={key})."
|
|
1801
|
+
)
|
|
1802
|
+
|
|
1803
|
+
# if the spec is empty, we need to add a done and terminated manually
|
|
1804
|
+
if spec.is_empty():
|
|
1805
|
+
spec["done"] = Categorical(
|
|
1806
|
+
n=2, shape=(*spec.shape, 1), dtype=torch.bool, device=self.device
|
|
1807
|
+
)
|
|
1808
|
+
spec["terminated"] = Categorical(
|
|
1809
|
+
n=2, shape=(*spec.shape, 1), dtype=torch.bool, device=self.device
|
|
1810
|
+
)
|
|
1811
|
+
|
|
1812
|
+
if_locked = self.is_spec_locked
|
|
1813
|
+
if if_locked:
|
|
1814
|
+
self.is_spec_locked = False
|
|
1815
|
+
check_local_done(full_done_spec)
|
|
1816
|
+
self.output_spec["full_done_spec"] = full_done_spec
|
|
1817
|
+
if if_locked:
|
|
1818
|
+
self.is_spec_locked = True
|
|
1819
|
+
return
|
|
1820
|
+
|
|
1821
|
+
@done_spec.setter
|
|
1822
|
+
@_maybe_unlock
|
|
1823
|
+
def done_spec(self, value: TensorSpec) -> None:
|
|
1824
|
+
device = self.output_spec.device
|
|
1825
|
+
if not hasattr(value, "shape"):
|
|
1826
|
+
raise TypeError(
|
|
1827
|
+
f"done_spec of type {type(value)} do not have a shape " f"attribute."
|
|
1828
|
+
)
|
|
1829
|
+
if value.shape[: len(self.batch_size)] != self.batch_size:
|
|
1830
|
+
raise ValueError(
|
|
1831
|
+
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
|
|
1832
|
+
)
|
|
1833
|
+
if not isinstance(value, Composite):
|
|
1834
|
+
value = Composite(
|
|
1835
|
+
done=value.to(device),
|
|
1836
|
+
terminated=value.to(device),
|
|
1837
|
+
shape=self.batch_size,
|
|
1838
|
+
device=device,
|
|
1839
|
+
)
|
|
1840
|
+
for leaf in value.values(True, True):
|
|
1841
|
+
if len(leaf.shape) == 0:
|
|
1842
|
+
raise RuntimeError(
|
|
1843
|
+
"the done_spec's leaves shape cannot be empty (this error"
|
|
1844
|
+
" usually comes from trying to set a reward_spec"
|
|
1845
|
+
" with a null number of dimensions. Try using a multidimensional"
|
|
1846
|
+
" spec instead, for instance with a singleton dimension at the tail)."
|
|
1847
|
+
)
|
|
1848
|
+
self.output_spec["full_done_spec"] = value.to(device)
|
|
1849
|
+
self._create_done_specs()
|
|
1850
|
+
|
|
1851
|
+
# observation spec: observation specs belong to output_spec
|
|
1852
|
+
@property
|
|
1853
|
+
def observation_spec(self) -> Composite:
|
|
1854
|
+
"""Observation spec.
|
|
1855
|
+
|
|
1856
|
+
Must be a :class:`torchrl.data.Composite` instance.
|
|
1857
|
+
The keys listed in the spec are directly accessible after reset and step.
|
|
1858
|
+
|
|
1859
|
+
In TorchRL, even though they are not properly speaking "observations"
|
|
1860
|
+
all info, states, results of transforms etc. outputs from the environment are stored in the
|
|
1861
|
+
``observation_spec``.
|
|
1862
|
+
|
|
1863
|
+
Therefore, ``"observation_spec"`` should be thought as
|
|
1864
|
+
a generic data container for environment outputs that are not done or reward data.
|
|
1865
|
+
|
|
1866
|
+
Examples:
|
|
1867
|
+
>>> from torchrl.envs.libs.gym import GymEnv
|
|
1868
|
+
>>> env = GymEnv("Pendulum-v1")
|
|
1869
|
+
>>> env.observation_spec
|
|
1870
|
+
Composite(
|
|
1871
|
+
observation: BoundedContinuous(
|
|
1872
|
+
shape=torch.Size([3]),
|
|
1873
|
+
space=ContinuousBox(
|
|
1874
|
+
low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
1875
|
+
high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
1876
|
+
device=cpu,
|
|
1877
|
+
dtype=torch.float32,
|
|
1878
|
+
domain=continuous), device=cpu, shape=torch.Size([]))
|
|
1879
|
+
|
|
1880
|
+
"""
|
|
1881
|
+
observation_spec = self.output_spec.get("full_observation_spec", default=None)
|
|
1882
|
+
if observation_spec is None:
|
|
1883
|
+
is_locked = self.is_spec_locked
|
|
1884
|
+
if is_locked:
|
|
1885
|
+
self.set_spec_lock_(False)
|
|
1886
|
+
observation_spec = Composite(shape=self.batch_size, device=self.device)
|
|
1887
|
+
self.output_spec["full_observation_spec"] = observation_spec
|
|
1888
|
+
if is_locked:
|
|
1889
|
+
self.set_spec_lock_(True)
|
|
1890
|
+
|
|
1891
|
+
return observation_spec
|
|
1892
|
+
|
|
1893
|
+
@observation_spec.setter
|
|
1894
|
+
@_maybe_unlock
|
|
1895
|
+
def observation_spec(self, value: TensorSpec) -> None:
|
|
1896
|
+
if not isinstance(value, Composite):
|
|
1897
|
+
value = Composite(
|
|
1898
|
+
observation=value,
|
|
1899
|
+
device=self.device,
|
|
1900
|
+
batch_size=self.output_spec.batch_size,
|
|
1901
|
+
)
|
|
1902
|
+
elif value.shape[: len(self.batch_size)] != self.batch_size:
|
|
1903
|
+
raise ValueError(
|
|
1904
|
+
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
|
|
1905
|
+
)
|
|
1906
|
+
if value.shape[: len(self.batch_size)] != self.batch_size:
|
|
1907
|
+
raise ValueError(
|
|
1908
|
+
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
|
|
1909
|
+
)
|
|
1910
|
+
device = self.output_spec._device
|
|
1911
|
+
self.output_spec["full_observation_spec"] = (
|
|
1912
|
+
value.to(device) if device is not None else value
|
|
1913
|
+
)
|
|
1914
|
+
|
|
1915
|
+
@property
|
|
1916
|
+
def full_observation_spec(self) -> Composite:
|
|
1917
|
+
return self.observation_spec
|
|
1918
|
+
|
|
1919
|
+
@full_observation_spec.setter
|
|
1920
|
+
@_maybe_unlock
|
|
1921
|
+
def full_observation_spec(self, spec: Composite):
|
|
1922
|
+
self.observation_spec = spec
|
|
1923
|
+
|
|
1924
|
+
# state spec: state specs belong to input_spec
|
|
1925
|
+
@property
|
|
1926
|
+
def state_spec(self) -> Composite:
|
|
1927
|
+
"""State spec.
|
|
1928
|
+
|
|
1929
|
+
Must be a :class:`torchrl.data.Composite` instance.
|
|
1930
|
+
The keys listed here should be provided as input alongside actions to the environment.
|
|
1931
|
+
|
|
1932
|
+
In TorchRL, even though they are not properly speaking "state"
|
|
1933
|
+
all inputs to the environment that are not actions are stored in the
|
|
1934
|
+
``state_spec``.
|
|
1935
|
+
|
|
1936
|
+
Therefore, ``"state_spec"`` should be thought as
|
|
1937
|
+
a generic data container for environment inputs that are not action data.
|
|
1938
|
+
|
|
1939
|
+
Examples:
|
|
1940
|
+
>>> from torchrl.envs import BraxEnv
|
|
1941
|
+
>>> for envname in BraxEnv.available_envs:
|
|
1942
|
+
... break
|
|
1943
|
+
>>> env = BraxEnv(envname)
|
|
1944
|
+
>>> env.state_spec
|
|
1945
|
+
Composite(
|
|
1946
|
+
state: Composite(
|
|
1947
|
+
pipeline_state: Composite(
|
|
1948
|
+
q: UnboundedContinuous(
|
|
1949
|
+
shape=torch.Size([15]),
|
|
1950
|
+
space=None,
|
|
1951
|
+
device=cpu,
|
|
1952
|
+
dtype=torch.float32,
|
|
1953
|
+
domain=continuous),
|
|
1954
|
+
[...], device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
|
|
1955
|
+
|
|
1956
|
+
|
|
1957
|
+
"""
|
|
1958
|
+
state_spec = self.input_spec["full_state_spec"]
|
|
1959
|
+
if state_spec is None:
|
|
1960
|
+
is_locked = self.is_spec_locked
|
|
1961
|
+
if is_locked:
|
|
1962
|
+
self.set_spec_lock_(False)
|
|
1963
|
+
state_spec = Composite(shape=self.batch_size, device=self.device)
|
|
1964
|
+
self.input_spec["full_state_spec"] = state_spec
|
|
1965
|
+
if is_locked:
|
|
1966
|
+
self.set_spec_lock_(True)
|
|
1967
|
+
return state_spec
|
|
1968
|
+
|
|
1969
|
+
@state_spec.setter
|
|
1970
|
+
@_maybe_unlock
|
|
1971
|
+
def state_spec(self, value: Composite) -> None:
|
|
1972
|
+
if value is None:
|
|
1973
|
+
self.input_spec["full_state_spec"] = Composite(
|
|
1974
|
+
device=self.device, shape=self.batch_size
|
|
1975
|
+
)
|
|
1976
|
+
else:
|
|
1977
|
+
device = self.input_spec.device
|
|
1978
|
+
if not isinstance(value, Composite):
|
|
1979
|
+
raise TypeError("The type of an state_spec must be Composite.")
|
|
1980
|
+
elif value.shape[: len(self.batch_size)] != self.batch_size:
|
|
1981
|
+
raise ValueError(
|
|
1982
|
+
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
|
|
1983
|
+
)
|
|
1984
|
+
if value.shape[: len(self.batch_size)] != self.batch_size:
|
|
1985
|
+
raise ValueError(
|
|
1986
|
+
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
|
|
1987
|
+
)
|
|
1988
|
+
self.input_spec["full_state_spec"] = (
|
|
1989
|
+
value.to(device) if device is not None else value
|
|
1990
|
+
)
|
|
1991
|
+
|
|
1992
|
+
@property
|
|
1993
|
+
def full_state_spec(self) -> Composite:
|
|
1994
|
+
"""The full state spec.
|
|
1995
|
+
|
|
1996
|
+
``full_state_spec`` is a :class:`~torchrl.data.Composite`` instance
|
|
1997
|
+
that contains all the state entries (ie, the input data that is not action).
|
|
1998
|
+
|
|
1999
|
+
Examples:
|
|
2000
|
+
>>> from torchrl.envs import BraxEnv
|
|
2001
|
+
>>> for envname in BraxEnv.available_envs:
|
|
2002
|
+
... break
|
|
2003
|
+
>>> env = BraxEnv(envname)
|
|
2004
|
+
>>> env.full_state_spec
|
|
2005
|
+
Composite(
|
|
2006
|
+
state: Composite(
|
|
2007
|
+
pipeline_state: Composite(
|
|
2008
|
+
q: UnboundedContinuous(
|
|
2009
|
+
shape=torch.Size([15]),
|
|
2010
|
+
space=None,
|
|
2011
|
+
device=cpu,
|
|
2012
|
+
dtype=torch.float32,
|
|
2013
|
+
domain=continuous),
|
|
2014
|
+
[...], device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
|
|
2015
|
+
|
|
2016
|
+
"""
|
|
2017
|
+
return self.state_spec
|
|
2018
|
+
|
|
2019
|
+
@full_state_spec.setter
|
|
2020
|
+
@_maybe_unlock
|
|
2021
|
+
def full_state_spec(self, spec: Composite) -> None:
|
|
2022
|
+
self.state_spec = spec
|
|
2023
|
+
|
|
2024
|
+
# Single-env specs can be used to remove the batch size from the spec
|
|
2025
|
+
@property
|
|
2026
|
+
def batch_dims(self) -> int:
|
|
2027
|
+
"""Number of batch dimensions of the env."""
|
|
2028
|
+
return len(self.batch_size)
|
|
2029
|
+
|
|
2030
|
+
def _make_single_env_spec(self, spec: TensorSpec) -> TensorSpec:
|
|
2031
|
+
if not self.batch_dims:
|
|
2032
|
+
return spec
|
|
2033
|
+
idx = tuple(0 for _ in range(self.batch_dims))
|
|
2034
|
+
return spec[idx]
|
|
2035
|
+
|
|
2036
|
+
@property
|
|
2037
|
+
def full_action_spec_unbatched(self) -> Composite:
|
|
2038
|
+
"""Returns the action spec of the env as if it had no batch dimensions."""
|
|
2039
|
+
return self._make_single_env_spec(self.full_action_spec)
|
|
2040
|
+
|
|
2041
|
+
@full_action_spec_unbatched.setter
|
|
2042
|
+
@_maybe_unlock
|
|
2043
|
+
def full_action_spec_unbatched(self, spec: Composite):
|
|
2044
|
+
spec = spec.expand(self.batch_size + spec.shape)
|
|
2045
|
+
self.full_action_spec = spec
|
|
2046
|
+
|
|
2047
|
+
@property
|
|
2048
|
+
def action_spec_unbatched(self) -> TensorSpec:
|
|
2049
|
+
"""Returns the action spec of the env as if it had no batch dimensions."""
|
|
2050
|
+
return self._make_single_env_spec(self.action_spec)
|
|
2051
|
+
|
|
2052
|
+
@action_spec_unbatched.setter
|
|
2053
|
+
@_maybe_unlock
|
|
2054
|
+
def action_spec_unbatched(self, spec: Composite):
|
|
2055
|
+
spec = spec.expand(self.batch_size + spec.shape)
|
|
2056
|
+
self.action_spec = spec
|
|
2057
|
+
|
|
2058
|
+
@property
|
|
2059
|
+
def full_observation_spec_unbatched(self) -> Composite:
|
|
2060
|
+
"""Returns the observation spec of the env as if it had no batch dimensions."""
|
|
2061
|
+
return self._make_single_env_spec(self.full_observation_spec)
|
|
2062
|
+
|
|
2063
|
+
@full_observation_spec_unbatched.setter
|
|
2064
|
+
@_maybe_unlock
|
|
2065
|
+
def full_observation_spec_unbatched(self, spec: Composite):
|
|
2066
|
+
spec = spec.expand(self.batch_size + spec.shape)
|
|
2067
|
+
self.full_observation_spec = spec
|
|
2068
|
+
|
|
2069
|
+
@property
|
|
2070
|
+
def observation_spec_unbatched(self) -> Composite:
|
|
2071
|
+
"""Returns the observation spec of the env as if it had no batch dimensions."""
|
|
2072
|
+
return self._make_single_env_spec(self.observation_spec)
|
|
2073
|
+
|
|
2074
|
+
@observation_spec_unbatched.setter
|
|
2075
|
+
@_maybe_unlock
|
|
2076
|
+
def observation_spec_unbatched(self, spec: Composite):
|
|
2077
|
+
spec = spec.expand(self.batch_size + spec.shape)
|
|
2078
|
+
self.observation_spec = spec
|
|
2079
|
+
|
|
2080
|
+
@property
|
|
2081
|
+
def full_reward_spec_unbatched(self) -> Composite:
|
|
2082
|
+
"""Returns the reward spec of the env as if it had no batch dimensions."""
|
|
2083
|
+
return self._make_single_env_spec(self.full_reward_spec)
|
|
2084
|
+
|
|
2085
|
+
@full_reward_spec_unbatched.setter
|
|
2086
|
+
@_maybe_unlock
|
|
2087
|
+
def full_reward_spec_unbatched(self, spec: Composite):
|
|
2088
|
+
spec = spec.expand(self.batch_size + spec.shape)
|
|
2089
|
+
self.full_reward_spec = spec
|
|
2090
|
+
|
|
2091
|
+
@property
|
|
2092
|
+
def reward_spec_unbatched(self) -> TensorSpec:
|
|
2093
|
+
"""Returns the reward spec of the env as if it had no batch dimensions."""
|
|
2094
|
+
return self._make_single_env_spec(self.reward_spec)
|
|
2095
|
+
|
|
2096
|
+
@reward_spec_unbatched.setter
|
|
2097
|
+
@_maybe_unlock
|
|
2098
|
+
def reward_spec_unbatched(self, spec: Composite):
|
|
2099
|
+
spec = spec.expand(self.batch_size + spec.shape)
|
|
2100
|
+
self.reward_spec = spec
|
|
2101
|
+
|
|
2102
|
+
@property
|
|
2103
|
+
def full_done_spec_unbatched(self) -> Composite:
|
|
2104
|
+
"""Returns the done spec of the env as if it had no batch dimensions."""
|
|
2105
|
+
return self._make_single_env_spec(self.full_done_spec)
|
|
2106
|
+
|
|
2107
|
+
@full_done_spec_unbatched.setter
|
|
2108
|
+
@_maybe_unlock
|
|
2109
|
+
def full_done_spec_unbatched(self, spec: Composite):
|
|
2110
|
+
spec = spec.expand(self.batch_size + spec.shape)
|
|
2111
|
+
self.full_done_spec = spec
|
|
2112
|
+
|
|
2113
|
+
@property
|
|
2114
|
+
def done_spec_unbatched(self) -> TensorSpec:
|
|
2115
|
+
"""Returns the done spec of the env as if it had no batch dimensions."""
|
|
2116
|
+
return self._make_single_env_spec(self.done_spec)
|
|
2117
|
+
|
|
2118
|
+
@done_spec_unbatched.setter
|
|
2119
|
+
@_maybe_unlock
|
|
2120
|
+
def done_spec_unbatched(self, spec: Composite):
|
|
2121
|
+
spec = spec.expand(self.batch_size + spec.shape)
|
|
2122
|
+
self.done_spec = spec
|
|
2123
|
+
|
|
2124
|
+
@property
|
|
2125
|
+
def output_spec_unbatched(self) -> Composite:
|
|
2126
|
+
"""Returns the output spec of the env as if it had no batch dimensions."""
|
|
2127
|
+
return self._make_single_env_spec(self.output_spec)
|
|
2128
|
+
|
|
2129
|
+
@output_spec_unbatched.setter
|
|
2130
|
+
@_maybe_unlock
|
|
2131
|
+
def output_spec_unbatched(self, spec: Composite):
|
|
2132
|
+
spec = spec.expand(self.batch_size + spec.shape)
|
|
2133
|
+
self.output_spec = spec
|
|
2134
|
+
|
|
2135
|
+
@property
|
|
2136
|
+
def input_spec_unbatched(self) -> Composite:
|
|
2137
|
+
"""Returns the input spec of the env as if it had no batch dimensions."""
|
|
2138
|
+
return self._make_single_env_spec(self.input_spec)
|
|
2139
|
+
|
|
2140
|
+
@input_spec_unbatched.setter
|
|
2141
|
+
@_maybe_unlock
|
|
2142
|
+
def input_spec_unbatched(self, spec: Composite):
|
|
2143
|
+
spec = spec.expand(self.batch_size + spec.shape)
|
|
2144
|
+
self.input_spec = spec
|
|
2145
|
+
|
|
2146
|
+
@property
|
|
2147
|
+
def full_state_spec_unbatched(self) -> Composite:
|
|
2148
|
+
"""Returns the state spec of the env as if it had no batch dimensions."""
|
|
2149
|
+
return self._make_single_env_spec(self.full_state_spec)
|
|
2150
|
+
|
|
2151
|
+
@full_state_spec_unbatched.setter
|
|
2152
|
+
@_maybe_unlock
|
|
2153
|
+
def full_state_spec_unbatched(self, spec: Composite):
|
|
2154
|
+
spec = spec.expand(self.batch_size + spec.shape)
|
|
2155
|
+
self.full_state_spec = spec
|
|
2156
|
+
|
|
2157
|
+
@property
|
|
2158
|
+
def state_spec_unbatched(self) -> TensorSpec:
|
|
2159
|
+
"""Returns the state spec of the env as if it had no batch dimensions."""
|
|
2160
|
+
return self._make_single_env_spec(self.state_spec)
|
|
2161
|
+
|
|
2162
|
+
@state_spec_unbatched.setter
|
|
2163
|
+
@_maybe_unlock
|
|
2164
|
+
def state_spec_unbatched(self, spec: Composite):
|
|
2165
|
+
spec = spec.expand(self.batch_size + spec.shape)
|
|
2166
|
+
self.state_spec = spec
|
|
2167
|
+
|
|
2168
|
+
def _skip_tensordict(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
2169
|
+
# Creates a "skip" tensordict, ie a placeholder for when a step is skipped
|
|
2170
|
+
next_tensordict = self.full_done_spec.zero()
|
|
2171
|
+
next_tensordict.update(self.full_observation_spec.zero())
|
|
2172
|
+
next_tensordict.update(self.full_reward_spec.zero())
|
|
2173
|
+
|
|
2174
|
+
# Copy the data from tensordict in `next`
|
|
2175
|
+
keys = set()
|
|
2176
|
+
|
|
2177
|
+
def select_and_clone(name, x, y):
|
|
2178
|
+
keys.add(name)
|
|
2179
|
+
if y is not None:
|
|
2180
|
+
if y.device == x.device:
|
|
2181
|
+
return x.clone()
|
|
2182
|
+
return x.to(y.device)
|
|
2183
|
+
|
|
2184
|
+
result = tensordict._fast_apply(
|
|
2185
|
+
select_and_clone,
|
|
2186
|
+
next_tensordict,
|
|
2187
|
+
device=self.device,
|
|
2188
|
+
default=None,
|
|
2189
|
+
filter_empty=True,
|
|
2190
|
+
is_leaf=_is_leaf_nontensor,
|
|
2191
|
+
named=True,
|
|
2192
|
+
nested_keys=True,
|
|
2193
|
+
)
|
|
2194
|
+
result.update(next_tensordict.exclude(*keys).filter_empty_())
|
|
2195
|
+
return result
|
|
2196
|
+
|
|
2197
|
+
def step(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
2198
|
+
"""Makes a step in the environment.
|
|
2199
|
+
|
|
2200
|
+
Step accepts a single argument, tensordict, which usually carries an 'action' key which indicates the action
|
|
2201
|
+
to be taken.
|
|
2202
|
+
Step will call an out-place private method, _step, which is the method to be re-written by EnvBase subclasses.
|
|
2203
|
+
|
|
2204
|
+
Args:
|
|
2205
|
+
tensordict (TensorDictBase): Tensordict containing the action to be taken.
|
|
2206
|
+
If the input tensordict contains a ``"next"`` entry, the values contained in it
|
|
2207
|
+
will prevail over the newly computed values. This gives a mechanism
|
|
2208
|
+
to override the underlying computations.
|
|
2209
|
+
|
|
2210
|
+
Returns:
|
|
2211
|
+
the input tensordict, modified in place with the resulting observations, done state and reward
|
|
2212
|
+
(+ others if needed).
|
|
2213
|
+
|
|
2214
|
+
"""
|
|
2215
|
+
# sanity check
|
|
2216
|
+
self._assert_tensordict_shape(tensordict)
|
|
2217
|
+
partial_steps = tensordict.pop("_step", None)
|
|
2218
|
+
|
|
2219
|
+
next_tensordict = None
|
|
2220
|
+
|
|
2221
|
+
if partial_steps is not None:
|
|
2222
|
+
tensordict_batch_size = None
|
|
2223
|
+
if not self.batch_locked:
|
|
2224
|
+
# Batched envs have their own way of dealing with this - batched envs that are not batched-locked may fail here
|
|
2225
|
+
if partial_steps.all():
|
|
2226
|
+
partial_steps = None
|
|
2227
|
+
else:
|
|
2228
|
+
tensordict_batch_size = tensordict.batch_size
|
|
2229
|
+
partial_steps = partial_steps.view(tensordict_batch_size)
|
|
2230
|
+
tensordict = tensordict[partial_steps]
|
|
2231
|
+
else:
|
|
2232
|
+
if not partial_steps.any():
|
|
2233
|
+
next_tensordict = self._skip_tensordic(tensordict)
|
|
2234
|
+
else:
|
|
2235
|
+
# trust that the _step can handle this!
|
|
2236
|
+
tensordict.set("_step", partial_steps)
|
|
2237
|
+
if tensordict_batch_size is None:
|
|
2238
|
+
tensordict_batch_size = self.batch_size
|
|
2239
|
+
|
|
2240
|
+
next_preset = tensordict.get("next", None)
|
|
2241
|
+
|
|
2242
|
+
if next_tensordict is None:
|
|
2243
|
+
next_tensordict = self._step(tensordict)
|
|
2244
|
+
next_tensordict = self._step_proc_data(next_tensordict)
|
|
2245
|
+
if next_preset is not None:
|
|
2246
|
+
# tensordict could already have a "next" key
|
|
2247
|
+
# this could be done more efficiently by not excluding but just passing
|
|
2248
|
+
# the necessary keys
|
|
2249
|
+
next_tensordict.update(
|
|
2250
|
+
next_preset.exclude(*next_tensordict.keys(True, True))
|
|
2251
|
+
)
|
|
2252
|
+
tensordict.set("next", next_tensordict)
|
|
2253
|
+
if partial_steps is not None and tensordict_batch_size != self.batch_size:
|
|
2254
|
+
result = tensordict.new_zeros(tensordict_batch_size)
|
|
2255
|
+
|
|
2256
|
+
if tensordict_batch_size == tensordict.batch_size:
|
|
2257
|
+
|
|
2258
|
+
def select_and_clone(x, y):
|
|
2259
|
+
if y is not None:
|
|
2260
|
+
if x.device == y.device:
|
|
2261
|
+
return x.clone()
|
|
2262
|
+
return x.to(y.device)
|
|
2263
|
+
|
|
2264
|
+
result.update(
|
|
2265
|
+
tensordict._fast_apply(
|
|
2266
|
+
select_and_clone,
|
|
2267
|
+
result,
|
|
2268
|
+
device=result.device,
|
|
2269
|
+
filter_empty=True,
|
|
2270
|
+
default=None,
|
|
2271
|
+
batch_size=result.batch_size,
|
|
2272
|
+
is_leaf=_is_leaf_nontensor,
|
|
2273
|
+
)
|
|
2274
|
+
)
|
|
2275
|
+
if partial_steps.any():
|
|
2276
|
+
result[partial_steps] = tensordict
|
|
2277
|
+
return result
|
|
2278
|
+
return tensordict
|
|
2279
|
+
|
|
2280
|
+
@classmethod
|
|
2281
|
+
def _complete_done(
|
|
2282
|
+
cls, done_spec: Composite, data: TensorDictBase
|
|
2283
|
+
) -> TensorDictBase:
|
|
2284
|
+
"""Completes the data structure at step time to put missing done keys."""
|
|
2285
|
+
# by default, if a done key is missing, it is assumed that it is False
|
|
2286
|
+
# except in 2 cases: (1) there is a "done" but no "terminated" or (2)
|
|
2287
|
+
# there is a "terminated" but no "done".
|
|
2288
|
+
if done_spec.ndim:
|
|
2289
|
+
leading_dim = data.shape[: -done_spec.ndim]
|
|
2290
|
+
else:
|
|
2291
|
+
leading_dim = data.shape
|
|
2292
|
+
vals = {}
|
|
2293
|
+
i = -1
|
|
2294
|
+
for i, (key, item) in enumerate(done_spec.items()): # noqa: B007
|
|
2295
|
+
val = data.get(key, None)
|
|
2296
|
+
if isinstance(item, Composite):
|
|
2297
|
+
if val is not None:
|
|
2298
|
+
cls._complete_done(item, val)
|
|
2299
|
+
continue
|
|
2300
|
+
shape = (*leading_dim, *item.shape)
|
|
2301
|
+
if val is not None:
|
|
2302
|
+
if val.shape != shape:
|
|
2303
|
+
val = val.reshape(shape)
|
|
2304
|
+
data.set(key, val)
|
|
2305
|
+
vals[key] = val
|
|
2306
|
+
|
|
2307
|
+
if len(vals) < i + 1:
|
|
2308
|
+
# complete missing dones: we only want to do that if we don't have enough done values
|
|
2309
|
+
data_keys = set(data.keys())
|
|
2310
|
+
done_spec_keys = set(done_spec.keys())
|
|
2311
|
+
for key, item in done_spec.items(False, True):
|
|
2312
|
+
val = vals.get(key, None)
|
|
2313
|
+
if (
|
|
2314
|
+
key == "done"
|
|
2315
|
+
and val is not None
|
|
2316
|
+
and "terminated" in done_spec_keys
|
|
2317
|
+
and "terminated" not in data_keys
|
|
2318
|
+
):
|
|
2319
|
+
if "truncated" in data_keys:
|
|
2320
|
+
raise RuntimeError(
|
|
2321
|
+
"Cannot infer the value of terminated when only done and truncated are present."
|
|
2322
|
+
)
|
|
2323
|
+
data.set("terminated", val)
|
|
2324
|
+
data_keys.add("terminated")
|
|
2325
|
+
elif (
|
|
2326
|
+
key == "terminated"
|
|
2327
|
+
and val is not None
|
|
2328
|
+
and "done" in done_spec_keys
|
|
2329
|
+
and "done" not in data_keys
|
|
2330
|
+
):
|
|
2331
|
+
if "truncated" in data_keys:
|
|
2332
|
+
val = val | data.get("truncated")
|
|
2333
|
+
data.set("done", val)
|
|
2334
|
+
data_keys.add("done")
|
|
2335
|
+
elif val is None and key not in data_keys:
|
|
2336
|
+
# we must keep this here: we only want to fill with 0s if we're sure
|
|
2337
|
+
# done should not be copied to terminated or terminated to done
|
|
2338
|
+
# in this case, just fill with 0s
|
|
2339
|
+
data.set(key, item.zero(leading_dim))
|
|
2340
|
+
return data
|
|
2341
|
+
|
|
2342
|
+
def _step_proc_data(self, next_tensordict_out):
|
|
2343
|
+
batch_size = self.batch_size
|
|
2344
|
+
dims = len(batch_size)
|
|
2345
|
+
leading_batch_size = (
|
|
2346
|
+
next_tensordict_out.batch_size[:-dims]
|
|
2347
|
+
if dims
|
|
2348
|
+
else next_tensordict_out.shape
|
|
2349
|
+
)
|
|
2350
|
+
for reward_key in self.reward_keys:
|
|
2351
|
+
expected_reward_shape = torch.Size(
|
|
2352
|
+
[
|
|
2353
|
+
*leading_batch_size,
|
|
2354
|
+
*self.output_spec["full_reward_spec"][reward_key].shape,
|
|
2355
|
+
]
|
|
2356
|
+
)
|
|
2357
|
+
# If the reward has a variable shape, we don't want to perform this check
|
|
2358
|
+
if all(s > 0 for s in expected_reward_shape):
|
|
2359
|
+
reward = next_tensordict_out.get(reward_key)
|
|
2360
|
+
actual_reward_shape = reward.shape
|
|
2361
|
+
if actual_reward_shape != expected_reward_shape:
|
|
2362
|
+
reward = reward.view(expected_reward_shape)
|
|
2363
|
+
next_tensordict_out.set(reward_key, reward)
|
|
2364
|
+
|
|
2365
|
+
self._complete_done(self.full_done_spec, next_tensordict_out)
|
|
2366
|
+
|
|
2367
|
+
if self.run_type_checks:
|
|
2368
|
+
for key, spec in self.observation_spec.items():
|
|
2369
|
+
obs = next_tensordict_out.get(key)
|
|
2370
|
+
spec.type_check(obs)
|
|
2371
|
+
|
|
2372
|
+
for reward_key in self.reward_keys:
|
|
2373
|
+
if (
|
|
2374
|
+
next_tensordict_out.get(reward_key).dtype
|
|
2375
|
+
is not self.output_spec[
|
|
2376
|
+
unravel_key(("full_reward_spec", reward_key))
|
|
2377
|
+
].dtype
|
|
2378
|
+
):
|
|
2379
|
+
raise TypeError(
|
|
2380
|
+
f"expected reward.dtype to be {self.output_spec[unravel_key(('full_reward_spec',reward_key))]} "
|
|
2381
|
+
f"but got {next_tensordict_out.get(reward_key).dtype}"
|
|
2382
|
+
)
|
|
2383
|
+
|
|
2384
|
+
for done_key in self.done_keys:
|
|
2385
|
+
if (
|
|
2386
|
+
next_tensordict_out.get(done_key).dtype
|
|
2387
|
+
is not self.output_spec["full_done_spec", done_key].dtype
|
|
2388
|
+
):
|
|
2389
|
+
raise TypeError(
|
|
2390
|
+
f"expected done.dtype to be {self.output_spec['full_done_spec', done_key].dtype} but got {next_tensordict_out.get(done_key).dtype}"
|
|
2391
|
+
)
|
|
2392
|
+
return next_tensordict_out
|
|
2393
|
+
|
|
2394
|
+
def _get_in_keys_to_exclude(self, tensordict):
|
|
2395
|
+
if self._cache_in_keys is None:
|
|
2396
|
+
self._cache_in_keys = list(
|
|
2397
|
+
set(self.input_spec.keys(True)).intersection(
|
|
2398
|
+
tensordict.keys(True, True)
|
|
2399
|
+
)
|
|
2400
|
+
)
|
|
2401
|
+
return self._cache_in_keys
|
|
2402
|
+
|
|
2403
|
+
@classmethod
|
|
2404
|
+
def register_gym(
|
|
2405
|
+
cls,
|
|
2406
|
+
id: str,
|
|
2407
|
+
*,
|
|
2408
|
+
entry_point: Callable | None = None,
|
|
2409
|
+
transform: Transform | None = None, # noqa: F821
|
|
2410
|
+
info_keys: list[NestedKey] | None = None,
|
|
2411
|
+
backend: str | None = None,
|
|
2412
|
+
to_numpy: bool = False,
|
|
2413
|
+
reward_threshold: float | None = None,
|
|
2414
|
+
nondeterministic: bool = False,
|
|
2415
|
+
max_episode_steps: int | None = None,
|
|
2416
|
+
order_enforce: bool = True,
|
|
2417
|
+
autoreset: bool | None = None,
|
|
2418
|
+
disable_env_checker: bool = False,
|
|
2419
|
+
apply_api_compatibility: bool = False,
|
|
2420
|
+
**kwargs,
|
|
2421
|
+
):
|
|
2422
|
+
"""Registers an environment in gym(nasium).
|
|
2423
|
+
|
|
2424
|
+
This method is designed with the following scopes in mind:
|
|
2425
|
+
|
|
2426
|
+
- Incorporate a TorchRL-first environment in a framework that uses Gym;
|
|
2427
|
+
- Incorporate another environment (eg, DeepMind Control, Brax, Jumanji, ...)
|
|
2428
|
+
in a framework that uses Gym.
|
|
2429
|
+
|
|
2430
|
+
Args:
|
|
2431
|
+
id (str): the name of the environment. Should follow the
|
|
2432
|
+
`gym naming convention <https://www.gymlibrary.dev/content/environment_creation/#registering-envs>`_.
|
|
2433
|
+
|
|
2434
|
+
Keyword Args:
|
|
2435
|
+
entry_point (callable, optional): the entry point to build the environment.
|
|
2436
|
+
If none is passed, the parent class will be used as entry point.
|
|
2437
|
+
Typically, this is used to register an environment that does not
|
|
2438
|
+
necessarily inherit from the base being used:
|
|
2439
|
+
|
|
2440
|
+
>>> from torchrl.envs import DMControlEnv
|
|
2441
|
+
>>> DMControlEnv.register_gym("DMC-cheetah-v0", env_name="cheetah", task="run")
|
|
2442
|
+
>>> # equivalently
|
|
2443
|
+
>>> EnvBase.register_gym("DMC-cheetah-v0", entry_point=DMControlEnv, env_name="cheetah", task="run")
|
|
2444
|
+
|
|
2445
|
+
transform (torchrl.envs.Transform): a transform (or list of transforms
|
|
2446
|
+
within a :class:`torchrl.envs.Compose` instance) to be used with the env.
|
|
2447
|
+
This arg can be passed during a call to :func:`~gym.make` (see
|
|
2448
|
+
example below).
|
|
2449
|
+
info_keys (List[NestedKey], optional): if provided, these keys will
|
|
2450
|
+
be used to build the info dictionary and will be excluded from
|
|
2451
|
+
the observation keys.
|
|
2452
|
+
This arg can be passed during a call to :func:`~gym.make` (see
|
|
2453
|
+
example below).
|
|
2454
|
+
|
|
2455
|
+
.. warning::
|
|
2456
|
+
It may be the case that using ``info_keys`` makes a spec empty
|
|
2457
|
+
because the content has been moved to the info dictionary.
|
|
2458
|
+
Gym does not like empty ``Dict`` in the specs, so this empty
|
|
2459
|
+
content should be removed with :class:`~torchrl.envs.transforms.RemoveEmptySpecs`.
|
|
2460
|
+
|
|
2461
|
+
backend (str, optional): the backend. Can be either `"gym"` or `"gymnasium"`
|
|
2462
|
+
or any other backend compatible with :class:`~torchrl.envs.libs.gym.set_gym_backend`.
|
|
2463
|
+
to_numpy (bool, optional): if ``True``, the result of calls to `step` and
|
|
2464
|
+
`reset` will be mapped to numpy arrays. Defaults to ``False``
|
|
2465
|
+
(results are tensors).
|
|
2466
|
+
This arg can be passed during a call to :func:`~gym.make` (see
|
|
2467
|
+
example below).
|
|
2468
|
+
reward_threshold (:obj:`float`, optional): [Gym kwarg] The reward threshold
|
|
2469
|
+
considered to have learnt an environment.
|
|
2470
|
+
nondeterministic (bool, optional): [Gym kwarg If the environment is nondeterministic
|
|
2471
|
+
(even with knowledge of the initial seed and all actions). Defaults to
|
|
2472
|
+
``False``.
|
|
2473
|
+
max_episode_steps (int, optional): [Gym kwarg] The maximum number
|
|
2474
|
+
of episodes steps before truncation. Used by the Time Limit wrapper.
|
|
2475
|
+
order_enforce (bool, optional): [Gym >= 0.14] Whether the order
|
|
2476
|
+
enforcer wrapper should be applied to ensure users run functions
|
|
2477
|
+
in the correct order.
|
|
2478
|
+
Defaults to ``True``.
|
|
2479
|
+
autoreset (bool, optional): [Gym >= 0.14 and <1.0.0] Whether the autoreset wrapper
|
|
2480
|
+
should be added such that reset does not need to be called.
|
|
2481
|
+
Defaults to ``False``.
|
|
2482
|
+
disable_env_checker: [Gym >= 0.14] Whether the environment
|
|
2483
|
+
checker should be disabled for the environment. Defaults to ``False``.
|
|
2484
|
+
apply_api_compatibility: [Gym >= 0.26 and <1.0.0] If to apply the `StepAPICompatibility` wrapper.
|
|
2485
|
+
Defaults to ``False``.
|
|
2486
|
+
**kwargs: arbitrary keyword arguments which are passed to the environment constructor.
|
|
2487
|
+
|
|
2488
|
+
.. note::
|
|
2489
|
+
TorchRL's environment do not have the concept of an ``"info"`` dictionary,
|
|
2490
|
+
as ``TensorDict`` offers all the storage requirements deemed necessary
|
|
2491
|
+
in most training settings. Still, you can use the ``info_keys`` argument to
|
|
2492
|
+
have a fine grained control over what is deemed to be considered
|
|
2493
|
+
as an observation and what should be seen as info.
|
|
2494
|
+
|
|
2495
|
+
Examples:
|
|
2496
|
+
>>> # Register the "cheetah" env from DMControl with the "run" task
|
|
2497
|
+
>>> from torchrl.envs import DMControlEnv
|
|
2498
|
+
>>> import torch
|
|
2499
|
+
>>> DMControlEnv.register_gym("DMC-cheetah-v0", to_numpy=False, backend="gym", env_name="cheetah", task_name="run")
|
|
2500
|
+
>>> import gym
|
|
2501
|
+
>>> envgym = gym.make("DMC-cheetah-v0")
|
|
2502
|
+
>>> envgym.seed(0)
|
|
2503
|
+
>>> torch.manual_seed(0)
|
|
2504
|
+
>>> envgym.reset()
|
|
2505
|
+
({'position': tensor([-0.0855, 0.0215, -0.0881, -0.0412, -0.1101, 0.0080, 0.0254, 0.0424],
|
|
2506
|
+
dtype=torch.float64), 'velocity': tensor([ 1.9609e-02, -1.9776e-04, -1.6347e-03, 3.3842e-02, 2.5338e-02,
|
|
2507
|
+
3.3064e-02, 1.0381e-04, 7.6656e-05, 1.0204e-02],
|
|
2508
|
+
dtype=torch.float64)}, {})
|
|
2509
|
+
>>> envgym.step(envgym.action_space.sample())
|
|
2510
|
+
({'position': tensor([-0.0833, 0.0275, -0.0612, -0.0770, -0.1256, 0.0082, 0.0186, 0.0476],
|
|
2511
|
+
dtype=torch.float64), 'velocity': tensor([ 0.2221, 0.2256, 0.5930, 2.6937, -3.5865, -1.5479, 0.0187, -0.6825,
|
|
2512
|
+
0.5224], dtype=torch.float64)}, tensor([0.0018], dtype=torch.float64), tensor([False]), tensor([False]), {})
|
|
2513
|
+
>>> # same environment with observation stacked
|
|
2514
|
+
>>> from torchrl.envs import CatTensors
|
|
2515
|
+
>>> envgym = gym.make("DMC-cheetah-v0", transform=CatTensors(in_keys=["position", "velocity"], out_key="observation"))
|
|
2516
|
+
>>> envgym.reset()
|
|
2517
|
+
({'observation': tensor([-0.1005, 0.0335, -0.0268, 0.0133, -0.0627, 0.0074, -0.0488, -0.0353,
|
|
2518
|
+
-0.0075, -0.0069, 0.0098, -0.0058, 0.0033, -0.0157, -0.0004, -0.0381,
|
|
2519
|
+
-0.0452], dtype=torch.float64)}, {})
|
|
2520
|
+
>>> # same environment with numpy observations
|
|
2521
|
+
>>> envgym = gym.make("DMC-cheetah-v0", transform=CatTensors(in_keys=["position", "velocity"], out_key="observation"), to_numpy=True)
|
|
2522
|
+
>>> envgym.reset()
|
|
2523
|
+
({'observation': array([-0.11355747, 0.04257728, 0.00408397, 0.04155852, -0.0389733 ,
|
|
2524
|
+
-0.01409826, -0.0978704 , -0.08808327, 0.03970837, 0.00535434,
|
|
2525
|
+
-0.02353762, 0.05116226, 0.02788907, 0.06848346, 0.05154399,
|
|
2526
|
+
0.0371798 , 0.05128025])}, {})
|
|
2527
|
+
>>> # If gymnasium is installed, we can register the environment there too.
|
|
2528
|
+
>>> DMControlEnv.register_gym("DMC-cheetah-v0", to_numpy=False, backend="gymnasium", env_name="cheetah", task_name="run")
|
|
2529
|
+
>>> import gymnasium
|
|
2530
|
+
>>> envgym = gymnasium.make("DMC-cheetah-v0")
|
|
2531
|
+
>>> envgym.seed(0)
|
|
2532
|
+
>>> torch.manual_seed(0)
|
|
2533
|
+
>>> envgym.reset()
|
|
2534
|
+
({'position': tensor([-0.0855, 0.0215, -0.0881, -0.0412, -0.1101, 0.0080, 0.0254, 0.0424],
|
|
2535
|
+
dtype=torch.float64), 'velocity': tensor([ 1.9609e-02, -1.9776e-04, -1.6347e-03, 3.3842e-02, 2.5338e-02,
|
|
2536
|
+
3.3064e-02, 1.0381e-04, 7.6656e-05, 1.0204e-02],
|
|
2537
|
+
dtype=torch.float64)}, {})
|
|
2538
|
+
|
|
2539
|
+
.. note::
|
|
2540
|
+
This feature also works for stateless environments (eg, :class:`~torchrl.envs.BraxEnv`).
|
|
2541
|
+
|
|
2542
|
+
>>> import gymnasium
|
|
2543
|
+
>>> import torch
|
|
2544
|
+
>>> from tensordict import TensorDict
|
|
2545
|
+
>>> from torchrl.envs import BraxEnv, SelectTransform
|
|
2546
|
+
>>>
|
|
2547
|
+
>>> # get action for dydactic purposes
|
|
2548
|
+
>>> env = BraxEnv("ant", batch_size=[2])
|
|
2549
|
+
>>> env.set_seed(0)
|
|
2550
|
+
>>> torch.manual_seed(0)
|
|
2551
|
+
>>> td = env.rollout(10)
|
|
2552
|
+
>>>
|
|
2553
|
+
>>> actions = td.get("action")
|
|
2554
|
+
>>>
|
|
2555
|
+
>>> # register env
|
|
2556
|
+
>>> env.register_gym("Brax-Ant-v0", env_name="ant", batch_size=[2], info_keys=["state"])
|
|
2557
|
+
>>> gym_env = gymnasium.make("Brax-Ant-v0")
|
|
2558
|
+
>>> gym_env.seed(0)
|
|
2559
|
+
>>> torch.manual_seed(0)
|
|
2560
|
+
>>>
|
|
2561
|
+
>>> gym_env.reset()
|
|
2562
|
+
>>> obs = []
|
|
2563
|
+
>>> for i in range(10):
|
|
2564
|
+
... obs, reward, terminated, truncated, info = gym_env.step(td[..., i].get("action"))
|
|
2565
|
+
|
|
2566
|
+
|
|
2567
|
+
"""
|
|
2568
|
+
from torchrl.envs.libs.gym import gym_backend, set_gym_backend
|
|
2569
|
+
|
|
2570
|
+
if backend is None:
|
|
2571
|
+
backend = gym_backend()
|
|
2572
|
+
|
|
2573
|
+
with set_gym_backend(backend):
|
|
2574
|
+
return cls._register_gym(
|
|
2575
|
+
id=id,
|
|
2576
|
+
entry_point=entry_point,
|
|
2577
|
+
transform=transform,
|
|
2578
|
+
info_keys=info_keys,
|
|
2579
|
+
to_numpy=to_numpy,
|
|
2580
|
+
reward_threshold=reward_threshold,
|
|
2581
|
+
nondeterministic=nondeterministic,
|
|
2582
|
+
max_episode_steps=max_episode_steps,
|
|
2583
|
+
order_enforce=order_enforce,
|
|
2584
|
+
autoreset=autoreset,
|
|
2585
|
+
disable_env_checker=disable_env_checker,
|
|
2586
|
+
apply_api_compatibility=apply_api_compatibility,
|
|
2587
|
+
**kwargs,
|
|
2588
|
+
)
|
|
2589
|
+
|
|
2590
|
+
_GYM_UNRECOGNIZED_KWARG = (
|
|
2591
|
+
"The keyword argument {} is not compatible with gym version {}"
|
|
2592
|
+
)
|
|
2593
|
+
|
|
2594
|
+
@implement_for("gym", "0.26", None, class_method=True)
|
|
2595
|
+
def _register_gym(
|
|
2596
|
+
cls,
|
|
2597
|
+
id,
|
|
2598
|
+
entry_point: Callable | None = None,
|
|
2599
|
+
transform: Transform | None = None, # noqa: F821
|
|
2600
|
+
info_keys: list[NestedKey] | None = None,
|
|
2601
|
+
to_numpy: bool = False,
|
|
2602
|
+
reward_threshold: float | None = None,
|
|
2603
|
+
nondeterministic: bool = False,
|
|
2604
|
+
max_episode_steps: int | None = None,
|
|
2605
|
+
order_enforce: bool = True,
|
|
2606
|
+
autoreset: bool | None = None,
|
|
2607
|
+
disable_env_checker: bool = False,
|
|
2608
|
+
apply_api_compatibility: bool = False,
|
|
2609
|
+
**kwargs,
|
|
2610
|
+
):
|
|
2611
|
+
import gym
|
|
2612
|
+
from torchrl.envs.libs._gym_utils import _TorchRLGymWrapper
|
|
2613
|
+
|
|
2614
|
+
if entry_point is None:
|
|
2615
|
+
entry_point = cls
|
|
2616
|
+
entry_point = partial(
|
|
2617
|
+
_TorchRLGymWrapper,
|
|
2618
|
+
entry_point=entry_point,
|
|
2619
|
+
info_keys=info_keys,
|
|
2620
|
+
to_numpy=to_numpy,
|
|
2621
|
+
transform=transform,
|
|
2622
|
+
**kwargs,
|
|
2623
|
+
)
|
|
2624
|
+
return gym.register(
|
|
2625
|
+
id=id,
|
|
2626
|
+
entry_point=entry_point,
|
|
2627
|
+
reward_threshold=reward_threshold,
|
|
2628
|
+
nondeterministic=nondeterministic,
|
|
2629
|
+
max_episode_steps=max_episode_steps,
|
|
2630
|
+
order_enforce=order_enforce,
|
|
2631
|
+
autoreset=bool(autoreset),
|
|
2632
|
+
disable_env_checker=disable_env_checker,
|
|
2633
|
+
apply_api_compatibility=apply_api_compatibility,
|
|
2634
|
+
)
|
|
2635
|
+
|
|
2636
|
+
@implement_for("gym", "0.25", "0.26", class_method=True)
|
|
2637
|
+
def _register_gym( # noqa: F811
|
|
2638
|
+
cls,
|
|
2639
|
+
id,
|
|
2640
|
+
entry_point: Callable | None = None,
|
|
2641
|
+
transform: Transform | None = None, # noqa: F821
|
|
2642
|
+
info_keys: list[NestedKey] | None = None,
|
|
2643
|
+
to_numpy: bool = False,
|
|
2644
|
+
reward_threshold: float | None = None,
|
|
2645
|
+
nondeterministic: bool = False,
|
|
2646
|
+
max_episode_steps: int | None = None,
|
|
2647
|
+
order_enforce: bool = True,
|
|
2648
|
+
autoreset: bool | None = None,
|
|
2649
|
+
disable_env_checker: bool = False,
|
|
2650
|
+
apply_api_compatibility: bool = False,
|
|
2651
|
+
**kwargs,
|
|
2652
|
+
):
|
|
2653
|
+
import gym
|
|
2654
|
+
|
|
2655
|
+
if apply_api_compatibility is not False:
|
|
2656
|
+
raise TypeError(
|
|
2657
|
+
cls._GYM_UNRECOGNIZED_KWARG.format(
|
|
2658
|
+
"apply_api_compatibility", gym.__version__
|
|
2659
|
+
)
|
|
2660
|
+
)
|
|
2661
|
+
from torchrl.envs.libs._gym_utils import _TorchRLGymWrapper
|
|
2662
|
+
|
|
2663
|
+
if entry_point is None:
|
|
2664
|
+
entry_point = cls
|
|
2665
|
+
entry_point = partial(
|
|
2666
|
+
_TorchRLGymWrapper,
|
|
2667
|
+
entry_point=entry_point,
|
|
2668
|
+
info_keys=info_keys,
|
|
2669
|
+
to_numpy=to_numpy,
|
|
2670
|
+
transform=transform,
|
|
2671
|
+
**kwargs,
|
|
2672
|
+
)
|
|
2673
|
+
return gym.register(
|
|
2674
|
+
id=id,
|
|
2675
|
+
entry_point=entry_point,
|
|
2676
|
+
reward_threshold=reward_threshold,
|
|
2677
|
+
nondeterministic=nondeterministic,
|
|
2678
|
+
max_episode_steps=max_episode_steps,
|
|
2679
|
+
order_enforce=order_enforce,
|
|
2680
|
+
autoreset=bool(autoreset),
|
|
2681
|
+
disable_env_checker=disable_env_checker,
|
|
2682
|
+
)
|
|
2683
|
+
|
|
2684
|
+
@implement_for("gym", "0.24", "0.25", class_method=True)
|
|
2685
|
+
def _register_gym( # noqa: F811
|
|
2686
|
+
cls,
|
|
2687
|
+
id,
|
|
2688
|
+
entry_point: Callable | None = None,
|
|
2689
|
+
transform: Transform | None = None, # noqa: F821
|
|
2690
|
+
info_keys: list[NestedKey] | None = None,
|
|
2691
|
+
to_numpy: bool = False,
|
|
2692
|
+
reward_threshold: float | None = None,
|
|
2693
|
+
nondeterministic: bool = False,
|
|
2694
|
+
max_episode_steps: int | None = None,
|
|
2695
|
+
order_enforce: bool = True,
|
|
2696
|
+
autoreset: bool | None = None,
|
|
2697
|
+
disable_env_checker: bool = False,
|
|
2698
|
+
apply_api_compatibility: bool = False,
|
|
2699
|
+
**kwargs,
|
|
2700
|
+
):
|
|
2701
|
+
import gym
|
|
2702
|
+
|
|
2703
|
+
if apply_api_compatibility is not False:
|
|
2704
|
+
raise TypeError(
|
|
2705
|
+
cls._GYM_UNRECOGNIZED_KWARG.format(
|
|
2706
|
+
"apply_api_compatibility", gym.__version__
|
|
2707
|
+
)
|
|
2708
|
+
)
|
|
2709
|
+
if disable_env_checker is not False:
|
|
2710
|
+
raise TypeError(
|
|
2711
|
+
cls._GYM_UNRECOGNIZED_KWARG.format(
|
|
2712
|
+
"disable_env_checker", gym.__version__
|
|
2713
|
+
)
|
|
2714
|
+
)
|
|
2715
|
+
from torchrl.envs.libs._gym_utils import _TorchRLGymWrapper
|
|
2716
|
+
|
|
2717
|
+
if entry_point is None:
|
|
2718
|
+
entry_point = cls
|
|
2719
|
+
entry_point = partial(
|
|
2720
|
+
_TorchRLGymWrapper,
|
|
2721
|
+
entry_point=entry_point,
|
|
2722
|
+
info_keys=info_keys,
|
|
2723
|
+
to_numpy=to_numpy,
|
|
2724
|
+
transform=transform,
|
|
2725
|
+
**kwargs,
|
|
2726
|
+
)
|
|
2727
|
+
return gym.register(
|
|
2728
|
+
id=id,
|
|
2729
|
+
entry_point=entry_point,
|
|
2730
|
+
reward_threshold=reward_threshold,
|
|
2731
|
+
nondeterministic=nondeterministic,
|
|
2732
|
+
max_episode_steps=max_episode_steps,
|
|
2733
|
+
order_enforce=order_enforce,
|
|
2734
|
+
autoreset=bool(autoreset),
|
|
2735
|
+
)
|
|
2736
|
+
|
|
2737
|
+
@implement_for("gym", "0.21", "0.24", class_method=True)
|
|
2738
|
+
def _register_gym( # noqa: F811
|
|
2739
|
+
cls,
|
|
2740
|
+
id,
|
|
2741
|
+
entry_point: Callable | None = None,
|
|
2742
|
+
transform: Transform | None = None, # noqa: F821
|
|
2743
|
+
info_keys: list[NestedKey] | None = None,
|
|
2744
|
+
to_numpy: bool = False,
|
|
2745
|
+
reward_threshold: float | None = None,
|
|
2746
|
+
nondeterministic: bool = False,
|
|
2747
|
+
max_episode_steps: int | None = None,
|
|
2748
|
+
order_enforce: bool = True,
|
|
2749
|
+
autoreset: bool | None = None,
|
|
2750
|
+
disable_env_checker: bool = False,
|
|
2751
|
+
apply_api_compatibility: bool = False,
|
|
2752
|
+
**kwargs,
|
|
2753
|
+
):
|
|
2754
|
+
import gym
|
|
2755
|
+
|
|
2756
|
+
if apply_api_compatibility is not False:
|
|
2757
|
+
raise TypeError(
|
|
2758
|
+
cls._GYM_UNRECOGNIZED_KWARG.format(
|
|
2759
|
+
"apply_api_compatibility", gym.__version__
|
|
2760
|
+
)
|
|
2761
|
+
)
|
|
2762
|
+
if disable_env_checker is not False:
|
|
2763
|
+
raise TypeError(
|
|
2764
|
+
cls._GYM_UNRECOGNIZED_KWARG.format(
|
|
2765
|
+
"disable_env_checker", gym.__version__
|
|
2766
|
+
)
|
|
2767
|
+
)
|
|
2768
|
+
if autoreset is not None:
|
|
2769
|
+
raise TypeError(
|
|
2770
|
+
cls._GYM_UNRECOGNIZED_KWARG.format("autoreset", gym.__version__)
|
|
2771
|
+
)
|
|
2772
|
+
from torchrl.envs.libs._gym_utils import _TorchRLGymWrapper
|
|
2773
|
+
|
|
2774
|
+
if entry_point is None:
|
|
2775
|
+
entry_point = cls
|
|
2776
|
+
entry_point = partial(
|
|
2777
|
+
_TorchRLGymWrapper,
|
|
2778
|
+
entry_point=entry_point,
|
|
2779
|
+
info_keys=info_keys,
|
|
2780
|
+
to_numpy=to_numpy,
|
|
2781
|
+
transform=transform,
|
|
2782
|
+
**kwargs,
|
|
2783
|
+
)
|
|
2784
|
+
return gym.register(
|
|
2785
|
+
id=id,
|
|
2786
|
+
entry_point=entry_point,
|
|
2787
|
+
reward_threshold=reward_threshold,
|
|
2788
|
+
nondeterministic=nondeterministic,
|
|
2789
|
+
max_episode_steps=max_episode_steps,
|
|
2790
|
+
order_enforce=order_enforce,
|
|
2791
|
+
autoreset=bool(autoreset),
|
|
2792
|
+
)
|
|
2793
|
+
|
|
2794
|
+
@implement_for("gym", None, "0.21", class_method=True)
|
|
2795
|
+
def _register_gym( # noqa: F811
|
|
2796
|
+
cls,
|
|
2797
|
+
id,
|
|
2798
|
+
entry_point: Callable | None = None,
|
|
2799
|
+
transform: Transform | None = None, # noqa: F821
|
|
2800
|
+
info_keys: list[NestedKey] | None = None,
|
|
2801
|
+
to_numpy: bool = False,
|
|
2802
|
+
reward_threshold: float | None = None,
|
|
2803
|
+
nondeterministic: bool = False,
|
|
2804
|
+
max_episode_steps: int | None = None,
|
|
2805
|
+
order_enforce: bool = True,
|
|
2806
|
+
autoreset: bool | None = None,
|
|
2807
|
+
disable_env_checker: bool = False,
|
|
2808
|
+
apply_api_compatibility: bool = False,
|
|
2809
|
+
**kwargs,
|
|
2810
|
+
):
|
|
2811
|
+
import gym
|
|
2812
|
+
from torchrl.envs.libs._gym_utils import _TorchRLGymWrapper
|
|
2813
|
+
|
|
2814
|
+
if order_enforce is not True:
|
|
2815
|
+
raise TypeError(
|
|
2816
|
+
cls._GYM_UNRECOGNIZED_KWARG.format("order_enforce", gym.__version__)
|
|
2817
|
+
)
|
|
2818
|
+
if disable_env_checker is not False:
|
|
2819
|
+
raise TypeError(
|
|
2820
|
+
cls._GYM_UNRECOGNIZED_KWARG.format(
|
|
2821
|
+
"disable_env_checker", gym.__version__
|
|
2822
|
+
)
|
|
2823
|
+
)
|
|
2824
|
+
if autoreset is not None:
|
|
2825
|
+
raise TypeError(
|
|
2826
|
+
cls._GYM_UNRECOGNIZED_KWARG.format("autoreset", gym.__version__)
|
|
2827
|
+
)
|
|
2828
|
+
if apply_api_compatibility is not False:
|
|
2829
|
+
raise TypeError(
|
|
2830
|
+
cls._GYM_UNRECOGNIZED_KWARG.format(
|
|
2831
|
+
"apply_api_compatibility", gym.__version__
|
|
2832
|
+
)
|
|
2833
|
+
)
|
|
2834
|
+
if entry_point is None:
|
|
2835
|
+
entry_point = cls
|
|
2836
|
+
entry_point = partial(
|
|
2837
|
+
_TorchRLGymWrapper,
|
|
2838
|
+
entry_point=entry_point,
|
|
2839
|
+
info_keys=info_keys,
|
|
2840
|
+
to_numpy=to_numpy,
|
|
2841
|
+
transform=transform,
|
|
2842
|
+
**kwargs,
|
|
2843
|
+
)
|
|
2844
|
+
return gym.register(
|
|
2845
|
+
id=id,
|
|
2846
|
+
entry_point=entry_point,
|
|
2847
|
+
reward_threshold=reward_threshold,
|
|
2848
|
+
nondeterministic=nondeterministic,
|
|
2849
|
+
max_episode_steps=max_episode_steps,
|
|
2850
|
+
)
|
|
2851
|
+
|
|
2852
|
+
@implement_for("gymnasium", None, "1.0.0", class_method=True)
|
|
2853
|
+
def _register_gym( # noqa: F811
|
|
2854
|
+
cls,
|
|
2855
|
+
id,
|
|
2856
|
+
entry_point: Callable | None = None,
|
|
2857
|
+
transform: Transform | None = None, # noqa: F821
|
|
2858
|
+
info_keys: list[NestedKey] | None = None,
|
|
2859
|
+
to_numpy: bool = False,
|
|
2860
|
+
reward_threshold: float | None = None,
|
|
2861
|
+
nondeterministic: bool = False,
|
|
2862
|
+
max_episode_steps: int | None = None,
|
|
2863
|
+
order_enforce: bool = True,
|
|
2864
|
+
autoreset: bool | None = None,
|
|
2865
|
+
disable_env_checker: bool = False,
|
|
2866
|
+
apply_api_compatibility: bool = False,
|
|
2867
|
+
**kwargs,
|
|
2868
|
+
):
|
|
2869
|
+
import gymnasium
|
|
2870
|
+
from torchrl.envs.libs._gym_utils import _TorchRLGymnasiumWrapper
|
|
2871
|
+
|
|
2872
|
+
if entry_point is None:
|
|
2873
|
+
entry_point = cls
|
|
2874
|
+
|
|
2875
|
+
entry_point = partial(
|
|
2876
|
+
_TorchRLGymnasiumWrapper,
|
|
2877
|
+
entry_point=entry_point,
|
|
2878
|
+
info_keys=info_keys,
|
|
2879
|
+
to_numpy=to_numpy,
|
|
2880
|
+
transform=transform,
|
|
2881
|
+
**kwargs,
|
|
2882
|
+
)
|
|
2883
|
+
return gymnasium.register(
|
|
2884
|
+
id=id,
|
|
2885
|
+
entry_point=entry_point,
|
|
2886
|
+
reward_threshold=reward_threshold,
|
|
2887
|
+
nondeterministic=nondeterministic,
|
|
2888
|
+
max_episode_steps=max_episode_steps,
|
|
2889
|
+
order_enforce=order_enforce,
|
|
2890
|
+
autoreset=bool(autoreset),
|
|
2891
|
+
disable_env_checker=disable_env_checker,
|
|
2892
|
+
apply_api_compatibility=apply_api_compatibility,
|
|
2893
|
+
)
|
|
2894
|
+
|
|
2895
|
+
@implement_for("gymnasium", "1.1.0", class_method=True)
|
|
2896
|
+
def _register_gym( # noqa: F811
|
|
2897
|
+
cls,
|
|
2898
|
+
id,
|
|
2899
|
+
entry_point: Callable | None = None,
|
|
2900
|
+
transform: Transform | None = None, # noqa: F821
|
|
2901
|
+
info_keys: list[NestedKey] | None = None,
|
|
2902
|
+
to_numpy: bool = False,
|
|
2903
|
+
reward_threshold: float | None = None,
|
|
2904
|
+
nondeterministic: bool = False,
|
|
2905
|
+
max_episode_steps: int | None = None,
|
|
2906
|
+
order_enforce: bool = True,
|
|
2907
|
+
autoreset: bool | None = None,
|
|
2908
|
+
disable_env_checker: bool = False,
|
|
2909
|
+
apply_api_compatibility: bool = False,
|
|
2910
|
+
**kwargs,
|
|
2911
|
+
):
|
|
2912
|
+
import gymnasium
|
|
2913
|
+
from torchrl.envs.libs._gym_utils import _TorchRLGymnasiumWrapper
|
|
2914
|
+
|
|
2915
|
+
if autoreset is not None:
|
|
2916
|
+
raise TypeError(
|
|
2917
|
+
f"the autoreset argument is deprecated in gymnasium>=1.0. Got autoreset={autoreset}"
|
|
2918
|
+
)
|
|
2919
|
+
if entry_point is None:
|
|
2920
|
+
entry_point = cls
|
|
2921
|
+
|
|
2922
|
+
entry_point = partial(
|
|
2923
|
+
_TorchRLGymnasiumWrapper,
|
|
2924
|
+
entry_point=entry_point,
|
|
2925
|
+
info_keys=info_keys,
|
|
2926
|
+
to_numpy=to_numpy,
|
|
2927
|
+
transform=transform,
|
|
2928
|
+
**kwargs,
|
|
2929
|
+
)
|
|
2930
|
+
if apply_api_compatibility is not False:
|
|
2931
|
+
raise TypeError(
|
|
2932
|
+
cls._GYM_UNRECOGNIZED_KWARG.format(
|
|
2933
|
+
"apply_api_compatibility", gymnasium.__version__
|
|
2934
|
+
)
|
|
2935
|
+
)
|
|
2936
|
+
return gymnasium.register(
|
|
2937
|
+
id=id,
|
|
2938
|
+
entry_point=entry_point,
|
|
2939
|
+
reward_threshold=reward_threshold,
|
|
2940
|
+
nondeterministic=nondeterministic,
|
|
2941
|
+
max_episode_steps=max_episode_steps,
|
|
2942
|
+
order_enforce=order_enforce,
|
|
2943
|
+
disable_env_checker=disable_env_checker,
|
|
2944
|
+
)
|
|
2945
|
+
|
|
2946
|
+
def forward(self, *args, **kwargs):
|
|
2947
|
+
raise NotImplementedError(
|
|
2948
|
+
"EnvBase.forward is not implemented. If you ended here during a call to `ParallelEnv(...)`, please use "
|
|
2949
|
+
"a constructor such as `ParallelEnv(num_env, lambda env=env: env)` instead. "
|
|
2950
|
+
"Batched envs require constructors because environment instances may not always be serializable."
|
|
2951
|
+
)
|
|
2952
|
+
|
|
2953
|
+
@abc.abstractmethod
|
|
2954
|
+
def _step(
|
|
2955
|
+
self,
|
|
2956
|
+
tensordict: TensorDictBase,
|
|
2957
|
+
) -> TensorDictBase:
|
|
2958
|
+
raise NotImplementedError
|
|
2959
|
+
|
|
2960
|
+
@abc.abstractmethod
|
|
2961
|
+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
|
|
2962
|
+
raise NotImplementedError
|
|
2963
|
+
|
|
2964
|
+
def reset(
|
|
2965
|
+
self,
|
|
2966
|
+
tensordict: TensorDictBase | None = None,
|
|
2967
|
+
**kwargs,
|
|
2968
|
+
) -> TensorDictBase:
|
|
2969
|
+
"""Resets the environment.
|
|
2970
|
+
|
|
2971
|
+
As for step and _step, only the private method :obj:`_reset` should be overwritten by EnvBase subclasses.
|
|
2972
|
+
|
|
2973
|
+
Args:
|
|
2974
|
+
tensordict (TensorDictBase, optional): tensordict to be used to contain the resulting new observation.
|
|
2975
|
+
In some cases, this input can also be used to pass argument to the reset function.
|
|
2976
|
+
kwargs (optional): other arguments to be passed to the native
|
|
2977
|
+
reset function.
|
|
2978
|
+
|
|
2979
|
+
Returns:
|
|
2980
|
+
a tensordict (or the input tensordict, if any), modified in place with the resulting observations.
|
|
2981
|
+
|
|
2982
|
+
.. note:: `reset` should not be overwritten by :class:`~torchrl.envs.EnvBase` subclasses. The method to
|
|
2983
|
+
modify is :meth:`~torchrl.envs.EnvBase._reset`.
|
|
2984
|
+
|
|
2985
|
+
"""
|
|
2986
|
+
if tensordict is not None:
|
|
2987
|
+
self._assert_tensordict_shape(tensordict)
|
|
2988
|
+
|
|
2989
|
+
select_reset_only = kwargs.pop("select_reset_only", False)
|
|
2990
|
+
if select_reset_only and tensordict is not None:
|
|
2991
|
+
# When making rollouts with step_and_maybe_reset, it can happen that a tensordict has
|
|
2992
|
+
# keys that are used by reset to optionally set the reset state (eg, the fen in chess). If that's the
|
|
2993
|
+
# case and we don't throw them away here, reset will just be a no-op (put the env in the state reached
|
|
2994
|
+
# during the previous step).
|
|
2995
|
+
# Therefore, maybe_reset tells reset to temporarily hide the non-reset keys.
|
|
2996
|
+
# To make step_and_maybe_reset handle custom reset states, some version of TensorDictPrimer should be used.
|
|
2997
|
+
tensordict_reset = self._reset(
|
|
2998
|
+
tensordict.select(*self.reset_keys, strict=False), **kwargs
|
|
2999
|
+
)
|
|
3000
|
+
else:
|
|
3001
|
+
tensordict_reset = self._reset(tensordict, **kwargs)
|
|
3002
|
+
# We assume that this is done properly
|
|
3003
|
+
# if reset.device != self.device:
|
|
3004
|
+
# reset = reset.to(self.device, non_blocking=True)
|
|
3005
|
+
if tensordict_reset is tensordict:
|
|
3006
|
+
raise RuntimeError(
|
|
3007
|
+
"EnvBase._reset should return outplace changes to the input "
|
|
3008
|
+
"tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty()) "
|
|
3009
|
+
"inside _reset before writing new tensors onto this new instance."
|
|
3010
|
+
)
|
|
3011
|
+
if not isinstance(tensordict_reset, TensorDictBase):
|
|
3012
|
+
raise RuntimeError(
|
|
3013
|
+
f"env._reset returned an object of type {type(tensordict_reset)} but a TensorDict was expected."
|
|
3014
|
+
)
|
|
3015
|
+
return self._reset_proc_data(tensordict, tensordict_reset)
|
|
3016
|
+
|
|
3017
|
+
def _reset_proc_data(self, tensordict, tensordict_reset):
|
|
3018
|
+
self._complete_done(self.full_done_spec, tensordict_reset)
|
|
3019
|
+
self._reset_check_done(tensordict, tensordict_reset)
|
|
3020
|
+
if tensordict is not None:
|
|
3021
|
+
return _update_during_reset(tensordict_reset, tensordict, self.reset_keys)
|
|
3022
|
+
return tensordict_reset
|
|
3023
|
+
|
|
3024
|
+
def _reset_check_done(self, tensordict, tensordict_reset):
|
|
3025
|
+
"""Checks the done status after reset.
|
|
3026
|
+
|
|
3027
|
+
If _reset signals were passed, we check that the env is not done for these
|
|
3028
|
+
indices.
|
|
3029
|
+
|
|
3030
|
+
We also check that the input tensordict contained ``"done"``s if the
|
|
3031
|
+
reset is partial and incomplete.
|
|
3032
|
+
|
|
3033
|
+
"""
|
|
3034
|
+
# we iterate over (reset_key, (done_key, truncated_key)) and check that all
|
|
3035
|
+
# values where reset was true now have a done set to False.
|
|
3036
|
+
# If no reset was present, all done and truncated must be False
|
|
3037
|
+
|
|
3038
|
+
# Once we checked a root, we don't check its leaves - so keep track of the roots. Fortunately, we sort the done
|
|
3039
|
+
# keys in the done_keys_group from root to leaf
|
|
3040
|
+
prefix_complete = set()
|
|
3041
|
+
for reset_key, done_key_group in zip(self.reset_keys, self.done_keys_groups):
|
|
3042
|
+
skip = False
|
|
3043
|
+
if isinstance(reset_key, tuple):
|
|
3044
|
+
for i in range(len(reset_key) - 1):
|
|
3045
|
+
if reset_key[:i] in prefix_complete:
|
|
3046
|
+
skip = True
|
|
3047
|
+
break
|
|
3048
|
+
if skip:
|
|
3049
|
+
continue
|
|
3050
|
+
reset_value = (
|
|
3051
|
+
tensordict.get(reset_key, default=None)
|
|
3052
|
+
if tensordict is not None
|
|
3053
|
+
else None
|
|
3054
|
+
)
|
|
3055
|
+
prefix_complete.add(() if isinstance(reset_key, str) else reset_key[:-1])
|
|
3056
|
+
if reset_value is not None:
|
|
3057
|
+
for done_key in done_key_group:
|
|
3058
|
+
done_val = tensordict_reset.get(done_key)
|
|
3059
|
+
if (
|
|
3060
|
+
done_val.any()
|
|
3061
|
+
and done_val[reset_value].any()
|
|
3062
|
+
and not self._allow_done_after_reset
|
|
3063
|
+
):
|
|
3064
|
+
raise RuntimeError(
|
|
3065
|
+
f"Env done entry '{done_key}' was (partially) True after reset on specified '_reset' dimensions. This is not allowed."
|
|
3066
|
+
)
|
|
3067
|
+
if (
|
|
3068
|
+
done_key not in tensordict.keys(True)
|
|
3069
|
+
and done_val[~reset_value].any()
|
|
3070
|
+
):
|
|
3071
|
+
warnings.warn(
|
|
3072
|
+
f"A partial `'_reset'` key has been passed to `reset` ({reset_key}), "
|
|
3073
|
+
f"but the corresponding done_key ({done_key}) wasn't present in the input "
|
|
3074
|
+
f"tensordict. "
|
|
3075
|
+
f"This is discouraged, since the input tensordict should contain "
|
|
3076
|
+
f"all the data not being reset."
|
|
3077
|
+
)
|
|
3078
|
+
# we set the done val to tensordict, to make sure that
|
|
3079
|
+
# _update_during_reset does not pad the value
|
|
3080
|
+
tensordict.set(done_key, done_val)
|
|
3081
|
+
elif not self._allow_done_after_reset:
|
|
3082
|
+
for done_key in done_key_group:
|
|
3083
|
+
if tensordict_reset.get(done_key).any():
|
|
3084
|
+
raise RuntimeError(
|
|
3085
|
+
f"The done entry '{done_key}' was (partially) True after a call to reset() in env {self}."
|
|
3086
|
+
)
|
|
3087
|
+
|
|
3088
|
+
def numel(self) -> int:
|
|
3089
|
+
return prod(self.batch_size)
|
|
3090
|
+
|
|
3091
|
+
def set_seed(
|
|
3092
|
+
self, seed: int | None = None, static_seed: bool = False
|
|
3093
|
+
) -> int | None:
|
|
3094
|
+
"""Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present).
|
|
3095
|
+
|
|
3096
|
+
Args:
|
|
3097
|
+
seed (int): seed to be set. The seed is set only locally in the environment. To handle the global seed,
|
|
3098
|
+
see :func:`~torch.manual_seed`.
|
|
3099
|
+
static_seed (bool, optional): if ``True``, the seed is not incremented.
|
|
3100
|
+
Defaults to False
|
|
3101
|
+
|
|
3102
|
+
Returns:
|
|
3103
|
+
integer representing the "next seed": i.e. the seed that should be
|
|
3104
|
+
used for another environment if created concomitantly to this environment.
|
|
3105
|
+
|
|
3106
|
+
"""
|
|
3107
|
+
self._set_seed(seed)
|
|
3108
|
+
if seed is not None and not static_seed:
|
|
3109
|
+
new_seed = seed_generator(seed)
|
|
3110
|
+
seed = new_seed
|
|
3111
|
+
return seed
|
|
3112
|
+
|
|
3113
|
+
@abc.abstractmethod
|
|
3114
|
+
def _set_seed(self, seed: int | None) -> None:
|
|
3115
|
+
raise NotImplementedError
|
|
3116
|
+
|
|
3117
|
+
def set_state(self):
|
|
3118
|
+
raise NotImplementedError
|
|
3119
|
+
|
|
3120
|
+
def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
|
|
3121
|
+
if (
|
|
3122
|
+
self.batch_locked or self.batch_size != ()
|
|
3123
|
+
) and tensordict.batch_size != self.batch_size:
|
|
3124
|
+
raise RuntimeError(
|
|
3125
|
+
f"Expected a tensordict with shape==env.batch_size, "
|
|
3126
|
+
f"got {tensordict.batch_size} and {self.batch_size}"
|
|
3127
|
+
)
|
|
3128
|
+
|
|
3129
|
+
def all_actions(self, tensordict: TensorDictBase | None = None) -> TensorDictBase:
|
|
3130
|
+
"""Generates all possible actions from the action spec.
|
|
3131
|
+
|
|
3132
|
+
This only works in environments with fully discrete actions.
|
|
3133
|
+
|
|
3134
|
+
Args:
|
|
3135
|
+
tensordict (TensorDictBase, optional): If given, :meth:`~.reset`
|
|
3136
|
+
is called with this tensordict.
|
|
3137
|
+
|
|
3138
|
+
Returns:
|
|
3139
|
+
a tensordict object with the "action" entry updated with a batch of
|
|
3140
|
+
all possible actions. The actions are stacked together in the
|
|
3141
|
+
leading dimension.
|
|
3142
|
+
"""
|
|
3143
|
+
if tensordict is not None:
|
|
3144
|
+
self.reset(tensordict)
|
|
3145
|
+
|
|
3146
|
+
return self.full_action_spec.enumerate(use_mask=True)
|
|
3147
|
+
|
|
3148
|
+
def rand_action(self, tensordict: TensorDictBase | None = None):
|
|
3149
|
+
"""Performs a random action given the action_spec attribute.
|
|
3150
|
+
|
|
3151
|
+
Args:
|
|
3152
|
+
tensordict (TensorDictBase, optional): tensordict where the resulting action should be written.
|
|
3153
|
+
|
|
3154
|
+
Returns:
|
|
3155
|
+
a tensordict object with the "action" entry updated with a random
|
|
3156
|
+
sample from the action-spec.
|
|
3157
|
+
|
|
3158
|
+
"""
|
|
3159
|
+
shape = torch.Size([])
|
|
3160
|
+
if not self.batch_locked:
|
|
3161
|
+
if not self.batch_size and tensordict is not None:
|
|
3162
|
+
# if we can't infer the batch-size from the env, take it from tensordict
|
|
3163
|
+
shape = tensordict.shape
|
|
3164
|
+
elif not self.batch_size:
|
|
3165
|
+
# if tensordict wasn't provided, we assume empty batch size
|
|
3166
|
+
shape = torch.Size([])
|
|
3167
|
+
elif tensordict.shape != self.batch_size:
|
|
3168
|
+
# if tensordict is not None and the env has a batch size, their shape must match
|
|
3169
|
+
raise RuntimeError(
|
|
3170
|
+
"The input tensordict and the env have a different batch size: "
|
|
3171
|
+
f"env.batch_size={self.batch_size} and tensordict.batch_size={tensordict.shape}. "
|
|
3172
|
+
f"Non batch-locked environment require the env batch-size to be either empty or to"
|
|
3173
|
+
f" match the tensordict one."
|
|
3174
|
+
)
|
|
3175
|
+
# We generate the action from the full_action_spec
|
|
3176
|
+
r = self.input_spec["full_action_spec"].rand(shape)
|
|
3177
|
+
if tensordict is None:
|
|
3178
|
+
return r
|
|
3179
|
+
tensordict.update(r)
|
|
3180
|
+
return tensordict
|
|
3181
|
+
|
|
3182
|
+
def rand_step(self, tensordict: TensorDictBase | None = None) -> TensorDictBase:
|
|
3183
|
+
"""Performs a random step in the environment given the action_spec attribute.
|
|
3184
|
+
|
|
3185
|
+
Args:
|
|
3186
|
+
tensordict (TensorDictBase, optional): tensordict where the resulting info should be written.
|
|
3187
|
+
|
|
3188
|
+
Returns:
|
|
3189
|
+
a tensordict object with the new observation after a random step in the environment. The action will
|
|
3190
|
+
be stored with the "action" key.
|
|
3191
|
+
|
|
3192
|
+
"""
|
|
3193
|
+
tensordict = self.rand_action(tensordict)
|
|
3194
|
+
return self.step(tensordict)
|
|
3195
|
+
|
|
3196
|
+
@property
|
|
3197
|
+
def specs(self) -> Composite:
|
|
3198
|
+
"""Returns a Composite container where all the environment are present.
|
|
3199
|
+
|
|
3200
|
+
This feature allows one to create an environment, retrieve all of the specs in a single data container and then
|
|
3201
|
+
erase the environment from the workspace.
|
|
3202
|
+
|
|
3203
|
+
"""
|
|
3204
|
+
return Composite(
|
|
3205
|
+
output_spec=self.output_spec,
|
|
3206
|
+
input_spec=self.input_spec,
|
|
3207
|
+
shape=self.batch_size,
|
|
3208
|
+
)
|
|
3209
|
+
|
|
3210
|
+
@property
|
|
3211
|
+
@_cache_value
|
|
3212
|
+
def _has_dynamic_specs(self) -> bool:
|
|
3213
|
+
return _has_dynamic_specs(self.specs)
|
|
3214
|
+
|
|
3215
|
+
def rollout(
|
|
3216
|
+
self,
|
|
3217
|
+
max_steps: int,
|
|
3218
|
+
policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
|
|
3219
|
+
callback: Callable[[TensorDictBase, ...], Any] | None = None,
|
|
3220
|
+
*,
|
|
3221
|
+
auto_reset: bool = True,
|
|
3222
|
+
auto_cast_to_device: bool = False,
|
|
3223
|
+
break_when_any_done: bool | None = None,
|
|
3224
|
+
break_when_all_done: bool | None = None,
|
|
3225
|
+
return_contiguous: bool | None = False,
|
|
3226
|
+
tensordict: TensorDictBase | None = None,
|
|
3227
|
+
set_truncated: bool = False,
|
|
3228
|
+
out=None,
|
|
3229
|
+
trust_policy: bool = False,
|
|
3230
|
+
storing_device: DEVICE_TYPING | None = None,
|
|
3231
|
+
) -> TensorDictBase:
|
|
3232
|
+
"""Executes a rollout in the environment.
|
|
3233
|
+
|
|
3234
|
+
The function will return as soon as any of the contained environments
|
|
3235
|
+
reaches any of the done states.
|
|
3236
|
+
|
|
3237
|
+
Args:
|
|
3238
|
+
max_steps (int): maximum number of steps to be executed. The actual number of steps can be smaller if
|
|
3239
|
+
the environment reaches a done state before max_steps have been executed.
|
|
3240
|
+
policy (callable, optional): callable to be called to compute the desired action.
|
|
3241
|
+
If no policy is provided, actions will be called using :obj:`env.rand_step()`.
|
|
3242
|
+
The policy can be any callable that reads either a tensordict or
|
|
3243
|
+
the entire sequence of observation entries __sorted as__ the ``env.observation_spec.keys()``.
|
|
3244
|
+
Defaults to `None`.
|
|
3245
|
+
callback (Callable[[TensorDict], Any], optional): function to be called at each iteration with the given
|
|
3246
|
+
TensorDict. Defaults to ``None``. The output of ``callback`` will not be collected, it is the user
|
|
3247
|
+
responsibility to save any result within the callback call if data needs to be carried over beyond
|
|
3248
|
+
the call to ``rollout``.
|
|
3249
|
+
|
|
3250
|
+
Keyword Args:
|
|
3251
|
+
auto_reset (bool, optional): if ``True``, the contained environments will be reset before starting the
|
|
3252
|
+
rollout. If ``False``, then the rollout will continue from a previous state, which requires the
|
|
3253
|
+
``tensordict`` argument to be passed with the previous rollout. Default is ``True``.
|
|
3254
|
+
auto_cast_to_device (bool, optional): if ``True``, the device of the tensordict is automatically cast to the
|
|
3255
|
+
policy device before the policy is used. Default is ``False``.
|
|
3256
|
+
break_when_any_done (bool): if ``True``, break when any of the contained environments reaches any of the
|
|
3257
|
+
done states. If ``False``, then the done environments are reset automatically. Default is ``True``.
|
|
3258
|
+
|
|
3259
|
+
.. seealso:: The :ref:`Partial resets <ref_partial_resets>` of the documentation gives more
|
|
3260
|
+
information about partial resets.
|
|
3261
|
+
|
|
3262
|
+
break_when_all_done (bool, optional): if ``True``, break if all of the contained environments reach any
|
|
3263
|
+
of the done states. If ``False``, break if at least one environment reaches any of the done states.
|
|
3264
|
+
Default is ``False``.
|
|
3265
|
+
|
|
3266
|
+
.. seealso:: The :ref:`Partial steps <ref_partial_steps>` of the documentation gives more
|
|
3267
|
+
information about partial resets.
|
|
3268
|
+
|
|
3269
|
+
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is `True` if
|
|
3270
|
+
the env does not have dynamic specs, otherwise `False`.
|
|
3271
|
+
tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial
|
|
3272
|
+
tensordict must be provided. Rollout will check if this tensordict has done flags and reset the
|
|
3273
|
+
environment in those dimensions (if needed).
|
|
3274
|
+
This normally should not occur if ``tensordict`` is the output of a reset, but can occur
|
|
3275
|
+
if ``tensordict`` is the last step of a previous rollout.
|
|
3276
|
+
A ``tensordict`` can also be provided when ``auto_reset=True`` if metadata need to be passed
|
|
3277
|
+
to the ``reset`` method, such as a batch-size or a device for stateless environments.
|
|
3278
|
+
set_truncated (bool, optional): if ``True``, ``"truncated"`` and ``"done"`` keys will be set to
|
|
3279
|
+
``True`` after completion of the rollout. If no ``"truncated"`` is found within the
|
|
3280
|
+
``done_spec``, an exception is raised.
|
|
3281
|
+
Truncated keys can be set through ``env.add_truncated_keys``.
|
|
3282
|
+
Defaults to ``False``.
|
|
3283
|
+
trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
|
|
3284
|
+
assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules
|
|
3285
|
+
and ``False`` otherwise.
|
|
3286
|
+
storing_device (Device, optional): if provided, the tensordict will be stored on this device.
|
|
3287
|
+
Defaults to ``None``.
|
|
3288
|
+
|
|
3289
|
+
Returns:
|
|
3290
|
+
TensorDict object containing the resulting trajectory.
|
|
3291
|
+
|
|
3292
|
+
The data returned will be marked with a "time" dimension name for the last
|
|
3293
|
+
dimension of the tensordict (at the ``env.ndim`` index).
|
|
3294
|
+
|
|
3295
|
+
``rollout`` is quite handy to display what the data structure of the
|
|
3296
|
+
environment looks like.
|
|
3297
|
+
|
|
3298
|
+
Examples:
|
|
3299
|
+
>>> # Using rollout without a policy
|
|
3300
|
+
>>> from torchrl.envs.libs.gym import GymEnv
|
|
3301
|
+
>>> from torchrl.envs.transforms import TransformedEnv, StepCounter
|
|
3302
|
+
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(max_steps=20))
|
|
3303
|
+
>>> rollout = env.rollout(max_steps=1000)
|
|
3304
|
+
>>> print(rollout)
|
|
3305
|
+
TensorDict(
|
|
3306
|
+
fields={
|
|
3307
|
+
action: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3308
|
+
done: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
3309
|
+
next: TensorDict(
|
|
3310
|
+
fields={
|
|
3311
|
+
done: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
3312
|
+
observation: Tensor(shape=torch.Size([20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3313
|
+
reward: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3314
|
+
step_count: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
3315
|
+
truncated: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
3316
|
+
batch_size=torch.Size([20]),
|
|
3317
|
+
device=cpu,
|
|
3318
|
+
is_shared=False),
|
|
3319
|
+
observation: Tensor(shape=torch.Size([20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3320
|
+
step_count: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
3321
|
+
truncated: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
3322
|
+
batch_size=torch.Size([20]),
|
|
3323
|
+
device=cpu,
|
|
3324
|
+
is_shared=False)
|
|
3325
|
+
>>> print(rollout.names)
|
|
3326
|
+
['time']
|
|
3327
|
+
>>> # with envs that contain more dimensions
|
|
3328
|
+
>>> from torchrl.envs import SerialEnv
|
|
3329
|
+
>>> env = SerialEnv(3, lambda: TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(max_steps=20)))
|
|
3330
|
+
>>> rollout = env.rollout(max_steps=1000)
|
|
3331
|
+
>>> print(rollout)
|
|
3332
|
+
TensorDict(
|
|
3333
|
+
fields={
|
|
3334
|
+
action: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3335
|
+
done: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
3336
|
+
next: TensorDict(
|
|
3337
|
+
fields={
|
|
3338
|
+
done: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
3339
|
+
observation: Tensor(shape=torch.Size([3, 20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3340
|
+
reward: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3341
|
+
step_count: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
3342
|
+
truncated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
3343
|
+
batch_size=torch.Size([3, 20]),
|
|
3344
|
+
device=cpu,
|
|
3345
|
+
is_shared=False),
|
|
3346
|
+
observation: Tensor(shape=torch.Size([3, 20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3347
|
+
step_count: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
3348
|
+
truncated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
3349
|
+
batch_size=torch.Size([3, 20]),
|
|
3350
|
+
device=cpu,
|
|
3351
|
+
is_shared=False)
|
|
3352
|
+
>>> print(rollout.names)
|
|
3353
|
+
[None, 'time']
|
|
3354
|
+
|
|
3355
|
+
Using a policy (a regular :class:`~torch.nn.Module` or a :class:`~tensordict.nn.TensorDictModule`)
|
|
3356
|
+
is also easy:
|
|
3357
|
+
|
|
3358
|
+
Examples:
|
|
3359
|
+
>>> from torch import nn
|
|
3360
|
+
>>> env = GymEnv("CartPole-v1", categorical_action_encoding=True)
|
|
3361
|
+
>>> class ArgMaxModule(nn.Module):
|
|
3362
|
+
... def forward(self, values):
|
|
3363
|
+
... return values.argmax(-1)
|
|
3364
|
+
>>> n_obs = env.observation_spec["observation"].shape[-1]
|
|
3365
|
+
>>> n_act = env.action_spec.n
|
|
3366
|
+
>>> # A deterministic policy
|
|
3367
|
+
>>> policy = nn.Sequential(
|
|
3368
|
+
... nn.Linear(n_obs, n_act),
|
|
3369
|
+
... ArgMaxModule())
|
|
3370
|
+
>>> env.rollout(max_steps=10, policy=policy)
|
|
3371
|
+
TensorDict(
|
|
3372
|
+
fields={
|
|
3373
|
+
action: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
3374
|
+
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
3375
|
+
next: TensorDict(
|
|
3376
|
+
fields={
|
|
3377
|
+
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
3378
|
+
observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3379
|
+
reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3380
|
+
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
3381
|
+
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
3382
|
+
batch_size=torch.Size([10]),
|
|
3383
|
+
device=cpu,
|
|
3384
|
+
is_shared=False),
|
|
3385
|
+
observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3386
|
+
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
3387
|
+
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
3388
|
+
batch_size=torch.Size([10]),
|
|
3389
|
+
device=cpu,
|
|
3390
|
+
is_shared=False)
|
|
3391
|
+
>>> # Under the hood, rollout will wrap the policy in a TensorDictModule
|
|
3392
|
+
>>> # To speed things up we can do that ourselves
|
|
3393
|
+
>>> from tensordict.nn import TensorDictModule
|
|
3394
|
+
>>> policy = TensorDictModule(policy, in_keys=list(env.observation_spec.keys()), out_keys=["action"])
|
|
3395
|
+
>>> env.rollout(max_steps=10, policy=policy)
|
|
3396
|
+
TensorDict(
|
|
3397
|
+
fields={
|
|
3398
|
+
action: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
3399
|
+
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
3400
|
+
next: TensorDict(
|
|
3401
|
+
fields={
|
|
3402
|
+
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
3403
|
+
observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3404
|
+
reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3405
|
+
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
3406
|
+
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
3407
|
+
batch_size=torch.Size([10]),
|
|
3408
|
+
device=cpu,
|
|
3409
|
+
is_shared=False),
|
|
3410
|
+
observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3411
|
+
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
3412
|
+
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
3413
|
+
batch_size=torch.Size([10]),
|
|
3414
|
+
device=cpu,
|
|
3415
|
+
is_shared=False)
|
|
3416
|
+
|
|
3417
|
+
|
|
3418
|
+
In some instances, contiguous tensordict cannot be obtained because
|
|
3419
|
+
they cannot be stacked. This can happen when the data returned at
|
|
3420
|
+
each step may have a different shape, or when different environments
|
|
3421
|
+
are executed together. In that case, ``return_contiguous=False``
|
|
3422
|
+
will cause the returned tensordict to be a lazy stack of tensordicts:
|
|
3423
|
+
|
|
3424
|
+
Examples of non-contiguous rollout:
|
|
3425
|
+
>>> rollout = env.rollout(4, return_contiguous=False)
|
|
3426
|
+
>>> print(rollout)
|
|
3427
|
+
LazyStackedTensorDict(
|
|
3428
|
+
fields={
|
|
3429
|
+
action: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3430
|
+
done: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
3431
|
+
next: LazyStackedTensorDict(
|
|
3432
|
+
fields={
|
|
3433
|
+
done: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
3434
|
+
observation: Tensor(shape=torch.Size([3, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3435
|
+
reward: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3436
|
+
step_count: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
3437
|
+
truncated: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
3438
|
+
batch_size=torch.Size([3, 4]),
|
|
3439
|
+
device=cpu,
|
|
3440
|
+
is_shared=False),
|
|
3441
|
+
observation: Tensor(shape=torch.Size([3, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3442
|
+
step_count: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
3443
|
+
truncated: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
3444
|
+
batch_size=torch.Size([3, 4]),
|
|
3445
|
+
device=cpu,
|
|
3446
|
+
is_shared=False)
|
|
3447
|
+
>>> print(rollout.names)
|
|
3448
|
+
[None, 'time']
|
|
3449
|
+
|
|
3450
|
+
Rollouts can be used in a loop to emulate data collection.
|
|
3451
|
+
To do so, you need to pass as input the last tensordict coming from the previous rollout after calling
|
|
3452
|
+
:func:`~torchrl.envs.utils.step_mdp` on it.
|
|
3453
|
+
|
|
3454
|
+
Examples of data collection rollouts:
|
|
3455
|
+
>>> from torchrl.envs import GymEnv, step_mdp
|
|
3456
|
+
>>> env = GymEnv("CartPole-v1")
|
|
3457
|
+
>>> epochs = 10
|
|
3458
|
+
>>> input_td = env.reset()
|
|
3459
|
+
>>> for i in range(epochs):
|
|
3460
|
+
... rollout_td = env.rollout(
|
|
3461
|
+
... max_steps=100,
|
|
3462
|
+
... break_when_any_done=False,
|
|
3463
|
+
... auto_reset=False,
|
|
3464
|
+
... tensordict=input_td,
|
|
3465
|
+
... )
|
|
3466
|
+
... input_td = step_mdp(
|
|
3467
|
+
... rollout_td[..., -1],
|
|
3468
|
+
... )
|
|
3469
|
+
|
|
3470
|
+
"""
|
|
3471
|
+
if break_when_any_done is None: # True by default
|
|
3472
|
+
if break_when_all_done: # all overrides
|
|
3473
|
+
break_when_any_done = False
|
|
3474
|
+
else:
|
|
3475
|
+
break_when_any_done = True
|
|
3476
|
+
if break_when_all_done is None:
|
|
3477
|
+
# There is no case where break_when_all_done is True by default
|
|
3478
|
+
break_when_all_done = False
|
|
3479
|
+
if break_when_all_done and break_when_any_done:
|
|
3480
|
+
raise TypeError(
|
|
3481
|
+
"Cannot have both break_when_all_done and break_when_any_done True at the same time."
|
|
3482
|
+
)
|
|
3483
|
+
if return_contiguous is None:
|
|
3484
|
+
return_contiguous = not self._has_dynamic_specs
|
|
3485
|
+
if policy is not None:
|
|
3486
|
+
policy = _make_compatible_policy(
|
|
3487
|
+
policy,
|
|
3488
|
+
self.observation_spec,
|
|
3489
|
+
env=self,
|
|
3490
|
+
fast_wrap=True,
|
|
3491
|
+
trust_policy=trust_policy,
|
|
3492
|
+
)
|
|
3493
|
+
if auto_cast_to_device:
|
|
3494
|
+
try:
|
|
3495
|
+
policy_device = next(policy.parameters()).device
|
|
3496
|
+
except (StopIteration, AttributeError):
|
|
3497
|
+
policy_device = None
|
|
3498
|
+
else:
|
|
3499
|
+
policy_device = None
|
|
3500
|
+
else:
|
|
3501
|
+
policy = self.rand_action
|
|
3502
|
+
policy_device = None
|
|
3503
|
+
|
|
3504
|
+
env_device = self.device
|
|
3505
|
+
|
|
3506
|
+
if auto_reset:
|
|
3507
|
+
tensordict = self.reset(tensordict)
|
|
3508
|
+
elif tensordict is None:
|
|
3509
|
+
raise RuntimeError("tensordict must be provided when auto_reset is False")
|
|
3510
|
+
else:
|
|
3511
|
+
tensordict = self.maybe_reset(tensordict)
|
|
3512
|
+
|
|
3513
|
+
kwargs = {
|
|
3514
|
+
"tensordict": tensordict,
|
|
3515
|
+
"auto_cast_to_device": auto_cast_to_device,
|
|
3516
|
+
"max_steps": max_steps,
|
|
3517
|
+
"policy": policy,
|
|
3518
|
+
"policy_device": policy_device,
|
|
3519
|
+
"env_device": env_device,
|
|
3520
|
+
"storing_device": None
|
|
3521
|
+
if storing_device is None
|
|
3522
|
+
else torch.device(storing_device),
|
|
3523
|
+
"callback": callback,
|
|
3524
|
+
}
|
|
3525
|
+
if break_when_any_done or break_when_all_done:
|
|
3526
|
+
tensordicts = self._rollout_stop_early(
|
|
3527
|
+
break_when_all_done=break_when_all_done,
|
|
3528
|
+
break_when_any_done=break_when_any_done,
|
|
3529
|
+
**kwargs,
|
|
3530
|
+
)
|
|
3531
|
+
else:
|
|
3532
|
+
tensordicts = self._rollout_nonstop(**kwargs)
|
|
3533
|
+
batch_size = self.batch_size if tensordict is None else tensordict.batch_size
|
|
3534
|
+
if return_contiguous:
|
|
3535
|
+
try:
|
|
3536
|
+
out_td = torch.stack(tensordicts, len(batch_size), out=out)
|
|
3537
|
+
except RuntimeError as err:
|
|
3538
|
+
if (
|
|
3539
|
+
re.match(
|
|
3540
|
+
"The shapes of the tensors to stack is incompatible", str(err)
|
|
3541
|
+
)
|
|
3542
|
+
and self._has_dynamic_specs
|
|
3543
|
+
):
|
|
3544
|
+
raise RuntimeError(
|
|
3545
|
+
"The environment specs are dynamic. Call rollout with return_contiguous=False."
|
|
3546
|
+
)
|
|
3547
|
+
if re.match(
|
|
3548
|
+
"The sets of keys in the tensordicts to stack are exclusive",
|
|
3549
|
+
str(err),
|
|
3550
|
+
):
|
|
3551
|
+
for reward_key in self.reward_keys:
|
|
3552
|
+
if any(reward_key in td for td in tensordicts):
|
|
3553
|
+
raise RuntimeError(
|
|
3554
|
+
"The reward key was present in the root tensordict of at least one of the tensordicts to stack. "
|
|
3555
|
+
"The likely cause is that your environment returns a reward during a call to `reset`, which is not allowed. "
|
|
3556
|
+
"To fix this, you should return the reward in the `step` method but not in during `reset`. If you need a reward "
|
|
3557
|
+
"to be returned during `reset`, submit an issue on github."
|
|
3558
|
+
)
|
|
3559
|
+
raise
|
|
3560
|
+
else:
|
|
3561
|
+
out_td = LazyStackedTensorDict.maybe_dense_stack(
|
|
3562
|
+
tensordicts, len(batch_size), out=out
|
|
3563
|
+
)
|
|
3564
|
+
if set_truncated:
|
|
3565
|
+
found_truncated = False
|
|
3566
|
+
for key in self.done_keys:
|
|
3567
|
+
if _ends_with(key, "truncated"):
|
|
3568
|
+
val = out_td.get(("next", key))
|
|
3569
|
+
done = out_td.get(("next", _replace_last(key, "done")))
|
|
3570
|
+
val[(slice(None),) * (out_td.ndim - 1) + (-1,)] = True
|
|
3571
|
+
out_td.set(("next", key), val)
|
|
3572
|
+
out_td.set(("next", _replace_last(key, "done")), val | done)
|
|
3573
|
+
found_truncated = True
|
|
3574
|
+
if not found_truncated:
|
|
3575
|
+
raise RuntimeError(
|
|
3576
|
+
"set_truncated was set to True but no truncated key could be found. "
|
|
3577
|
+
"Make sure a 'truncated' entry was set in the environment "
|
|
3578
|
+
"full_done_keys using `env.add_truncated_keys()`."
|
|
3579
|
+
)
|
|
3580
|
+
|
|
3581
|
+
out_td.refine_names(..., "time")
|
|
3582
|
+
return out_td
|
|
3583
|
+
|
|
3584
|
+
@_maybe_unlock
|
|
3585
|
+
def add_truncated_keys(self) -> EnvBase:
|
|
3586
|
+
"""Adds truncated keys to the environment."""
|
|
3587
|
+
i = 0
|
|
3588
|
+
for key in self.done_keys:
|
|
3589
|
+
i += 1
|
|
3590
|
+
truncated_key = _replace_last(key, "truncated")
|
|
3591
|
+
self.full_done_spec[truncated_key] = self.full_done_spec[key].clone()
|
|
3592
|
+
if i == 0:
|
|
3593
|
+
raise KeyError(f"Couldn't find done keys. done_spec={self.full_done_specs}")
|
|
3594
|
+
|
|
3595
|
+
return self
|
|
3596
|
+
|
|
3597
|
+
def step_mdp(self, next_tensordict: TensorDictBase) -> TensorDictBase:
|
|
3598
|
+
"""Advances the environment state by one step using the provided `next_tensordict`.
|
|
3599
|
+
|
|
3600
|
+
This method updates the environment's state by transitioning from the current
|
|
3601
|
+
state to the next, as defined by the `next_tensordict`. The resulting tensordict
|
|
3602
|
+
includes updated observations and any other relevant state information, with
|
|
3603
|
+
keys managed according to the environment's specifications.
|
|
3604
|
+
|
|
3605
|
+
Internally, this method utilizes a precomputed :class:`~torchrl.envs.utils._StepMDP` instance to efficiently
|
|
3606
|
+
handle the transition of state, observation, action, reward, and done keys. The
|
|
3607
|
+
:class:`~torchrl.envs.utils._StepMDP` class optimizes the process by precomputing the keys to include and
|
|
3608
|
+
exclude, reducing runtime overhead during repeated calls. The :class:`~torchrl.envs.utils._StepMDP` instance
|
|
3609
|
+
is created with `exclude_action=False`, meaning that action keys are retained in
|
|
3610
|
+
the root tensordict.
|
|
3611
|
+
|
|
3612
|
+
Args:
|
|
3613
|
+
next_tensordict (TensorDictBase): A tensordict containing the state of the
|
|
3614
|
+
environment at the next time step. This tensordict should include keys
|
|
3615
|
+
for observations, actions, rewards, and done flags, as defined by the
|
|
3616
|
+
environment's specifications.
|
|
3617
|
+
|
|
3618
|
+
Returns:
|
|
3619
|
+
TensorDictBase: A new tensordict representing the environment state after
|
|
3620
|
+
advancing by one step.
|
|
3621
|
+
|
|
3622
|
+
.. note:: The method ensures that the environment's key specifications are validated
|
|
3623
|
+
against the provided `next_tensordict`, issuing warnings if discrepancies
|
|
3624
|
+
are found.
|
|
3625
|
+
|
|
3626
|
+
.. note:: This method is designed to work efficiently with environments that have
|
|
3627
|
+
consistent key specifications, leveraging the `_StepMDP` class to minimize
|
|
3628
|
+
overhead.
|
|
3629
|
+
|
|
3630
|
+
Example:
|
|
3631
|
+
>>> from torchrl.envs import GymEnv
|
|
3632
|
+
>>> env = GymEnv("Pendulum-1")
|
|
3633
|
+
>>> data = env.reset()
|
|
3634
|
+
>>> for i in range(10):
|
|
3635
|
+
... # compute action
|
|
3636
|
+
... env.rand_action(data)
|
|
3637
|
+
... # Perform action
|
|
3638
|
+
... next_data = env.step(reset_data)
|
|
3639
|
+
... data = env.step_mdp(next_data)
|
|
3640
|
+
"""
|
|
3641
|
+
return self._step_mdp(next_tensordict)
|
|
3642
|
+
|
|
3643
|
+
@property
|
|
3644
|
+
@_cache_value
|
|
3645
|
+
def _step_mdp(self) -> Callable[[TensorDictBase], TensorDictBase]:
|
|
3646
|
+
return _StepMDP(self, exclude_action=False)
|
|
3647
|
+
|
|
3648
|
+
def _rollout_stop_early(
|
|
3649
|
+
self,
|
|
3650
|
+
*,
|
|
3651
|
+
break_when_any_done,
|
|
3652
|
+
break_when_all_done,
|
|
3653
|
+
tensordict,
|
|
3654
|
+
auto_cast_to_device,
|
|
3655
|
+
max_steps,
|
|
3656
|
+
policy,
|
|
3657
|
+
policy_device,
|
|
3658
|
+
env_device,
|
|
3659
|
+
storing_device,
|
|
3660
|
+
callback,
|
|
3661
|
+
):
|
|
3662
|
+
# Get the sync func
|
|
3663
|
+
if auto_cast_to_device:
|
|
3664
|
+
sync_func = _get_sync_func(policy_device, env_device)
|
|
3665
|
+
tensordicts = []
|
|
3666
|
+
partial_steps = True
|
|
3667
|
+
for i in range(max_steps):
|
|
3668
|
+
if auto_cast_to_device:
|
|
3669
|
+
if policy_device is not None:
|
|
3670
|
+
tensordict = tensordict.to(policy_device, non_blocking=True)
|
|
3671
|
+
sync_func()
|
|
3672
|
+
else:
|
|
3673
|
+
tensordict.clear_device_()
|
|
3674
|
+
# In case policy(..) does not modify in-place - no-op for TensorDict and related
|
|
3675
|
+
tensordict.update(policy(tensordict))
|
|
3676
|
+
if auto_cast_to_device:
|
|
3677
|
+
if env_device is not None:
|
|
3678
|
+
tensordict = tensordict.to(env_device, non_blocking=True)
|
|
3679
|
+
sync_func()
|
|
3680
|
+
else:
|
|
3681
|
+
tensordict.clear_device_()
|
|
3682
|
+
tensordict = self.step(tensordict)
|
|
3683
|
+
if storing_device is None or tensordict.device == storing_device:
|
|
3684
|
+
td_append = tensordict.copy()
|
|
3685
|
+
else:
|
|
3686
|
+
td_append = tensordict.to(storing_device)
|
|
3687
|
+
if break_when_all_done:
|
|
3688
|
+
if partial_steps is not True and not partial_steps.all():
|
|
3689
|
+
# At least one step is partial
|
|
3690
|
+
td_append.pop("_step", None)
|
|
3691
|
+
td_append = torch.where(
|
|
3692
|
+
partial_steps.view(td_append.shape), td_append, tensordicts[-1]
|
|
3693
|
+
)
|
|
3694
|
+
|
|
3695
|
+
tensordicts.append(td_append)
|
|
3696
|
+
|
|
3697
|
+
if i == max_steps - 1:
|
|
3698
|
+
# we don't truncate as one could potentially continue the run
|
|
3699
|
+
break
|
|
3700
|
+
tensordict = self._step_mdp(tensordict)
|
|
3701
|
+
|
|
3702
|
+
if break_when_any_done:
|
|
3703
|
+
# done and truncated are in done_keys
|
|
3704
|
+
# We read if any key is done.
|
|
3705
|
+
any_done = _terminated_or_truncated(
|
|
3706
|
+
tensordict,
|
|
3707
|
+
full_done_spec=self.output_spec["full_done_spec"],
|
|
3708
|
+
key=None,
|
|
3709
|
+
)
|
|
3710
|
+
if any_done:
|
|
3711
|
+
break
|
|
3712
|
+
else:
|
|
3713
|
+
# Write the '_step' entry, indicating which step is to be undertaken
|
|
3714
|
+
_terminated_or_truncated(
|
|
3715
|
+
tensordict,
|
|
3716
|
+
full_done_spec=self.output_spec["full_done_spec"],
|
|
3717
|
+
key="_neg_step",
|
|
3718
|
+
write_full_false=False,
|
|
3719
|
+
)
|
|
3720
|
+
# This is what differentiates _step and _reset: we need to flip _step False -> True
|
|
3721
|
+
partial_step_curr = tensordict.pop("_neg_step", None)
|
|
3722
|
+
if partial_step_curr is not None:
|
|
3723
|
+
partial_step_curr = ~partial_step_curr
|
|
3724
|
+
partial_steps = partial_steps & partial_step_curr
|
|
3725
|
+
if partial_steps is not True:
|
|
3726
|
+
if not partial_steps.any():
|
|
3727
|
+
break
|
|
3728
|
+
# Write the final _step entry
|
|
3729
|
+
tensordict.set("_step", partial_steps)
|
|
3730
|
+
|
|
3731
|
+
if callback is not None:
|
|
3732
|
+
callback(self, tensordict)
|
|
3733
|
+
return tensordicts
|
|
3734
|
+
|
|
3735
|
+
def _rollout_nonstop(
|
|
3736
|
+
self,
|
|
3737
|
+
*,
|
|
3738
|
+
tensordict,
|
|
3739
|
+
auto_cast_to_device,
|
|
3740
|
+
max_steps,
|
|
3741
|
+
policy,
|
|
3742
|
+
policy_device,
|
|
3743
|
+
env_device,
|
|
3744
|
+
storing_device,
|
|
3745
|
+
callback,
|
|
3746
|
+
):
|
|
3747
|
+
if auto_cast_to_device:
|
|
3748
|
+
sync_func = _get_sync_func(policy_device, env_device)
|
|
3749
|
+
tensordicts = []
|
|
3750
|
+
tensordict_ = tensordict
|
|
3751
|
+
for i in range(max_steps):
|
|
3752
|
+
if auto_cast_to_device:
|
|
3753
|
+
if policy_device is not None:
|
|
3754
|
+
tensordict_ = tensordict_.to(policy_device, non_blocking=True)
|
|
3755
|
+
sync_func()
|
|
3756
|
+
else:
|
|
3757
|
+
tensordict_.clear_device_()
|
|
3758
|
+
# In case policy(..) does not modify in-place - no-op for TensorDict and related
|
|
3759
|
+
tensordict_.update(policy(tensordict_))
|
|
3760
|
+
if auto_cast_to_device:
|
|
3761
|
+
if env_device is not None:
|
|
3762
|
+
tensordict_ = tensordict_.to(env_device, non_blocking=True)
|
|
3763
|
+
sync_func()
|
|
3764
|
+
else:
|
|
3765
|
+
tensordict_.clear_device_()
|
|
3766
|
+
if i == max_steps - 1:
|
|
3767
|
+
tensordict = self.step(tensordict_)
|
|
3768
|
+
else:
|
|
3769
|
+
tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
|
|
3770
|
+
if storing_device is None or tensordict.device == storing_device:
|
|
3771
|
+
tensordicts.append(tensordict)
|
|
3772
|
+
else:
|
|
3773
|
+
tensordicts.append(tensordict.to(storing_device))
|
|
3774
|
+
if i == max_steps - 1:
|
|
3775
|
+
# we don't truncate as one could potentially continue the run
|
|
3776
|
+
break
|
|
3777
|
+
if callback is not None:
|
|
3778
|
+
callback(self, tensordict)
|
|
3779
|
+
|
|
3780
|
+
return tensordicts
|
|
3781
|
+
|
|
3782
|
+
def step_and_maybe_reset(
|
|
3783
|
+
self, tensordict: TensorDictBase
|
|
3784
|
+
) -> tuple[TensorDictBase, TensorDictBase]:
|
|
3785
|
+
"""Runs a step in the environment and (partially) resets it if needed.
|
|
3786
|
+
|
|
3787
|
+
Args:
|
|
3788
|
+
tensordict (TensorDictBase): an input data structure for the :meth:`step`
|
|
3789
|
+
method.
|
|
3790
|
+
|
|
3791
|
+
This method allows to easily code non-stopping rollout functions.
|
|
3792
|
+
|
|
3793
|
+
Examples:
|
|
3794
|
+
>>> from torchrl.envs import ParallelEnv, GymEnv
|
|
3795
|
+
>>> def rollout(env, n):
|
|
3796
|
+
... data_ = env.reset()
|
|
3797
|
+
... result = []
|
|
3798
|
+
... for i in range(n):
|
|
3799
|
+
... data, data_ = env.step_and_maybe_reset(data_)
|
|
3800
|
+
... result.append(data)
|
|
3801
|
+
... return torch.stack(result)
|
|
3802
|
+
>>> env = ParallelEnv(2, lambda: GymEnv("CartPole-v1"))
|
|
3803
|
+
>>> print(rollout(env, 2))
|
|
3804
|
+
TensorDict(
|
|
3805
|
+
fields={
|
|
3806
|
+
done: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
3807
|
+
next: TensorDict(
|
|
3808
|
+
fields={
|
|
3809
|
+
done: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
3810
|
+
observation: Tensor(shape=torch.Size([2, 2, 4]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3811
|
+
reward: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3812
|
+
terminated: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
3813
|
+
truncated: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
3814
|
+
batch_size=torch.Size([2, 2]),
|
|
3815
|
+
device=cpu,
|
|
3816
|
+
is_shared=False),
|
|
3817
|
+
observation: Tensor(shape=torch.Size([2, 2, 4]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
3818
|
+
terminated: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
3819
|
+
truncated: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
3820
|
+
batch_size=torch.Size([2, 2]),
|
|
3821
|
+
device=cpu,
|
|
3822
|
+
is_shared=False)
|
|
3823
|
+
"""
|
|
3824
|
+
if tensordict.device != self.device:
|
|
3825
|
+
tensordict = tensordict.to(self.device)
|
|
3826
|
+
tensordict = self.step(tensordict)
|
|
3827
|
+
# done and truncated are in done_keys
|
|
3828
|
+
# We read if any key is done.
|
|
3829
|
+
tensordict_ = self._step_mdp(tensordict)
|
|
3830
|
+
# if self._post_step_mdp_hooks is not None:
|
|
3831
|
+
# tensordict_ = self._post_step_mdp_hooks(tensordict_)
|
|
3832
|
+
tensordict_ = self.maybe_reset(tensordict_)
|
|
3833
|
+
return tensordict, tensordict_
|
|
3834
|
+
|
|
3835
|
+
# _post_step_mdp_hooks: Callable[[TensorDictBase], TensorDictBase] | None = None
|
|
3836
|
+
|
|
3837
|
+
@property
|
|
3838
|
+
@_cache_value
|
|
3839
|
+
def _simple_done(self):
|
|
3840
|
+
key_set = set(self.full_done_spec.keys())
|
|
3841
|
+
|
|
3842
|
+
_simple_done = "done" in key_set and "terminated" in key_set
|
|
3843
|
+
return _simple_done
|
|
3844
|
+
|
|
3845
|
+
def any_done(self, tensordict: TensorDictBase) -> bool:
|
|
3846
|
+
"""Checks if the tensordict is in a "done" state (or if an element of the batch is).
|
|
3847
|
+
|
|
3848
|
+
Writes the result under the `"_reset"` entry.
|
|
3849
|
+
|
|
3850
|
+
Returns: a bool indicating whether there is an element in the tensordict that is marked
|
|
3851
|
+
as done.
|
|
3852
|
+
|
|
3853
|
+
.. note:: The tensordict passed should be a `"next"` tensordict or equivalent -- i.e., it should not
|
|
3854
|
+
contain a `"next"` value.
|
|
3855
|
+
|
|
3856
|
+
"""
|
|
3857
|
+
if self._simple_done:
|
|
3858
|
+
done = tensordict._get_str("done", default=None)
|
|
3859
|
+
if done is not None:
|
|
3860
|
+
any_done = done.any()
|
|
3861
|
+
else:
|
|
3862
|
+
any_done = False
|
|
3863
|
+
if any_done:
|
|
3864
|
+
tensordict._set_str(
|
|
3865
|
+
"_reset",
|
|
3866
|
+
done.clone(),
|
|
3867
|
+
validated=True,
|
|
3868
|
+
inplace=False,
|
|
3869
|
+
non_blocking=False,
|
|
3870
|
+
)
|
|
3871
|
+
else:
|
|
3872
|
+
any_done = _terminated_or_truncated(
|
|
3873
|
+
tensordict,
|
|
3874
|
+
full_done_spec=self.output_spec["full_done_spec"],
|
|
3875
|
+
key="_reset",
|
|
3876
|
+
)
|
|
3877
|
+
return any_done
|
|
3878
|
+
|
|
3879
|
+
def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
3880
|
+
"""Checks the done keys of the input tensordict and, if needed, resets the environment where it is done.
|
|
3881
|
+
|
|
3882
|
+
Args:
|
|
3883
|
+
tensordict (TensorDictBase): a tensordict coming from the output of :func:`~torchrl.envs.utils.step_mdp`.
|
|
3884
|
+
|
|
3885
|
+
Returns:
|
|
3886
|
+
A tensordict that is identical to the input where the environment was
|
|
3887
|
+
not reset and contains the new reset data where the environment was reset.
|
|
3888
|
+
|
|
3889
|
+
"""
|
|
3890
|
+
any_done = self.any_done(tensordict)
|
|
3891
|
+
if any_done:
|
|
3892
|
+
tensordict = self.reset(tensordict, select_reset_only=True)
|
|
3893
|
+
return tensordict
|
|
3894
|
+
|
|
3895
|
+
def empty_cache(self):
|
|
3896
|
+
"""Erases all the cached values.
|
|
3897
|
+
|
|
3898
|
+
For regular envs, the key lists (reward, done etc) are cached, but in some cases
|
|
3899
|
+
they may change during the execution of the code (eg, when adding a transform).
|
|
3900
|
+
|
|
3901
|
+
"""
|
|
3902
|
+
self._cache.clear()
|
|
3903
|
+
|
|
3904
|
+
@property
|
|
3905
|
+
@_cache_value
|
|
3906
|
+
def reset_keys(self) -> list[NestedKey]:
|
|
3907
|
+
"""Returns a list of reset keys.
|
|
3908
|
+
|
|
3909
|
+
Reset keys are keys that indicate partial reset, in batched, multitask or multiagent
|
|
3910
|
+
settings. They are structured as ``(*prefix, "_reset")`` where ``prefix`` is
|
|
3911
|
+
a (possibly empty) tuple of strings pointing to a tensordict location
|
|
3912
|
+
where a done state can be found.
|
|
3913
|
+
|
|
3914
|
+
Keys are sorted by depth in the data tree.
|
|
3915
|
+
"""
|
|
3916
|
+
reset_keys = sorted(
|
|
3917
|
+
(
|
|
3918
|
+
_replace_last(done_key, "_reset")
|
|
3919
|
+
for (done_key, *_) in self.done_keys_groups
|
|
3920
|
+
),
|
|
3921
|
+
key=_repr_by_depth,
|
|
3922
|
+
)
|
|
3923
|
+
return reset_keys
|
|
3924
|
+
|
|
3925
|
+
@property
|
|
3926
|
+
def _filtered_reset_keys(self):
|
|
3927
|
+
"""Returns only the effective reset keys, discarding nested resets if they're not being used."""
|
|
3928
|
+
reset_keys = self.reset_keys
|
|
3929
|
+
result = []
|
|
3930
|
+
|
|
3931
|
+
def _root(key):
|
|
3932
|
+
if isinstance(key, str):
|
|
3933
|
+
return ()
|
|
3934
|
+
return key[:-1]
|
|
3935
|
+
|
|
3936
|
+
roots = []
|
|
3937
|
+
for reset_key in reset_keys:
|
|
3938
|
+
cur_root = _root(reset_key)
|
|
3939
|
+
for root in roots:
|
|
3940
|
+
if cur_root[: len(root)] == root:
|
|
3941
|
+
break
|
|
3942
|
+
else:
|
|
3943
|
+
roots.append(cur_root)
|
|
3944
|
+
result.append(reset_key)
|
|
3945
|
+
return result
|
|
3946
|
+
|
|
3947
|
+
@property
|
|
3948
|
+
@_cache_value
|
|
3949
|
+
def done_keys_groups(self):
|
|
3950
|
+
"""A list of done keys, grouped as the reset keys.
|
|
3951
|
+
|
|
3952
|
+
This is a list of lists. The outer list has the length of reset keys, the
|
|
3953
|
+
inner lists contain the done keys (eg, done and truncated) that can
|
|
3954
|
+
be read to determine a reset when it is absent.
|
|
3955
|
+
"""
|
|
3956
|
+
# done keys, sorted as reset keys
|
|
3957
|
+
done_keys_group = []
|
|
3958
|
+
roots = set()
|
|
3959
|
+
fds = self.full_done_spec
|
|
3960
|
+
for done_key in self.done_keys:
|
|
3961
|
+
root_name = done_key[:-1] if isinstance(done_key, tuple) else ()
|
|
3962
|
+
root = fds[root_name] if root_name else fds
|
|
3963
|
+
n = len(roots)
|
|
3964
|
+
roots.add(root_name)
|
|
3965
|
+
if len(roots) - n:
|
|
3966
|
+
done_keys_group.append(
|
|
3967
|
+
[
|
|
3968
|
+
unravel_key(root_name + (key,))
|
|
3969
|
+
for key in root.keys(include_nested=False, leaves_only=True)
|
|
3970
|
+
]
|
|
3971
|
+
)
|
|
3972
|
+
return done_keys_group
|
|
3973
|
+
|
|
3974
|
+
def _select_observation_keys(self, tensordict: TensorDictBase) -> Iterator[str]:
|
|
3975
|
+
for key in tensordict.keys():
|
|
3976
|
+
if key.rfind("observation") >= 0:
|
|
3977
|
+
yield key
|
|
3978
|
+
|
|
3979
|
+
def close(self, *, raise_if_closed: bool = True):
|
|
3980
|
+
self.is_closed = True
|
|
3981
|
+
|
|
3982
|
+
def __del__(self):
|
|
3983
|
+
# if del occurs before env has been set up, we don't want a recursion
|
|
3984
|
+
# error
|
|
3985
|
+
if "is_closed" in self.__dict__ and not self.is_closed:
|
|
3986
|
+
try:
|
|
3987
|
+
self.close()
|
|
3988
|
+
except Exception:
|
|
3989
|
+
# a TypeError will typically be raised if the env is deleted when the program ends.
|
|
3990
|
+
# In the future, insignificant changes to the close method may change the error type.
|
|
3991
|
+
# We excplicitely assume that any error raised during closure in
|
|
3992
|
+
# __del__ will not affect the program.
|
|
3993
|
+
pass
|
|
3994
|
+
|
|
3995
|
+
@_maybe_unlock
|
|
3996
|
+
def to(self, device: DEVICE_TYPING) -> EnvBase:
|
|
3997
|
+
device = _make_ordinal_device(torch.device(device))
|
|
3998
|
+
if device == self.device:
|
|
3999
|
+
return self
|
|
4000
|
+
self.__dict__["_input_spec"] = self.input_spec.to(device)
|
|
4001
|
+
self.__dict__["_output_spec"] = self.output_spec.to(device)
|
|
4002
|
+
self._device = device
|
|
4003
|
+
return super().to(device)
|
|
4004
|
+
|
|
4005
|
+
def fake_tensordict(self) -> TensorDictBase:
|
|
4006
|
+
"""Returns a fake tensordict with key-value pairs that match in shape, device and dtype what can be expected during an environment rollout."""
|
|
4007
|
+
state_spec = self.state_spec
|
|
4008
|
+
observation_spec = self.observation_spec
|
|
4009
|
+
action_spec = self.input_spec["full_action_spec"]
|
|
4010
|
+
# instantiates reward_spec if needed
|
|
4011
|
+
_ = self.full_reward_spec
|
|
4012
|
+
reward_spec = self.output_spec["full_reward_spec"]
|
|
4013
|
+
full_done_spec = self.output_spec["full_done_spec"]
|
|
4014
|
+
|
|
4015
|
+
fake_obs = observation_spec.zero()
|
|
4016
|
+
fake_reward = reward_spec.zero()
|
|
4017
|
+
fake_done = full_done_spec.zero()
|
|
4018
|
+
fake_state = state_spec.zero()
|
|
4019
|
+
fake_action = action_spec.zero()
|
|
4020
|
+
|
|
4021
|
+
if any(
|
|
4022
|
+
isinstance(val, LazyStackedTensorDict) for val in fake_action.values(True)
|
|
4023
|
+
):
|
|
4024
|
+
fake_input = fake_action.update(fake_state)
|
|
4025
|
+
else:
|
|
4026
|
+
fake_input = fake_state.update(fake_action)
|
|
4027
|
+
|
|
4028
|
+
# the input and output key may match, but the output prevails
|
|
4029
|
+
# Hence we generate the input, and override using the output
|
|
4030
|
+
fake_in_out = fake_input.update(fake_obs)
|
|
4031
|
+
|
|
4032
|
+
next_output = fake_obs.clone()
|
|
4033
|
+
next_output.update(fake_reward)
|
|
4034
|
+
next_output.update(fake_done)
|
|
4035
|
+
fake_in_out.update(fake_done.clone())
|
|
4036
|
+
if "next" not in fake_in_out.keys():
|
|
4037
|
+
fake_in_out.set("next", next_output)
|
|
4038
|
+
else:
|
|
4039
|
+
fake_in_out.get("next").update(next_output)
|
|
4040
|
+
|
|
4041
|
+
fake_in_out.batch_size = self.batch_size
|
|
4042
|
+
fake_in_out = fake_in_out.to(self.device)
|
|
4043
|
+
return fake_in_out
|
|
4044
|
+
|
|
4045
|
+
|
|
4046
|
+
class _EnvWrapper(EnvBase):
|
|
4047
|
+
"""Abstract environment wrapper class.
|
|
4048
|
+
|
|
4049
|
+
Unlike EnvBase, _EnvWrapper comes with a :obj:`_build_env` private method that will be called upon instantiation.
|
|
4050
|
+
Interfaces with other libraries should be coded using _EnvWrapper.
|
|
4051
|
+
|
|
4052
|
+
It is possible to directly query attributed from the nested environment it its name does not conflict with
|
|
4053
|
+
an attribute of the wrapper:
|
|
4054
|
+
>>> env = SomeWrapper(...)
|
|
4055
|
+
>>> custom_attribute0 = env._env.custom_attribute
|
|
4056
|
+
>>> custom_attribute1 = env.custom_attribute
|
|
4057
|
+
>>> assert custom_attribute0 is custom_attribute1 # should return True
|
|
4058
|
+
|
|
4059
|
+
"""
|
|
4060
|
+
|
|
4061
|
+
git_url: str = ""
|
|
4062
|
+
available_envs: dict[str, Any] = {}
|
|
4063
|
+
libname: str = ""
|
|
4064
|
+
|
|
4065
|
+
def __init__(
|
|
4066
|
+
self,
|
|
4067
|
+
*args,
|
|
4068
|
+
device: DEVICE_TYPING = None,
|
|
4069
|
+
batch_size: torch.Size | None = None,
|
|
4070
|
+
allow_done_after_reset: bool = False,
|
|
4071
|
+
spec_locked: bool = True,
|
|
4072
|
+
**kwargs,
|
|
4073
|
+
):
|
|
4074
|
+
super().__init__(
|
|
4075
|
+
device=device,
|
|
4076
|
+
batch_size=batch_size,
|
|
4077
|
+
allow_done_after_reset=allow_done_after_reset,
|
|
4078
|
+
spec_locked=spec_locked,
|
|
4079
|
+
)
|
|
4080
|
+
if len(args):
|
|
4081
|
+
raise ValueError(
|
|
4082
|
+
"`_EnvWrapper.__init__` received a non-empty args list of arguments. "
|
|
4083
|
+
"Make sure only keywords arguments are used when calling `super().__init__`."
|
|
4084
|
+
)
|
|
4085
|
+
|
|
4086
|
+
frame_skip = kwargs.pop("frame_skip", 1)
|
|
4087
|
+
if not isinstance(frame_skip, int):
|
|
4088
|
+
raise ValueError(f"frame_skip must be an integer, got {frame_skip}")
|
|
4089
|
+
self.frame_skip = frame_skip
|
|
4090
|
+
# this value can be changed if frame_skip is passed during env construction
|
|
4091
|
+
self.wrapper_frame_skip = frame_skip
|
|
4092
|
+
|
|
4093
|
+
self._constructor_kwargs = kwargs
|
|
4094
|
+
self._check_kwargs(kwargs)
|
|
4095
|
+
self._convert_actions_to_numpy = kwargs.pop("convert_actions_to_numpy", True)
|
|
4096
|
+
self._env = self._build_env(**kwargs) # writes the self._env attribute
|
|
4097
|
+
self._make_specs(self._env) # writes the self._env attribute
|
|
4098
|
+
self.is_closed = False
|
|
4099
|
+
self._init_env() # runs all the steps to have a ready-to-use env
|
|
4100
|
+
|
|
4101
|
+
def _sync_device(self):
|
|
4102
|
+
sync_func = self.__dict__.get("_sync_device_val")
|
|
4103
|
+
if sync_func is None:
|
|
4104
|
+
device = self.device
|
|
4105
|
+
if device.type != "cuda":
|
|
4106
|
+
if torch.cuda.is_available():
|
|
4107
|
+
self._sync_device_val = torch.cuda.synchronize
|
|
4108
|
+
elif torch.backends.mps.is_available():
|
|
4109
|
+
self._sync_device_val = torch.cuda.synchronize
|
|
4110
|
+
elif device.type == "cpu":
|
|
4111
|
+
self._sync_device_val = _do_nothing
|
|
4112
|
+
else:
|
|
4113
|
+
self._sync_device_val = _do_nothing
|
|
4114
|
+
return self._sync_device
|
|
4115
|
+
return sync_func
|
|
4116
|
+
|
|
4117
|
+
@abc.abstractmethod
|
|
4118
|
+
def _check_kwargs(self, kwargs: dict):
|
|
4119
|
+
raise NotImplementedError
|
|
4120
|
+
|
|
4121
|
+
def __getattr__(self, attr: str) -> Any:
|
|
4122
|
+
if attr in self.__dir__():
|
|
4123
|
+
return self.__getattribute__(
|
|
4124
|
+
attr
|
|
4125
|
+
) # make sure that appropriate exceptions are raised
|
|
4126
|
+
|
|
4127
|
+
elif attr.startswith("__"):
|
|
4128
|
+
raise AttributeError(
|
|
4129
|
+
"passing built-in private methods is "
|
|
4130
|
+
f"not permitted with type {type(self)}. "
|
|
4131
|
+
f"Got attribute {attr}."
|
|
4132
|
+
)
|
|
4133
|
+
|
|
4134
|
+
elif "_env" in self.__dir__():
|
|
4135
|
+
env = self.__getattribute__("_env")
|
|
4136
|
+
return getattr(env, attr)
|
|
4137
|
+
super().__getattr__(attr)
|
|
4138
|
+
|
|
4139
|
+
raise AttributeError(
|
|
4140
|
+
f"The env wasn't set in {self.__class__.__name__}, cannot access {attr}"
|
|
4141
|
+
)
|
|
4142
|
+
|
|
4143
|
+
@abc.abstractmethod
|
|
4144
|
+
def _init_env(self) -> int | None:
|
|
4145
|
+
"""Runs all the necessary steps such that the environment is ready to use.
|
|
4146
|
+
|
|
4147
|
+
This step is intended to ensure that a seed is provided to the environment (if needed) and that the environment
|
|
4148
|
+
is reset (if needed). For instance, DMControl envs require the env to be reset before being used, but Gym envs
|
|
4149
|
+
don't.
|
|
4150
|
+
|
|
4151
|
+
Returns:
|
|
4152
|
+
the resulting seed
|
|
4153
|
+
|
|
4154
|
+
"""
|
|
4155
|
+
raise NotImplementedError
|
|
4156
|
+
|
|
4157
|
+
@abc.abstractmethod
|
|
4158
|
+
def _build_env(self, **kwargs) -> gym.Env: # noqa: F821
|
|
4159
|
+
"""Creates an environment from the target library and stores it with the `_env` attribute.
|
|
4160
|
+
|
|
4161
|
+
When overwritten, this function should pass all the required kwargs to the env instantiation method.
|
|
4162
|
+
|
|
4163
|
+
"""
|
|
4164
|
+
raise NotImplementedError
|
|
4165
|
+
|
|
4166
|
+
@abc.abstractmethod
|
|
4167
|
+
def _make_specs(self, env: gym.Env) -> None: # noqa: F821
|
|
4168
|
+
raise NotImplementedError
|
|
4169
|
+
|
|
4170
|
+
def close(self, *, raise_if_closed: bool = True) -> None:
|
|
4171
|
+
"""Closes the contained environment if possible."""
|
|
4172
|
+
self.is_closed = True
|
|
4173
|
+
try:
|
|
4174
|
+
self._env.close()
|
|
4175
|
+
except AttributeError:
|
|
4176
|
+
pass
|
|
4177
|
+
|
|
4178
|
+
|
|
4179
|
+
def make_tensordict(
|
|
4180
|
+
env: _EnvWrapper,
|
|
4181
|
+
policy: Callable[[TensorDictBase, ...], TensorDictBase] | None = None,
|
|
4182
|
+
) -> TensorDictBase:
|
|
4183
|
+
"""Returns a zeroed-tensordict with fields matching those required for a full step (action selection and environment step) in the environment.
|
|
4184
|
+
|
|
4185
|
+
Args:
|
|
4186
|
+
env (_EnvWrapper): environment defining the observation, action and reward space;
|
|
4187
|
+
policy (Callable, optional): policy corresponding to the environment.
|
|
4188
|
+
|
|
4189
|
+
"""
|
|
4190
|
+
with torch.no_grad():
|
|
4191
|
+
tensordict = env.reset()
|
|
4192
|
+
if policy is not None:
|
|
4193
|
+
tensordict.update(policy(tensordict))
|
|
4194
|
+
else:
|
|
4195
|
+
tensordict.set("action", env.action_spec.rand(), inplace=False)
|
|
4196
|
+
tensordict = env.step(tensordict)
|
|
4197
|
+
return tensordict.zero_()
|
|
4198
|
+
|
|
4199
|
+
|
|
4200
|
+
def _get_sync_func(policy_device, env_device):
|
|
4201
|
+
if torch.cuda.is_available():
|
|
4202
|
+
# Look for a specific device
|
|
4203
|
+
if policy_device is not None and policy_device.type == "cuda":
|
|
4204
|
+
if env_device is None or env_device.type == "cuda":
|
|
4205
|
+
return torch.cuda.synchronize
|
|
4206
|
+
return partial(torch.cuda.synchronize, device=policy_device)
|
|
4207
|
+
if env_device is not None and env_device.type == "cuda":
|
|
4208
|
+
if policy_device is None:
|
|
4209
|
+
return torch.cuda.synchronize
|
|
4210
|
+
return partial(torch.cuda.synchronize, device=env_device)
|
|
4211
|
+
return torch.cuda.synchronize
|
|
4212
|
+
if torch.backends.mps.is_available():
|
|
4213
|
+
return torch.mps.synchronize
|
|
4214
|
+
return _do_nothing
|
|
4215
|
+
|
|
4216
|
+
|
|
4217
|
+
def _do_nothing():
|
|
4218
|
+
return
|
|
4219
|
+
|
|
4220
|
+
|
|
4221
|
+
def _has_dynamic_specs(spec: Composite):
|
|
4222
|
+
from tensordict.base import _NESTED_TENSORS_AS_LISTS
|
|
4223
|
+
|
|
4224
|
+
return any(
|
|
4225
|
+
any(s == -1 for s in spec.shape)
|
|
4226
|
+
for spec in spec.values(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS)
|
|
4227
|
+
)
|
|
4228
|
+
|
|
4229
|
+
|
|
4230
|
+
def _tensor_to_spec(name, leaf, leaf_compare=None, *, stack):
|
|
4231
|
+
if not (isinstance(leaf, torch.Tensor) or is_tensor_collection(leaf)):
|
|
4232
|
+
stack[name] = NonTensor(shape=())
|
|
4233
|
+
return
|
|
4234
|
+
elif is_non_tensor(leaf):
|
|
4235
|
+
stack[name] = NonTensor(shape=leaf.shape)
|
|
4236
|
+
return
|
|
4237
|
+
shape = leaf.shape
|
|
4238
|
+
if leaf_compare is not None:
|
|
4239
|
+
shape_compare = leaf_compare.shape
|
|
4240
|
+
shape = [s0 if s0 == s1 else -1 for s0, s1 in zip(shape, shape_compare)]
|
|
4241
|
+
stack[name] = Unbounded(shape, device=leaf.device, dtype=leaf.dtype)
|