torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,3093 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import functools
|
|
9
|
+
import gc
|
|
10
|
+
import os
|
|
11
|
+
import time
|
|
12
|
+
import warnings
|
|
13
|
+
import weakref
|
|
14
|
+
from collections import OrderedDict
|
|
15
|
+
from collections.abc import Callable, Mapping, Sequence
|
|
16
|
+
from copy import deepcopy
|
|
17
|
+
from functools import wraps
|
|
18
|
+
from multiprocessing import connection
|
|
19
|
+
from multiprocessing.connection import wait as connection_wait
|
|
20
|
+
from multiprocessing.synchronize import Lock as MpLock
|
|
21
|
+
from typing import Any
|
|
22
|
+
from warnings import warn
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
from tensordict import (
|
|
26
|
+
is_tensor_collection,
|
|
27
|
+
LazyStackedTensorDict,
|
|
28
|
+
TensorDict,
|
|
29
|
+
TensorDictBase,
|
|
30
|
+
unravel_key,
|
|
31
|
+
)
|
|
32
|
+
from tensordict.base import _is_leaf_nontensor
|
|
33
|
+
from tensordict.utils import _zip_strict
|
|
34
|
+
from torch import multiprocessing as mp
|
|
35
|
+
|
|
36
|
+
from torchrl._utils import (
|
|
37
|
+
_check_for_faulty_process,
|
|
38
|
+
_get_default_mp_start_method,
|
|
39
|
+
_make_ordinal_device,
|
|
40
|
+
logger as torchrl_logger,
|
|
41
|
+
rl_warnings,
|
|
42
|
+
timeit,
|
|
43
|
+
VERBOSE,
|
|
44
|
+
)
|
|
45
|
+
from torchrl.data.tensor_specs import Composite, NonTensor
|
|
46
|
+
from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING
|
|
47
|
+
from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, EnvMetaData
|
|
48
|
+
|
|
49
|
+
from torchrl.envs.env_creator import EnvCreator, get_env_metadata
|
|
50
|
+
|
|
51
|
+
from torchrl.envs.utils import (
|
|
52
|
+
_aggregate_end_of_traj,
|
|
53
|
+
_sort_keys,
|
|
54
|
+
_update_during_reset,
|
|
55
|
+
clear_mpi_env_vars,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
_CONSOLIDATE_ERR_CAPTURE = (
|
|
59
|
+
"TensorDict.consolidate failed. You can deactivate the tensordict consolidation via the "
|
|
60
|
+
"`consolidate` keyword argument of the ParallelEnv constructor."
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _check_start(fun):
|
|
65
|
+
def decorated_fun(self: BatchedEnvBase, *args, **kwargs):
|
|
66
|
+
if self.is_closed:
|
|
67
|
+
self._create_td()
|
|
68
|
+
self._start_workers()
|
|
69
|
+
else:
|
|
70
|
+
if isinstance(self, ParallelEnv):
|
|
71
|
+
_check_for_faulty_process(self._workers)
|
|
72
|
+
return fun(self, *args, **kwargs)
|
|
73
|
+
|
|
74
|
+
return decorated_fun
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class _dispatch_caller_parallel:
|
|
78
|
+
def __init__(self, attr, parallel_env):
|
|
79
|
+
self.attr = attr
|
|
80
|
+
self.parallel_env = parallel_env
|
|
81
|
+
|
|
82
|
+
def __call__(self, *args, **kwargs):
|
|
83
|
+
# remove self from args
|
|
84
|
+
args = [_arg if _arg is not self.parallel_env else "_self" for _arg in args]
|
|
85
|
+
for channel in self.parallel_env.parent_channels:
|
|
86
|
+
channel.send((self.attr, (args, kwargs)))
|
|
87
|
+
|
|
88
|
+
results = []
|
|
89
|
+
for channel in self.parallel_env.parent_channels:
|
|
90
|
+
msg, result = channel.recv()
|
|
91
|
+
results.append(result)
|
|
92
|
+
|
|
93
|
+
return results
|
|
94
|
+
|
|
95
|
+
def __iter__(self):
|
|
96
|
+
# if the object returned is not a callable
|
|
97
|
+
return iter(self.__call__())
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class _dispatch_caller_serial:
|
|
101
|
+
def __init__(self, list_callable: list[Callable, Any]):
|
|
102
|
+
self.list_callable = list_callable
|
|
103
|
+
|
|
104
|
+
def __call__(self, *args, **kwargs):
|
|
105
|
+
return [_callable(*args, **kwargs) for _callable in self.list_callable]
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def lazy_property(prop: property):
|
|
109
|
+
"""Converts a property in a lazy property, that will call _set_properties when queried the first time."""
|
|
110
|
+
return property(fget=lazy(prop.fget), fset=prop.fset)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def lazy(fun):
|
|
114
|
+
"""Converts a fun in a lazy fun, that will call _set_properties when queried the first time."""
|
|
115
|
+
|
|
116
|
+
@wraps(fun)
|
|
117
|
+
def new_fun(self, *args, **kwargs):
|
|
118
|
+
if not self._properties_set:
|
|
119
|
+
self._set_properties()
|
|
120
|
+
return fun(self, *args, **kwargs)
|
|
121
|
+
|
|
122
|
+
return new_fun
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _is_unpicklable_lambda(fn: Callable) -> bool:
|
|
126
|
+
"""Check if a callable is a lambda function that needs cloudpickle wrapping.
|
|
127
|
+
|
|
128
|
+
Lambda functions cannot be pickled with standard pickle, so they need to be
|
|
129
|
+
wrapped with EnvCreator (which uses CloudpickleWrapper) for multiprocessing.
|
|
130
|
+
functools.partial objects are picklable, so they don't need wrapping.
|
|
131
|
+
"""
|
|
132
|
+
if isinstance(fn, functools.partial):
|
|
133
|
+
return False
|
|
134
|
+
return callable(fn) and getattr(fn, "__name__", None) == "<lambda>"
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class _PEnvMeta(_EnvPostInit):
|
|
138
|
+
def __call__(cls, *args, **kwargs):
|
|
139
|
+
serial_for_single = kwargs.pop("serial_for_single", False)
|
|
140
|
+
if serial_for_single:
|
|
141
|
+
num_workers = kwargs.get("num_workers")
|
|
142
|
+
# Remove start method from kwargs
|
|
143
|
+
kwargs.pop("mp_start_method", None)
|
|
144
|
+
if num_workers is None:
|
|
145
|
+
num_workers = args[0]
|
|
146
|
+
if num_workers == 1:
|
|
147
|
+
# We still use a serial to keep the shape unchanged
|
|
148
|
+
return SerialEnv(*args, **kwargs)
|
|
149
|
+
|
|
150
|
+
# Wrap lambda functions with EnvCreator so they can be pickled for
|
|
151
|
+
# multiprocessing with the spawn start method. Lambda functions cannot
|
|
152
|
+
# be serialized with standard pickle, but EnvCreator uses cloudpickle.
|
|
153
|
+
auto_wrap_envs = kwargs.pop("auto_wrap_envs", True)
|
|
154
|
+
|
|
155
|
+
def _warn_lambda():
|
|
156
|
+
if rl_warnings():
|
|
157
|
+
warnings.warn(
|
|
158
|
+
"A lambda function was passed to ParallelEnv and will be wrapped "
|
|
159
|
+
"in an EnvCreator. This causes the environment to be instantiated "
|
|
160
|
+
"in the main process to extract metadata. Consider using "
|
|
161
|
+
"functools.partial instead, which is natively serializable and "
|
|
162
|
+
"avoids this overhead. To suppress this warning, set the "
|
|
163
|
+
"RL_WARNINGS=0 environment variable.",
|
|
164
|
+
category=UserWarning,
|
|
165
|
+
stacklevel=4,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def _wrap_lambdas(create_env_fn):
|
|
169
|
+
if callable(create_env_fn) and _is_unpicklable_lambda(create_env_fn):
|
|
170
|
+
_warn_lambda()
|
|
171
|
+
return EnvCreator(create_env_fn)
|
|
172
|
+
if isinstance(create_env_fn, Sequence):
|
|
173
|
+
# Reuse EnvCreator for identical function objects to preserve
|
|
174
|
+
# _single_task detection (e.g., when [lambda_fn] * 3 is passed)
|
|
175
|
+
wrapped = {}
|
|
176
|
+
result = []
|
|
177
|
+
warned = False
|
|
178
|
+
for fn in create_env_fn:
|
|
179
|
+
if _is_unpicklable_lambda(fn):
|
|
180
|
+
fn_id = id(fn)
|
|
181
|
+
if fn_id not in wrapped:
|
|
182
|
+
if not warned:
|
|
183
|
+
_warn_lambda()
|
|
184
|
+
warned = True
|
|
185
|
+
wrapped[fn_id] = EnvCreator(fn)
|
|
186
|
+
result.append(wrapped[fn_id])
|
|
187
|
+
else:
|
|
188
|
+
result.append(fn)
|
|
189
|
+
return result
|
|
190
|
+
return create_env_fn
|
|
191
|
+
|
|
192
|
+
if auto_wrap_envs:
|
|
193
|
+
if "create_env_fn" in kwargs:
|
|
194
|
+
kwargs["create_env_fn"] = _wrap_lambdas(kwargs["create_env_fn"])
|
|
195
|
+
elif len(args) >= 2:
|
|
196
|
+
args = (args[0], _wrap_lambdas(args[1])) + args[2:]
|
|
197
|
+
|
|
198
|
+
return super().__call__(*args, **kwargs)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class BatchedEnvBase(EnvBase):
|
|
202
|
+
"""Batched environments allow the user to query an arbitrary method / attribute of the environment running remotely.
|
|
203
|
+
|
|
204
|
+
Those queries will return a list of length equal to the number of workers containing the
|
|
205
|
+
values resulting from those queries.
|
|
206
|
+
|
|
207
|
+
Example:
|
|
208
|
+
>>> env = ParallelEnv(3, my_env_fun)
|
|
209
|
+
>>> custom_attribute_list = env.custom_attribute
|
|
210
|
+
>>> custom_method_list = env.custom_method(*args)
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
num_workers: number of workers (i.e. env instances) to be deployed simultaneously;
|
|
214
|
+
create_env_fn (callable or list of callables): function (or list of functions) to be used for the environment
|
|
215
|
+
creation.
|
|
216
|
+
If a single task is used, a callable should be used and not a list of identical callables:
|
|
217
|
+
if a list of callable is provided, the environment will be executed as if multiple, diverse tasks were
|
|
218
|
+
needed, which comes with a slight compute overhead;
|
|
219
|
+
|
|
220
|
+
Keyword Args:
|
|
221
|
+
create_env_kwargs (dict or list of dicts, optional): kwargs to be used with the environments being created;
|
|
222
|
+
share_individual_td (bool, optional): if ``True``, a different tensordict is created for every process/worker and a lazy
|
|
223
|
+
stack is returned.
|
|
224
|
+
default = None (False if single task);
|
|
225
|
+
shared_memory (bool): whether the returned tensordict will be placed in shared memory;
|
|
226
|
+
memmap (bool): whether the returned tensordict will be placed in memory map.
|
|
227
|
+
policy_proof (callable, optional): if provided, it'll be used to get the list of
|
|
228
|
+
tensors to return through the :obj:`step()` and :obj:`reset()` methods, such as :obj:`"hidden"` etc.
|
|
229
|
+
device (str, int, torch.device): The device of the batched environment can be passed.
|
|
230
|
+
If not, it is inferred from the env. In this case, it is assumed that
|
|
231
|
+
the device of all environments match. If it is provided, it can differ
|
|
232
|
+
from the sub-environment device(s). In that case, the data will be
|
|
233
|
+
automatically cast to the appropriate device during collection.
|
|
234
|
+
This can be used to speed up collection in case casting to device
|
|
235
|
+
introduces an overhead (eg, numpy-based environents etc.): by using
|
|
236
|
+
a ``"cuda"`` device for the batched environment but a ``"cpu"``
|
|
237
|
+
device for the nested environments, one can keep the overhead to a
|
|
238
|
+
minimum.
|
|
239
|
+
num_threads (int, optional): number of threads for this process.
|
|
240
|
+
Should be equal to one plus the number of processes launched within
|
|
241
|
+
each subprocess (or one if a single process is launched).
|
|
242
|
+
Defaults to the number of workers + 1.
|
|
243
|
+
This parameter has no effect for the :class:`~SerialEnv` class.
|
|
244
|
+
num_sub_threads (int, optional): number of threads of the subprocesses.
|
|
245
|
+
Defaults to 1 for safety: if none is indicated, launching multiple
|
|
246
|
+
workers may charge the cpu load too much and harm performance.
|
|
247
|
+
This parameter has no effect for the :class:`~SerialEnv` class.
|
|
248
|
+
serial_for_single (bool, optional): if ``True``, creating a parallel environment
|
|
249
|
+
with a single worker will return a :class:`~SerialEnv` instead.
|
|
250
|
+
This option has no effect with :class:`~SerialEnv`. Defaults to ``False``.
|
|
251
|
+
non_blocking (bool, optional): if ``True``, device moves will be done using the
|
|
252
|
+
``non_blocking=True`` option. Defaults to ``True``.
|
|
253
|
+
mp_start_method (str, optional): the multiprocessing start method.
|
|
254
|
+
Uses the default start method if not indicated ('spawn' by default in
|
|
255
|
+
TorchRL if not initiated differently before first import).
|
|
256
|
+
To be used only with :class:`~torchrl.envs.ParallelEnv` subclasses.
|
|
257
|
+
use_buffers (bool, optional): whether communication between workers should
|
|
258
|
+
occur via circular preallocated memory buffers. Defaults to ``True`` unless
|
|
259
|
+
one of the environment has dynamic specs.
|
|
260
|
+
|
|
261
|
+
.. note:: Learn more about dynamic specs and environments :ref:`here <dynamic_envs>`.
|
|
262
|
+
daemon (bool, optional): whether the processes should be daemonized.
|
|
263
|
+
This is only applicable to parallel environments such as :class:`~torchrl.envs.ParallelEnv`.
|
|
264
|
+
Defaults to ``False``.
|
|
265
|
+
auto_wrap_envs (bool, optional): if ``True`` (default), lambda functions passed as
|
|
266
|
+
``create_env_fn`` will be automatically wrapped in an :class:`~torchrl.envs.EnvCreator`
|
|
267
|
+
to enable pickling for multiprocessing with the ``spawn`` start method.
|
|
268
|
+
This wrapping causes the environment to be instantiated once in the main process
|
|
269
|
+
(to extract metadata) before workers are started.
|
|
270
|
+
If this is undesirable, set ``auto_wrap_envs=False``. Otherwise, ensure your callable is
|
|
271
|
+
serializable (e.g., use :func:`functools.partial` instead of lambdas).
|
|
272
|
+
This parameter only affects :class:`~torchrl.envs.ParallelEnv`.
|
|
273
|
+
Defaults to ``True``.
|
|
274
|
+
|
|
275
|
+
.. note::
|
|
276
|
+
For :class:`~torchrl.envs.ParallelEnv`, it is recommended to use :func:`functools.partial`
|
|
277
|
+
instead of lambda functions when possible, as ``partial`` objects are natively serializable
|
|
278
|
+
and avoid the overhead of :class:`~torchrl.envs.EnvCreator` wrapping.
|
|
279
|
+
|
|
280
|
+
.. note::
|
|
281
|
+
One can pass keyword arguments to each sub-environments using the following
|
|
282
|
+
technique: every keyword argument in :meth:`reset` will be passed to each
|
|
283
|
+
environment except for the ``list_of_kwargs`` argument which, if present,
|
|
284
|
+
should contain a list of the same length as the number of workers with the
|
|
285
|
+
worker-specific keyword arguments stored in a dictionary.
|
|
286
|
+
If a partial reset is queried, the element of ``list_of_kwargs`` corresponding
|
|
287
|
+
to sub-environments that are not reset will be ignored.
|
|
288
|
+
|
|
289
|
+
Examples:
|
|
290
|
+
>>> from torchrl.envs import GymEnv, ParallelEnv, SerialEnv, EnvCreator
|
|
291
|
+
>>> make_env = EnvCreator(lambda: GymEnv("Pendulum-v1")) # EnvCreator ensures that the env is sharable. Optional in most cases.
|
|
292
|
+
>>> env = SerialEnv(2, make_env) # Makes 2 identical copies of the Pendulum env, runs them on the same process serially
|
|
293
|
+
>>> env = ParallelEnv(2, make_env) # Makes 2 identical copies of the Pendulum env, runs them on dedicated processes
|
|
294
|
+
>>> from torchrl.envs import DMControlEnv
|
|
295
|
+
>>> env = ParallelEnv(2, [
|
|
296
|
+
... lambda: DMControlEnv("humanoid", "stand"),
|
|
297
|
+
... lambda: DMControlEnv("humanoid", "walk")]) # Creates two independent copies of Humanoid, one that walks one that stands
|
|
298
|
+
>>> rollout = env.rollout(10) # executes 10 random steps in the environment
|
|
299
|
+
>>> rollout[0] # data for Humanoid stand
|
|
300
|
+
TensorDict(
|
|
301
|
+
fields={
|
|
302
|
+
action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
303
|
+
com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
304
|
+
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
305
|
+
extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
306
|
+
head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
307
|
+
joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
308
|
+
next: TensorDict(
|
|
309
|
+
fields={
|
|
310
|
+
com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
311
|
+
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
312
|
+
extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
313
|
+
head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
314
|
+
joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
315
|
+
reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
316
|
+
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
317
|
+
torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
318
|
+
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
319
|
+
velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)},
|
|
320
|
+
batch_size=torch.Size([10]),
|
|
321
|
+
device=cpu,
|
|
322
|
+
is_shared=False),
|
|
323
|
+
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
324
|
+
torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
325
|
+
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
326
|
+
velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)},
|
|
327
|
+
batch_size=torch.Size([10]),
|
|
328
|
+
device=cpu,
|
|
329
|
+
is_shared=False)
|
|
330
|
+
>>> rollout[1] # data for Humanoid walk
|
|
331
|
+
TensorDict(
|
|
332
|
+
fields={
|
|
333
|
+
action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
334
|
+
com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
335
|
+
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
336
|
+
extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
337
|
+
head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
338
|
+
joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
339
|
+
next: TensorDict(
|
|
340
|
+
fields={
|
|
341
|
+
com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
342
|
+
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
343
|
+
extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
344
|
+
head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
345
|
+
joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
346
|
+
reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
347
|
+
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
348
|
+
torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
349
|
+
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
350
|
+
velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)},
|
|
351
|
+
batch_size=torch.Size([10]),
|
|
352
|
+
device=cpu,
|
|
353
|
+
is_shared=False),
|
|
354
|
+
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
355
|
+
torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
356
|
+
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
357
|
+
velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)},
|
|
358
|
+
batch_size=torch.Size([10]),
|
|
359
|
+
device=cpu,
|
|
360
|
+
is_shared=False)
|
|
361
|
+
>>> # serial_for_single to avoid creating parallel envs if not necessary
|
|
362
|
+
>>> env = ParallelEnv(1, make_env, serial_for_single=True)
|
|
363
|
+
>>> assert isinstance(env, SerialEnv) # serial_for_single allows you to avoid creating parallel envs when not necessary
|
|
364
|
+
"""
|
|
365
|
+
|
|
366
|
+
_verbose: bool = VERBOSE
|
|
367
|
+
_excluded_wrapped_keys = [
|
|
368
|
+
"is_closed",
|
|
369
|
+
"parent_channels",
|
|
370
|
+
"batch_size",
|
|
371
|
+
"_dummy_env_str",
|
|
372
|
+
]
|
|
373
|
+
|
|
374
|
+
def __init__(
|
|
375
|
+
self,
|
|
376
|
+
num_workers: int,
|
|
377
|
+
create_env_fn: Callable[[], EnvBase] | Sequence[Callable[[], EnvBase]],
|
|
378
|
+
*,
|
|
379
|
+
create_env_kwargs: dict | Sequence[dict] = None,
|
|
380
|
+
pin_memory: bool = False,
|
|
381
|
+
share_individual_td: bool | None = None,
|
|
382
|
+
shared_memory: bool = True,
|
|
383
|
+
memmap: bool = False,
|
|
384
|
+
policy_proof: Callable | None = None,
|
|
385
|
+
device: DEVICE_TYPING | None = None,
|
|
386
|
+
allow_step_when_done: bool = False,
|
|
387
|
+
num_threads: int | None = None,
|
|
388
|
+
num_sub_threads: int = 1,
|
|
389
|
+
serial_for_single: bool = False,
|
|
390
|
+
non_blocking: bool = False,
|
|
391
|
+
mp_start_method: str | None = None,
|
|
392
|
+
use_buffers: bool | None = None,
|
|
393
|
+
consolidate: bool = True,
|
|
394
|
+
daemon: bool = False,
|
|
395
|
+
):
|
|
396
|
+
super().__init__(device=device)
|
|
397
|
+
self.serial_for_single = serial_for_single
|
|
398
|
+
self.is_closed = True
|
|
399
|
+
self.num_sub_threads = num_sub_threads
|
|
400
|
+
self.num_threads = num_threads
|
|
401
|
+
self._cache_in_keys = None
|
|
402
|
+
self._use_buffers = use_buffers
|
|
403
|
+
self.consolidate = consolidate
|
|
404
|
+
self.daemon = daemon
|
|
405
|
+
|
|
406
|
+
self._single_task = callable(create_env_fn) or (len(set(create_env_fn)) == 1)
|
|
407
|
+
if callable(create_env_fn):
|
|
408
|
+
create_env_fn = [create_env_fn for _ in range(num_workers)]
|
|
409
|
+
elif len(create_env_fn) != num_workers:
|
|
410
|
+
raise RuntimeError(
|
|
411
|
+
f"len(create_env_fn) and num_workers mismatch, "
|
|
412
|
+
f"got {len(create_env_fn)} and {num_workers}."
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
create_env_kwargs = {} if create_env_kwargs is None else create_env_kwargs
|
|
416
|
+
if isinstance(create_env_kwargs, Mapping):
|
|
417
|
+
create_env_kwargs = [
|
|
418
|
+
deepcopy(create_env_kwargs) for _ in range(num_workers)
|
|
419
|
+
]
|
|
420
|
+
elif len(create_env_kwargs) != num_workers:
|
|
421
|
+
raise RuntimeError(
|
|
422
|
+
f"len(create_env_kwargs) and num_workers mismatch, "
|
|
423
|
+
f"got {len(create_env_kwargs)} and {num_workers}."
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
self.policy_proof = policy_proof
|
|
427
|
+
self.num_workers = num_workers
|
|
428
|
+
self.create_env_fn = create_env_fn
|
|
429
|
+
self.create_env_kwargs = create_env_kwargs
|
|
430
|
+
self.pin_memory = pin_memory
|
|
431
|
+
if pin_memory:
|
|
432
|
+
raise ValueError("pin_memory for batched envs is deprecated")
|
|
433
|
+
|
|
434
|
+
# if share_individual_td is None, we will assess later if the output can be stacked
|
|
435
|
+
self.share_individual_td = share_individual_td
|
|
436
|
+
# self._batch_locked = batch_locked
|
|
437
|
+
self._share_memory = shared_memory
|
|
438
|
+
self._memmap = memmap
|
|
439
|
+
self.allow_step_when_done = allow_step_when_done
|
|
440
|
+
if allow_step_when_done:
|
|
441
|
+
raise ValueError("allow_step_when_done is deprecated")
|
|
442
|
+
if self._share_memory and self._memmap:
|
|
443
|
+
raise RuntimeError(
|
|
444
|
+
"memmap and shared memory are mutually exclusive features."
|
|
445
|
+
)
|
|
446
|
+
self._batch_size = None
|
|
447
|
+
self._device = (
|
|
448
|
+
_make_ordinal_device(torch.device(device)) if device is not None else device
|
|
449
|
+
)
|
|
450
|
+
self._dummy_env_str = None
|
|
451
|
+
self._seeds = None
|
|
452
|
+
self.__dict__["_input_spec"] = None
|
|
453
|
+
self.__dict__["_output_spec"] = None
|
|
454
|
+
# self._prepare_dummy_env(create_env_fn, create_env_kwargs)
|
|
455
|
+
self._properties_set = False
|
|
456
|
+
self._get_metadata(create_env_fn, create_env_kwargs)
|
|
457
|
+
self._non_blocking = non_blocking
|
|
458
|
+
if mp_start_method is not None and not isinstance(self, ParallelEnv):
|
|
459
|
+
raise TypeError(
|
|
460
|
+
f"Cannot use mp_start_method={mp_start_method} with envs of type {type(self)}."
|
|
461
|
+
)
|
|
462
|
+
self._mp_start_method = mp_start_method
|
|
463
|
+
|
|
464
|
+
is_spec_locked = EnvBase.is_spec_locked
|
|
465
|
+
|
|
466
|
+
def configure_parallel(
|
|
467
|
+
self,
|
|
468
|
+
*,
|
|
469
|
+
use_buffers: bool | None = None,
|
|
470
|
+
shared_memory: bool | None = None,
|
|
471
|
+
memmap: bool | None = None,
|
|
472
|
+
mp_start_method: str | None = None,
|
|
473
|
+
num_threads: int | None = None,
|
|
474
|
+
num_sub_threads: int | None = None,
|
|
475
|
+
non_blocking: bool | None = None,
|
|
476
|
+
daemon: bool | None = None,
|
|
477
|
+
) -> BatchedEnvBase:
|
|
478
|
+
"""Configure parallel execution parameters before the environment starts.
|
|
479
|
+
|
|
480
|
+
This method allows configuring parameters for parallel environment
|
|
481
|
+
execution. It must be called before the environment is started
|
|
482
|
+
(i.e., before accessing specs or calling reset/step).
|
|
483
|
+
|
|
484
|
+
Args:
|
|
485
|
+
use_buffers (bool, optional): whether communication between workers should
|
|
486
|
+
occur via circular preallocated memory buffers.
|
|
487
|
+
shared_memory (bool, optional): whether the returned tensordict will be
|
|
488
|
+
placed in shared memory.
|
|
489
|
+
memmap (bool, optional): whether the returned tensordict will be placed
|
|
490
|
+
in memory map.
|
|
491
|
+
mp_start_method (str, optional): the multiprocessing start method.
|
|
492
|
+
num_threads (int, optional): number of threads for this process.
|
|
493
|
+
num_sub_threads (int, optional): number of threads of the subprocesses.
|
|
494
|
+
non_blocking (bool, optional): if ``True``, device moves will be done using
|
|
495
|
+
the ``non_blocking=True`` option.
|
|
496
|
+
daemon (bool, optional): whether the processes should be daemonized.
|
|
497
|
+
|
|
498
|
+
Returns:
|
|
499
|
+
self: Returns self for method chaining.
|
|
500
|
+
|
|
501
|
+
Raises:
|
|
502
|
+
RuntimeError: If called after the environment has already started.
|
|
503
|
+
|
|
504
|
+
Example:
|
|
505
|
+
>>> env = ParallelEnv(4, lambda: GymEnv("Pendulum-v1"))
|
|
506
|
+
>>> env.configure_parallel(use_buffers=True, num_threads=2)
|
|
507
|
+
>>> env.reset() # Environment starts here
|
|
508
|
+
|
|
509
|
+
"""
|
|
510
|
+
if not self.is_closed:
|
|
511
|
+
raise RuntimeError(
|
|
512
|
+
"configure_parallel() cannot be called after the environment has started. "
|
|
513
|
+
"Call configure_parallel() before accessing specs or calling reset/step."
|
|
514
|
+
)
|
|
515
|
+
if use_buffers is not None:
|
|
516
|
+
self._use_buffers = use_buffers
|
|
517
|
+
if shared_memory is not None:
|
|
518
|
+
self._share_memory = shared_memory
|
|
519
|
+
if memmap is not None:
|
|
520
|
+
self._memmap = memmap
|
|
521
|
+
if mp_start_method is not None:
|
|
522
|
+
self._mp_start_method = mp_start_method
|
|
523
|
+
if num_threads is not None:
|
|
524
|
+
self.num_threads = num_threads
|
|
525
|
+
if num_sub_threads is not None:
|
|
526
|
+
self.num_sub_threads = num_sub_threads
|
|
527
|
+
if non_blocking is not None:
|
|
528
|
+
self._non_blocking = non_blocking
|
|
529
|
+
if daemon is not None:
|
|
530
|
+
self.daemon = daemon
|
|
531
|
+
return self
|
|
532
|
+
|
|
533
|
+
def select_and_clone(self, name, tensor, selected_keys=None):
|
|
534
|
+
if selected_keys is None:
|
|
535
|
+
selected_keys = self._selected_step_keys
|
|
536
|
+
if name in selected_keys:
|
|
537
|
+
if self.device is not None and tensor.device != self.device:
|
|
538
|
+
return tensor.to(self.device, non_blocking=self.non_blocking)
|
|
539
|
+
return tensor.clone()
|
|
540
|
+
|
|
541
|
+
@property
|
|
542
|
+
def non_blocking(self):
|
|
543
|
+
nb = self._non_blocking
|
|
544
|
+
if nb is None:
|
|
545
|
+
nb = True
|
|
546
|
+
self._non_blocking = nb
|
|
547
|
+
return nb
|
|
548
|
+
|
|
549
|
+
@property
|
|
550
|
+
def _sync_m2w(self) -> Callable:
|
|
551
|
+
sync_func = self.__dict__.get("_sync_m2w_value")
|
|
552
|
+
if sync_func is None:
|
|
553
|
+
sync_m2w, sync_w2m = self._find_sync_values()
|
|
554
|
+
self.__dict__["_sync_m2w_value"] = sync_m2w
|
|
555
|
+
self.__dict__["_sync_w2m_value"] = sync_w2m
|
|
556
|
+
return sync_m2w
|
|
557
|
+
return sync_func
|
|
558
|
+
|
|
559
|
+
@property
|
|
560
|
+
def _sync_w2m(self) -> Callable:
|
|
561
|
+
sync_func = self.__dict__.get("_sync_w2m_value")
|
|
562
|
+
if sync_func is None:
|
|
563
|
+
sync_m2w, sync_w2m = self._find_sync_values()
|
|
564
|
+
self.__dict__["_sync_m2w_value"] = sync_m2w
|
|
565
|
+
self.__dict__["_sync_w2m_value"] = sync_w2m
|
|
566
|
+
return sync_w2m
|
|
567
|
+
return sync_func
|
|
568
|
+
|
|
569
|
+
def _find_sync_values(self):
|
|
570
|
+
"""Returns the m2w and w2m sync values, in that order."""
|
|
571
|
+
if not self._use_buffers:
|
|
572
|
+
return _do_nothing, _do_nothing
|
|
573
|
+
# Simplest case: everything is on the same device
|
|
574
|
+
worker_device = self.shared_tensordict_parent.device
|
|
575
|
+
self_device = self.device
|
|
576
|
+
if not self.non_blocking or (
|
|
577
|
+
worker_device == self_device or self_device is None
|
|
578
|
+
):
|
|
579
|
+
# even if they're both None, there is no device-to-device movement
|
|
580
|
+
return _do_nothing, _do_nothing
|
|
581
|
+
|
|
582
|
+
if worker_device is None:
|
|
583
|
+
worker_not_main = False
|
|
584
|
+
|
|
585
|
+
def find_all_worker_devices(item):
|
|
586
|
+
nonlocal worker_not_main
|
|
587
|
+
if hasattr(item, "device"):
|
|
588
|
+
worker_not_main = worker_not_main or (item.device != self_device)
|
|
589
|
+
|
|
590
|
+
for td in self.shared_tensordicts:
|
|
591
|
+
td.apply(find_all_worker_devices, filter_empty=True)
|
|
592
|
+
if worker_not_main:
|
|
593
|
+
if torch.cuda.is_available():
|
|
594
|
+
worker_device = (
|
|
595
|
+
torch.device("cuda")
|
|
596
|
+
if self_device.type != "cuda"
|
|
597
|
+
else torch.device("cpu")
|
|
598
|
+
)
|
|
599
|
+
elif torch.backends.mps.is_available():
|
|
600
|
+
worker_device = (
|
|
601
|
+
torch.device("mps")
|
|
602
|
+
if self_device.type != "mps"
|
|
603
|
+
else torch.device("cpu")
|
|
604
|
+
)
|
|
605
|
+
else:
|
|
606
|
+
raise RuntimeError("Did not find a valid worker device")
|
|
607
|
+
else:
|
|
608
|
+
worker_device = self_device
|
|
609
|
+
|
|
610
|
+
if (
|
|
611
|
+
worker_device is not None
|
|
612
|
+
and worker_device.type == "cuda"
|
|
613
|
+
and self_device is not None
|
|
614
|
+
and self_device.type == "cpu"
|
|
615
|
+
):
|
|
616
|
+
return _do_nothing, _cuda_sync(worker_device)
|
|
617
|
+
if (
|
|
618
|
+
worker_device is not None
|
|
619
|
+
and worker_device.type == "mps"
|
|
620
|
+
and self_device is not None
|
|
621
|
+
and self_device.type == "cpu"
|
|
622
|
+
):
|
|
623
|
+
return _mps_sync(worker_device), _mps_sync(worker_device)
|
|
624
|
+
if (
|
|
625
|
+
worker_device is not None
|
|
626
|
+
and worker_device.type == "cpu"
|
|
627
|
+
and self_device is not None
|
|
628
|
+
and self_device.type == "cuda"
|
|
629
|
+
):
|
|
630
|
+
return _cuda_sync(self_device), _do_nothing
|
|
631
|
+
if (
|
|
632
|
+
worker_device is not None
|
|
633
|
+
and worker_device.type == "cpu"
|
|
634
|
+
and self_device is not None
|
|
635
|
+
and self_device.type == "mps"
|
|
636
|
+
):
|
|
637
|
+
return _mps_sync(self_device), _mps_sync(self_device)
|
|
638
|
+
return _do_nothing, _do_nothing
|
|
639
|
+
|
|
640
|
+
def __getstate__(self):
|
|
641
|
+
out = self.__dict__.copy()
|
|
642
|
+
out["_sync_m2w_value"] = None
|
|
643
|
+
out["_sync_w2m_value"] = None
|
|
644
|
+
return out
|
|
645
|
+
|
|
646
|
+
@property
|
|
647
|
+
def _has_dynamic_specs(self):
|
|
648
|
+
return not self._use_buffers
|
|
649
|
+
|
|
650
|
+
def _get_metadata(
|
|
651
|
+
self, create_env_fn: list[Callable], create_env_kwargs: list[dict]
|
|
652
|
+
):
|
|
653
|
+
if self._single_task:
|
|
654
|
+
# if EnvCreator, the metadata are already there
|
|
655
|
+
meta_data: EnvMetaData = get_env_metadata(
|
|
656
|
+
create_env_fn[0], create_env_kwargs[0]
|
|
657
|
+
)
|
|
658
|
+
self.meta_data = meta_data.expand(
|
|
659
|
+
*(self.num_workers, *meta_data.batch_size)
|
|
660
|
+
)
|
|
661
|
+
if self._use_buffers is not False:
|
|
662
|
+
_use_buffers = not self.meta_data.has_dynamic_specs
|
|
663
|
+
if self._use_buffers and not _use_buffers:
|
|
664
|
+
warn(
|
|
665
|
+
"A value of use_buffers=True was passed but this is incompatible "
|
|
666
|
+
"with the list of environments provided. Turning use_buffers to False."
|
|
667
|
+
)
|
|
668
|
+
self._use_buffers = _use_buffers
|
|
669
|
+
if self.share_individual_td is None:
|
|
670
|
+
self.share_individual_td = False
|
|
671
|
+
else:
|
|
672
|
+
n_tasks = len(create_env_fn)
|
|
673
|
+
self.meta_data: list[EnvMetaData] = []
|
|
674
|
+
for i in range(n_tasks):
|
|
675
|
+
self.meta_data.append(
|
|
676
|
+
get_env_metadata(create_env_fn[i], create_env_kwargs[i]).clone()
|
|
677
|
+
)
|
|
678
|
+
if self.share_individual_td is not True:
|
|
679
|
+
share_individual_td = not _stackable(
|
|
680
|
+
*[meta_data.tensordict for meta_data in self.meta_data]
|
|
681
|
+
)
|
|
682
|
+
if share_individual_td and self.share_individual_td is False:
|
|
683
|
+
raise ValueError(
|
|
684
|
+
"share_individual_td=False was provided but share_individual_td must "
|
|
685
|
+
"be True to accommodate non-stackable tensors."
|
|
686
|
+
)
|
|
687
|
+
self.share_individual_td = share_individual_td
|
|
688
|
+
_use_buffers = all(
|
|
689
|
+
not metadata.has_dynamic_specs for metadata in self.meta_data
|
|
690
|
+
)
|
|
691
|
+
if self._use_buffers and not _use_buffers:
|
|
692
|
+
warn(
|
|
693
|
+
"A value of use_buffers=True was passed but this is incompatible "
|
|
694
|
+
"with the list of environments provided. Turning use_buffers to False."
|
|
695
|
+
)
|
|
696
|
+
self._use_buffers = _use_buffers
|
|
697
|
+
|
|
698
|
+
self._set_properties()
|
|
699
|
+
|
|
700
|
+
def update_kwargs(self, kwargs: dict | list[dict]) -> None:
|
|
701
|
+
"""Updates the kwargs of each environment given a dictionary or a list of dictionaries.
|
|
702
|
+
|
|
703
|
+
Args:
|
|
704
|
+
kwargs (dict or list of dict): new kwargs to use with the environments
|
|
705
|
+
|
|
706
|
+
"""
|
|
707
|
+
if isinstance(kwargs, dict):
|
|
708
|
+
for _kwargs in self.create_env_kwargs:
|
|
709
|
+
_kwargs.update(kwargs)
|
|
710
|
+
else:
|
|
711
|
+
if len(kwargs) != self.num_workers:
|
|
712
|
+
raise RuntimeError(
|
|
713
|
+
f"len(kwargs) and num_workers mismatch, got {len(kwargs)} and {self.num_workers}."
|
|
714
|
+
)
|
|
715
|
+
for _kwargs, _new_kwargs in _zip_strict(self.create_env_kwargs, kwargs):
|
|
716
|
+
_kwargs.update(_new_kwargs)
|
|
717
|
+
|
|
718
|
+
def _get_in_keys_to_exclude(self, tensordict):
|
|
719
|
+
if self._cache_in_keys is None:
|
|
720
|
+
self._cache_in_keys = list(
|
|
721
|
+
set(self.input_spec.keys(True)).intersection(
|
|
722
|
+
tensordict.keys(True, True)
|
|
723
|
+
)
|
|
724
|
+
)
|
|
725
|
+
return self._cache_in_keys
|
|
726
|
+
|
|
727
|
+
def _set_properties(self):
|
|
728
|
+
|
|
729
|
+
cls = type(self)
|
|
730
|
+
|
|
731
|
+
def _check_for_empty_spec(specs: Composite):
|
|
732
|
+
for subspec in (
|
|
733
|
+
"full_state_spec",
|
|
734
|
+
"full_action_spec",
|
|
735
|
+
"full_done_spec",
|
|
736
|
+
"full_reward_spec",
|
|
737
|
+
"full_observation_spec",
|
|
738
|
+
):
|
|
739
|
+
for key, spec in reversed(
|
|
740
|
+
list(specs.get(subspec, default=Composite()).items(True))
|
|
741
|
+
):
|
|
742
|
+
if isinstance(spec, Composite) and spec.is_empty():
|
|
743
|
+
raise RuntimeError(
|
|
744
|
+
f"The environment passed to {cls.__name__} has empty specs in {key}. Consider using "
|
|
745
|
+
f"torchrl.envs.transforms.RemoveEmptySpecs to remove the empty specs."
|
|
746
|
+
)
|
|
747
|
+
return specs
|
|
748
|
+
|
|
749
|
+
meta_data = self.meta_data
|
|
750
|
+
self._properties_set = True
|
|
751
|
+
if self._single_task:
|
|
752
|
+
self._batch_size = meta_data.batch_size
|
|
753
|
+
device = meta_data.device
|
|
754
|
+
if self._device is None:
|
|
755
|
+
self._device = device
|
|
756
|
+
|
|
757
|
+
input_spec = _check_for_empty_spec(meta_data.specs["input_spec"].to(device))
|
|
758
|
+
output_spec = _check_for_empty_spec(
|
|
759
|
+
meta_data.specs["output_spec"].to(device)
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
self.action_spec = input_spec["full_action_spec"]
|
|
763
|
+
self.state_spec = input_spec["full_state_spec"]
|
|
764
|
+
self.observation_spec = output_spec["full_observation_spec"]
|
|
765
|
+
self.reward_spec = output_spec["full_reward_spec"]
|
|
766
|
+
self.done_spec = output_spec["full_done_spec"]
|
|
767
|
+
|
|
768
|
+
self._dummy_env_str = meta_data.env_str
|
|
769
|
+
self._env_tensordict = meta_data.tensordict
|
|
770
|
+
if device is None: # In other cases, the device will be mapped later
|
|
771
|
+
self._env_tensordict.clear_device_()
|
|
772
|
+
device_map = meta_data.device_map
|
|
773
|
+
|
|
774
|
+
def map_device(key, value, device_map=device_map):
|
|
775
|
+
return value.to(device_map[key])
|
|
776
|
+
|
|
777
|
+
self._env_tensordict.named_apply(
|
|
778
|
+
map_device, nested_keys=True, filter_empty=True
|
|
779
|
+
)
|
|
780
|
+
# if self._batch_locked is None:
|
|
781
|
+
# self._batch_locked = meta_data.batch_locked
|
|
782
|
+
else:
|
|
783
|
+
self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size])
|
|
784
|
+
devices = set()
|
|
785
|
+
for _meta_data in meta_data:
|
|
786
|
+
device = _meta_data.device
|
|
787
|
+
devices.add(device)
|
|
788
|
+
if self._device is None:
|
|
789
|
+
if len(devices) > 1:
|
|
790
|
+
raise ValueError(
|
|
791
|
+
f"The device wasn't passed to {type(self)}, but more than one device was found in the sub-environments. "
|
|
792
|
+
f"Please indicate a device to be used for collection."
|
|
793
|
+
)
|
|
794
|
+
device = list(devices)[0]
|
|
795
|
+
self._device = device
|
|
796
|
+
|
|
797
|
+
input_spec = []
|
|
798
|
+
for md in meta_data:
|
|
799
|
+
input_spec.append(_check_for_empty_spec(md.specs["input_spec"]))
|
|
800
|
+
input_spec = torch.stack(input_spec, 0)
|
|
801
|
+
output_spec = []
|
|
802
|
+
for md in meta_data:
|
|
803
|
+
output_spec.append(_check_for_empty_spec(md.specs["output_spec"]))
|
|
804
|
+
output_spec = torch.stack(output_spec, 0)
|
|
805
|
+
|
|
806
|
+
self.action_spec = input_spec["full_action_spec"]
|
|
807
|
+
self.state_spec = input_spec["full_state_spec"]
|
|
808
|
+
|
|
809
|
+
self.observation_spec = output_spec["full_observation_spec"]
|
|
810
|
+
self.reward_spec = output_spec["full_reward_spec"]
|
|
811
|
+
self.done_spec = output_spec["full_done_spec"]
|
|
812
|
+
|
|
813
|
+
self._dummy_env_str = str(meta_data[0])
|
|
814
|
+
if self.share_individual_td:
|
|
815
|
+
self._env_tensordict = LazyStackedTensorDict.lazy_stack(
|
|
816
|
+
[meta_data.tensordict for meta_data in meta_data], 0
|
|
817
|
+
)
|
|
818
|
+
else:
|
|
819
|
+
self._env_tensordict = torch.stack(
|
|
820
|
+
[meta_data.tensordict for meta_data in meta_data], 0
|
|
821
|
+
)
|
|
822
|
+
# if self._batch_locked is None:
|
|
823
|
+
# self._batch_locked = meta_data[0].batch_locked
|
|
824
|
+
self.has_lazy_inputs = contains_lazy_spec(self.input_spec)
|
|
825
|
+
|
|
826
|
+
def state_dict(self) -> OrderedDict:
|
|
827
|
+
raise NotImplementedError
|
|
828
|
+
|
|
829
|
+
def load_state_dict(self, state_dict: OrderedDict) -> None:
|
|
830
|
+
raise NotImplementedError
|
|
831
|
+
|
|
832
|
+
batch_size = lazy_property(EnvBase.batch_size)
|
|
833
|
+
device = lazy_property(EnvBase.device)
|
|
834
|
+
input_spec = lazy_property(EnvBase.input_spec)
|
|
835
|
+
output_spec = lazy_property(EnvBase.output_spec)
|
|
836
|
+
|
|
837
|
+
def _create_td(self) -> None:
|
|
838
|
+
"""Creates self.shared_tensordict_parent, a TensorDict used to store the most recent observations."""
|
|
839
|
+
if not self._use_buffers:
|
|
840
|
+
return
|
|
841
|
+
shared_tensordict_parent = self._env_tensordict.clone()
|
|
842
|
+
if self._env_tensordict.shape[0] != self.num_workers:
|
|
843
|
+
raise RuntimeError(
|
|
844
|
+
"batched environment base tensordict has the wrong shape"
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
# Non-tensor keys
|
|
848
|
+
non_tensor_keys = []
|
|
849
|
+
for spec in (
|
|
850
|
+
self.full_action_spec,
|
|
851
|
+
self.full_state_spec,
|
|
852
|
+
self.full_observation_spec,
|
|
853
|
+
self.full_reward_spec,
|
|
854
|
+
self.full_done_spec,
|
|
855
|
+
):
|
|
856
|
+
for key, _spec in spec.items(True, True):
|
|
857
|
+
if isinstance(_spec, NonTensor):
|
|
858
|
+
non_tensor_keys.append(key)
|
|
859
|
+
self._non_tensor_keys = non_tensor_keys
|
|
860
|
+
|
|
861
|
+
if self._single_task:
|
|
862
|
+
self._env_input_keys = sorted(
|
|
863
|
+
list(self.input_spec["full_action_spec"].keys(True, True))
|
|
864
|
+
+ list(self.state_spec.keys(True, True)),
|
|
865
|
+
key=_sort_keys,
|
|
866
|
+
)
|
|
867
|
+
self._env_output_keys = []
|
|
868
|
+
self._env_obs_keys = []
|
|
869
|
+
for key in self.output_spec["full_observation_spec"].keys(True, True):
|
|
870
|
+
self._env_output_keys.append(key)
|
|
871
|
+
self._env_obs_keys.append(key)
|
|
872
|
+
self._env_output_keys += self.reward_keys + self.done_keys
|
|
873
|
+
else:
|
|
874
|
+
# this is only possible if _single_task=False
|
|
875
|
+
env_input_keys = set()
|
|
876
|
+
for meta_data in self.meta_data:
|
|
877
|
+
if meta_data.specs["input_spec", "full_state_spec"] is not None:
|
|
878
|
+
env_input_keys = env_input_keys.union(
|
|
879
|
+
meta_data.specs["input_spec", "full_state_spec"].keys(
|
|
880
|
+
True, True
|
|
881
|
+
)
|
|
882
|
+
)
|
|
883
|
+
env_input_keys = env_input_keys.union(
|
|
884
|
+
meta_data.specs["input_spec", "full_action_spec"].keys(True, True)
|
|
885
|
+
)
|
|
886
|
+
env_output_keys = set()
|
|
887
|
+
env_obs_keys = set()
|
|
888
|
+
for meta_data in self.meta_data:
|
|
889
|
+
keys = meta_data.specs["output_spec"]["full_observation_spec"].keys(
|
|
890
|
+
True, True
|
|
891
|
+
)
|
|
892
|
+
keys = list(keys)
|
|
893
|
+
env_obs_keys = env_obs_keys.union(keys)
|
|
894
|
+
|
|
895
|
+
env_output_keys = env_output_keys.union(keys)
|
|
896
|
+
env_output_keys = env_output_keys.union(self.reward_keys + self.done_keys)
|
|
897
|
+
self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys)
|
|
898
|
+
self._env_input_keys = sorted(env_input_keys, key=_sort_keys)
|
|
899
|
+
self._env_output_keys = sorted(env_output_keys, key=_sort_keys)
|
|
900
|
+
|
|
901
|
+
self._env_obs_keys = [
|
|
902
|
+
key for key in self._env_obs_keys if key not in self._non_tensor_keys
|
|
903
|
+
]
|
|
904
|
+
self._env_input_keys = [
|
|
905
|
+
key for key in self._env_input_keys if key not in self._non_tensor_keys
|
|
906
|
+
]
|
|
907
|
+
self._env_output_keys = [
|
|
908
|
+
key for key in self._env_output_keys if key not in self._non_tensor_keys
|
|
909
|
+
]
|
|
910
|
+
|
|
911
|
+
reset_keys = self.reset_keys
|
|
912
|
+
self._selected_keys = (
|
|
913
|
+
set(self._env_output_keys)
|
|
914
|
+
.union(self._env_input_keys)
|
|
915
|
+
.union(self._env_obs_keys)
|
|
916
|
+
.union(set(self.done_keys))
|
|
917
|
+
)
|
|
918
|
+
self._selected_keys = self._selected_keys.union(reset_keys)
|
|
919
|
+
|
|
920
|
+
# input keys
|
|
921
|
+
self._selected_input_keys = {unravel_key(key) for key in self._env_input_keys}
|
|
922
|
+
# output keys after reset
|
|
923
|
+
self._selected_reset_keys = {
|
|
924
|
+
unravel_key(key) for key in self._env_obs_keys + self.done_keys + reset_keys
|
|
925
|
+
}
|
|
926
|
+
# output keys after reset, filtered
|
|
927
|
+
self._selected_reset_keys_filt = {
|
|
928
|
+
unravel_key(key) for key in self._env_obs_keys + self.done_keys
|
|
929
|
+
}
|
|
930
|
+
# output keys after step
|
|
931
|
+
self._selected_step_keys = {unravel_key(key) for key in self._env_output_keys}
|
|
932
|
+
|
|
933
|
+
if not self.share_individual_td:
|
|
934
|
+
shared_tensordict_parent = shared_tensordict_parent.filter_non_tensor_data()
|
|
935
|
+
shared_tensordict_parent = shared_tensordict_parent.select(
|
|
936
|
+
*self._selected_keys,
|
|
937
|
+
*(unravel_key(("next", key)) for key in self._env_output_keys),
|
|
938
|
+
strict=False,
|
|
939
|
+
)
|
|
940
|
+
self.shared_tensordict_parent = shared_tensordict_parent
|
|
941
|
+
else:
|
|
942
|
+
# Multi-task: we share tensordict that *may* have different keys
|
|
943
|
+
shared_tensordict_parent = [
|
|
944
|
+
tensordict.select(
|
|
945
|
+
*self._selected_keys,
|
|
946
|
+
*(unravel_key(("next", key)) for key in self._env_output_keys),
|
|
947
|
+
strict=False,
|
|
948
|
+
).filter_non_tensor_data()
|
|
949
|
+
for tensordict in shared_tensordict_parent
|
|
950
|
+
]
|
|
951
|
+
shared_tensordict_parent = LazyStackedTensorDict.lazy_stack(
|
|
952
|
+
shared_tensordict_parent,
|
|
953
|
+
0,
|
|
954
|
+
)
|
|
955
|
+
self.shared_tensordict_parent = shared_tensordict_parent
|
|
956
|
+
|
|
957
|
+
if self.share_individual_td:
|
|
958
|
+
if not isinstance(self.shared_tensordict_parent, LazyStackedTensorDict):
|
|
959
|
+
self.shared_tensordicts = [
|
|
960
|
+
td.clone() for td in self.shared_tensordict_parent.unbind(0)
|
|
961
|
+
]
|
|
962
|
+
self.shared_tensordict_parent = LazyStackedTensorDict.lazy_stack(
|
|
963
|
+
self.shared_tensordicts, 0
|
|
964
|
+
)
|
|
965
|
+
else:
|
|
966
|
+
# Multi-task: we share tensordict that *may* have different keys
|
|
967
|
+
# LazyStacked already stores this so we don't need to do anything
|
|
968
|
+
self.shared_tensordicts = self.shared_tensordict_parent
|
|
969
|
+
if self._share_memory:
|
|
970
|
+
self.shared_tensordict_parent.share_memory_()
|
|
971
|
+
elif self._memmap:
|
|
972
|
+
self.shared_tensordict_parent.memmap_()
|
|
973
|
+
else:
|
|
974
|
+
if self._share_memory:
|
|
975
|
+
self.shared_tensordict_parent.share_memory_()
|
|
976
|
+
if not self.shared_tensordict_parent.is_shared():
|
|
977
|
+
raise RuntimeError("share_memory_() failed")
|
|
978
|
+
elif self._memmap:
|
|
979
|
+
self.shared_tensordict_parent.memmap_()
|
|
980
|
+
if not self.shared_tensordict_parent.is_memmap():
|
|
981
|
+
raise RuntimeError("memmap_() failed")
|
|
982
|
+
self.shared_tensordicts = self.shared_tensordict_parent.unbind(0)
|
|
983
|
+
for td in self.shared_tensordicts:
|
|
984
|
+
td.lock_()
|
|
985
|
+
|
|
986
|
+
# we cache all the keys of the shared parent td for future use. This is
|
|
987
|
+
# safe since the td is locked.
|
|
988
|
+
self._cache_shared_keys = set(self.shared_tensordict_parent.keys(True, True))
|
|
989
|
+
|
|
990
|
+
self._shared_tensordict_parent_next = self.shared_tensordict_parent.get("next")
|
|
991
|
+
self._shared_tensordict_parent_root = self.shared_tensordict_parent.exclude(
|
|
992
|
+
"next", *self.reset_keys
|
|
993
|
+
)
|
|
994
|
+
|
|
995
|
+
def _start_workers(self) -> None:
|
|
996
|
+
"""Starts the various envs."""
|
|
997
|
+
raise NotImplementedError
|
|
998
|
+
|
|
999
|
+
def __repr__(self) -> str:
|
|
1000
|
+
if self._dummy_env_str is None:
|
|
1001
|
+
self._dummy_env_str = self._set_properties()
|
|
1002
|
+
return (
|
|
1003
|
+
f"{self.__class__.__name__}("
|
|
1004
|
+
f"\n\tenv={self._dummy_env_str}, "
|
|
1005
|
+
f"\n\tbatch_size={self.batch_size})"
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
def close(self, *, raise_if_closed: bool = True) -> None:
|
|
1009
|
+
if self.is_closed:
|
|
1010
|
+
if raise_if_closed:
|
|
1011
|
+
raise RuntimeError("trying to close a closed environment")
|
|
1012
|
+
else:
|
|
1013
|
+
return
|
|
1014
|
+
if self._verbose:
|
|
1015
|
+
torchrl_logger.info(f"closing {self.__class__.__name__}")
|
|
1016
|
+
|
|
1017
|
+
self.__dict__["_input_spec"] = None
|
|
1018
|
+
self.__dict__["_output_spec"] = None
|
|
1019
|
+
self._properties_set = False
|
|
1020
|
+
|
|
1021
|
+
self._shutdown_workers()
|
|
1022
|
+
self.is_closed = True
|
|
1023
|
+
import torchrl
|
|
1024
|
+
|
|
1025
|
+
num_threads = min(
|
|
1026
|
+
torchrl._THREAD_POOL_INIT, torch.get_num_threads() + self.num_workers
|
|
1027
|
+
)
|
|
1028
|
+
torch.set_num_threads(num_threads)
|
|
1029
|
+
|
|
1030
|
+
def _shutdown_workers(self) -> None:
|
|
1031
|
+
raise NotImplementedError
|
|
1032
|
+
|
|
1033
|
+
def _set_seed(self, seed: int | None) -> None:
|
|
1034
|
+
"""This method is not used in batched envs."""
|
|
1035
|
+
|
|
1036
|
+
@lazy
|
|
1037
|
+
def start(self) -> None:
|
|
1038
|
+
if not self.is_closed:
|
|
1039
|
+
raise RuntimeError("trying to start a environment that is not closed.")
|
|
1040
|
+
self._create_td()
|
|
1041
|
+
self._start_workers()
|
|
1042
|
+
|
|
1043
|
+
def to(self, device: DEVICE_TYPING):
|
|
1044
|
+
self._non_blocking = None
|
|
1045
|
+
device = _make_ordinal_device(torch.device(device))
|
|
1046
|
+
if device == self.device:
|
|
1047
|
+
return self
|
|
1048
|
+
self._device = device
|
|
1049
|
+
self.__dict__["_sync_m2w_value"] = None
|
|
1050
|
+
self.__dict__["_sync_w2m_value"] = None
|
|
1051
|
+
if self.__dict__["_input_spec"] is not None:
|
|
1052
|
+
self.__dict__["_input_spec"] = self.__dict__["_input_spec"].to(device)
|
|
1053
|
+
if self.__dict__["_output_spec"] is not None:
|
|
1054
|
+
self.__dict__["_output_spec"] = self.__dict__["_output_spec"].to(device)
|
|
1055
|
+
return self
|
|
1056
|
+
|
|
1057
|
+
def _reset_proc_data(self, tensordict, tensordict_reset):
|
|
1058
|
+
# since we call `reset` directly, all the postproc has been completed
|
|
1059
|
+
if tensordict is not None:
|
|
1060
|
+
if isinstance(tensordict_reset, LazyStackedTensorDict) and not isinstance(
|
|
1061
|
+
tensordict, LazyStackedTensorDict
|
|
1062
|
+
):
|
|
1063
|
+
tensordict = LazyStackedTensorDict(*tensordict.unbind(0))
|
|
1064
|
+
return _update_during_reset(tensordict_reset, tensordict, self.reset_keys)
|
|
1065
|
+
return tensordict_reset
|
|
1066
|
+
|
|
1067
|
+
def add_truncated_keys(self):
|
|
1068
|
+
raise RuntimeError(
|
|
1069
|
+
"Cannot add truncated keys to a batched environment. Please add these entries to "
|
|
1070
|
+
"the nested environments by calling sub_env.add_truncated_keys()"
|
|
1071
|
+
)
|
|
1072
|
+
|
|
1073
|
+
|
|
1074
|
+
class SerialEnv(BatchedEnvBase):
|
|
1075
|
+
"""Creates a series of environments in the same process."""
|
|
1076
|
+
|
|
1077
|
+
__doc__ += BatchedEnvBase.__doc__
|
|
1078
|
+
|
|
1079
|
+
_share_memory = False
|
|
1080
|
+
|
|
1081
|
+
def _start_workers(self) -> None:
|
|
1082
|
+
_num_workers = self.num_workers
|
|
1083
|
+
|
|
1084
|
+
self._envs = []
|
|
1085
|
+
weakref_set = set()
|
|
1086
|
+
for idx in range(_num_workers):
|
|
1087
|
+
env = self.create_env_fn[idx](**self.create_env_kwargs[idx])
|
|
1088
|
+
# We want to avoid having the same env multiple times
|
|
1089
|
+
# so we try to deepcopy it if needed. If we can't, we make
|
|
1090
|
+
# the user aware that this isn't a very good idea
|
|
1091
|
+
wr = weakref.ref(env)
|
|
1092
|
+
if wr in weakref_set:
|
|
1093
|
+
try:
|
|
1094
|
+
env = deepcopy(env)
|
|
1095
|
+
except Exception:
|
|
1096
|
+
warn(
|
|
1097
|
+
"Deepcopying the env failed within SerialEnv "
|
|
1098
|
+
"but more than one copy of the same env was found. "
|
|
1099
|
+
"This is a dangerous situation if your env keeps track "
|
|
1100
|
+
"of some variables (e.g., state) in-place. "
|
|
1101
|
+
"We'll use the same copy of the environment be beaware that "
|
|
1102
|
+
"this may have important, unwanted issues for stateful "
|
|
1103
|
+
"environments!"
|
|
1104
|
+
)
|
|
1105
|
+
weakref_set.add(wr)
|
|
1106
|
+
self._envs.append(env.set_spec_lock_())
|
|
1107
|
+
self.is_closed = False
|
|
1108
|
+
self.set_spec_lock_()
|
|
1109
|
+
|
|
1110
|
+
@_check_start
|
|
1111
|
+
def state_dict(self) -> OrderedDict:
|
|
1112
|
+
state_dict = OrderedDict()
|
|
1113
|
+
for idx, env in enumerate(self._envs):
|
|
1114
|
+
state_dict[f"worker{idx}"] = env.state_dict()
|
|
1115
|
+
|
|
1116
|
+
return state_dict
|
|
1117
|
+
|
|
1118
|
+
@_check_start
|
|
1119
|
+
def load_state_dict(self, state_dict: OrderedDict) -> None:
|
|
1120
|
+
if "worker0" not in state_dict:
|
|
1121
|
+
state_dict = OrderedDict(
|
|
1122
|
+
**{f"worker{idx}": state_dict for idx in range(self.num_workers)}
|
|
1123
|
+
)
|
|
1124
|
+
for idx, env in enumerate(self._envs):
|
|
1125
|
+
env.load_state_dict(state_dict[f"worker{idx}"])
|
|
1126
|
+
|
|
1127
|
+
def _shutdown_workers(self) -> None:
|
|
1128
|
+
if not self.is_closed:
|
|
1129
|
+
for env in self._envs:
|
|
1130
|
+
env.close()
|
|
1131
|
+
del self._envs
|
|
1132
|
+
|
|
1133
|
+
@_check_start
|
|
1134
|
+
def set_seed(
|
|
1135
|
+
self, seed: int | None = None, static_seed: bool = False
|
|
1136
|
+
) -> int | None:
|
|
1137
|
+
for env in self._envs:
|
|
1138
|
+
new_seed = env.set_seed(seed, static_seed=static_seed)
|
|
1139
|
+
seed = new_seed
|
|
1140
|
+
return seed
|
|
1141
|
+
|
|
1142
|
+
@_check_start
|
|
1143
|
+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
|
|
1144
|
+
list_of_kwargs = kwargs.pop("list_of_kwargs", [kwargs] * self.num_workers)
|
|
1145
|
+
if kwargs is not list_of_kwargs[0] and kwargs:
|
|
1146
|
+
# this means that kwargs had more than one element and that a list was provided
|
|
1147
|
+
for elt in list_of_kwargs:
|
|
1148
|
+
elt.update(kwargs)
|
|
1149
|
+
if tensordict is not None:
|
|
1150
|
+
if "_reset" in tensordict.keys():
|
|
1151
|
+
needs_resetting = tensordict["_reset"]
|
|
1152
|
+
else:
|
|
1153
|
+
needs_resetting = _aggregate_end_of_traj(
|
|
1154
|
+
tensordict, reset_keys=self.reset_keys
|
|
1155
|
+
)
|
|
1156
|
+
if needs_resetting.ndim > 2:
|
|
1157
|
+
needs_resetting = needs_resetting.flatten(1, needs_resetting.ndim - 1)
|
|
1158
|
+
if needs_resetting.ndim > 1:
|
|
1159
|
+
needs_resetting = needs_resetting.any(-1)
|
|
1160
|
+
elif not needs_resetting.ndim:
|
|
1161
|
+
needs_resetting = needs_resetting.expand((self.num_workers,))
|
|
1162
|
+
tensordict = tensordict.unbind(0)
|
|
1163
|
+
else:
|
|
1164
|
+
needs_resetting = torch.ones(
|
|
1165
|
+
(self.num_workers,), device=self.device, dtype=torch.bool
|
|
1166
|
+
)
|
|
1167
|
+
|
|
1168
|
+
out_tds = None
|
|
1169
|
+
if not self._use_buffers or self._non_tensor_keys:
|
|
1170
|
+
out_tds = [None] * self.num_workers
|
|
1171
|
+
|
|
1172
|
+
tds = []
|
|
1173
|
+
for i, _env in enumerate(self._envs):
|
|
1174
|
+
if not needs_resetting[i]:
|
|
1175
|
+
if out_tds is not None and tensordict is not None:
|
|
1176
|
+
ftd = _env.observation_spec.zero()
|
|
1177
|
+
if self.device is None:
|
|
1178
|
+
ftd.clear_device_()
|
|
1179
|
+
else:
|
|
1180
|
+
ftd = ftd.to(self.device)
|
|
1181
|
+
out_tds[i] = ftd
|
|
1182
|
+
continue
|
|
1183
|
+
if tensordict is not None:
|
|
1184
|
+
tensordict_ = tensordict[i]
|
|
1185
|
+
if tensordict_.is_empty():
|
|
1186
|
+
tensordict_ = None
|
|
1187
|
+
else:
|
|
1188
|
+
env_device = _env.device
|
|
1189
|
+
if env_device != self.device:
|
|
1190
|
+
if env_device is not None:
|
|
1191
|
+
tensordict_ = tensordict_.to(
|
|
1192
|
+
env_device, non_blocking=self.non_blocking
|
|
1193
|
+
)
|
|
1194
|
+
else:
|
|
1195
|
+
tensordict_ = tensordict_.clear_device_()
|
|
1196
|
+
else:
|
|
1197
|
+
tensordict_ = tensordict_.clone(False)
|
|
1198
|
+
else:
|
|
1199
|
+
tensordict_ = None
|
|
1200
|
+
tds.append((i, tensordict_))
|
|
1201
|
+
|
|
1202
|
+
self._sync_m2w()
|
|
1203
|
+
for i, tensordict_ in tds:
|
|
1204
|
+
_env = self._envs[i]
|
|
1205
|
+
_td = _env.reset(tensordict=tensordict_, **list_of_kwargs[i])
|
|
1206
|
+
if self._use_buffers:
|
|
1207
|
+
try:
|
|
1208
|
+
self.shared_tensordicts[i].update_(
|
|
1209
|
+
_td,
|
|
1210
|
+
keys_to_update=list(self._selected_reset_keys_filt),
|
|
1211
|
+
non_blocking=self.non_blocking,
|
|
1212
|
+
)
|
|
1213
|
+
except RuntimeError as err:
|
|
1214
|
+
if "no_grad mode" in str(err):
|
|
1215
|
+
raise RuntimeError(
|
|
1216
|
+
"Cannot update a view of a tensordict when gradients are required. "
|
|
1217
|
+
"To collect gradient across sub-environments, please set the "
|
|
1218
|
+
"share_individual_td argument to True."
|
|
1219
|
+
)
|
|
1220
|
+
raise
|
|
1221
|
+
if out_tds is not None:
|
|
1222
|
+
out_tds[i] = _td
|
|
1223
|
+
|
|
1224
|
+
device = self.device
|
|
1225
|
+
if not self._use_buffers:
|
|
1226
|
+
result = LazyStackedTensorDict.maybe_dense_stack(out_tds)
|
|
1227
|
+
if result.device != device:
|
|
1228
|
+
if device is None:
|
|
1229
|
+
result = result.clear_device_()
|
|
1230
|
+
else:
|
|
1231
|
+
result = result.to(device, non_blocking=self.non_blocking)
|
|
1232
|
+
self._sync_w2m()
|
|
1233
|
+
return result
|
|
1234
|
+
|
|
1235
|
+
selected_output_keys = self._selected_reset_keys_filt
|
|
1236
|
+
|
|
1237
|
+
# select + clone creates 2 tds, but we can create one only
|
|
1238
|
+
out = self.shared_tensordict_parent.named_apply(
|
|
1239
|
+
lambda *args: self.select_and_clone(
|
|
1240
|
+
*args, selected_keys=selected_output_keys
|
|
1241
|
+
),
|
|
1242
|
+
nested_keys=True,
|
|
1243
|
+
filter_empty=True,
|
|
1244
|
+
)
|
|
1245
|
+
if out_tds is not None:
|
|
1246
|
+
out.update(
|
|
1247
|
+
LazyStackedTensorDict(*out_tds), keys_to_update=self._non_tensor_keys
|
|
1248
|
+
)
|
|
1249
|
+
|
|
1250
|
+
if out.device != device:
|
|
1251
|
+
if device is None:
|
|
1252
|
+
out = out.clear_device_()
|
|
1253
|
+
else:
|
|
1254
|
+
out = out.to(device, non_blocking=self.non_blocking)
|
|
1255
|
+
self._sync_w2m()
|
|
1256
|
+
return out
|
|
1257
|
+
|
|
1258
|
+
@_check_start
|
|
1259
|
+
def _step(
|
|
1260
|
+
self,
|
|
1261
|
+
tensordict: TensorDict,
|
|
1262
|
+
) -> TensorDict:
|
|
1263
|
+
partial_steps = tensordict.get("_step")
|
|
1264
|
+
tensordict_save = tensordict
|
|
1265
|
+
if partial_steps is not None and partial_steps.all():
|
|
1266
|
+
partial_steps = None
|
|
1267
|
+
if partial_steps is not None:
|
|
1268
|
+
partial_steps = partial_steps.view(tensordict.shape)
|
|
1269
|
+
tensordict = tensordict[partial_steps]
|
|
1270
|
+
workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
|
|
1271
|
+
tensordict_in = tensordict
|
|
1272
|
+
else:
|
|
1273
|
+
workers_range = range(self.num_workers)
|
|
1274
|
+
tensordict_in = tensordict.copy()
|
|
1275
|
+
# if self._use_buffers:
|
|
1276
|
+
# shared_tensordict_parent = self.shared_tensordict_parent
|
|
1277
|
+
|
|
1278
|
+
data_in = []
|
|
1279
|
+
for i, td_ in zip(workers_range, tensordict_in):
|
|
1280
|
+
# shared_tensordicts are locked, and we need to select the keys since we update in-place.
|
|
1281
|
+
# There may be unexpected keys, such as "_reset", that we should comfortably ignore here.
|
|
1282
|
+
env_device = self._envs[i].device
|
|
1283
|
+
if env_device != self.device:
|
|
1284
|
+
if env_device is not None:
|
|
1285
|
+
data_in.append(td_.to(env_device, non_blocking=self.non_blocking))
|
|
1286
|
+
else:
|
|
1287
|
+
data_in.append(td_.clear_device_())
|
|
1288
|
+
else:
|
|
1289
|
+
data_in.append(td_)
|
|
1290
|
+
|
|
1291
|
+
self._sync_m2w()
|
|
1292
|
+
out_tds = None
|
|
1293
|
+
if not self._use_buffers or self._non_tensor_keys:
|
|
1294
|
+
out_tds = []
|
|
1295
|
+
|
|
1296
|
+
if self._use_buffers:
|
|
1297
|
+
next_td = self.shared_tensordict_parent.get("next")
|
|
1298
|
+
for i, _data_in in zip(workers_range, data_in):
|
|
1299
|
+
out_td = self._envs[i]._step(_data_in)
|
|
1300
|
+
next_td[i].update_(
|
|
1301
|
+
out_td,
|
|
1302
|
+
# _env_output_keys exclude non-tensor data
|
|
1303
|
+
keys_to_update=list(self._env_output_keys),
|
|
1304
|
+
non_blocking=self.non_blocking,
|
|
1305
|
+
)
|
|
1306
|
+
if out_tds is not None:
|
|
1307
|
+
# we store the non-tensor data here
|
|
1308
|
+
out_tds.append(out_td)
|
|
1309
|
+
|
|
1310
|
+
# We must pass a clone of the tensordict, as the values of this tensordict
|
|
1311
|
+
# will be modified in-place at further steps
|
|
1312
|
+
device = self.device
|
|
1313
|
+
|
|
1314
|
+
selected_keys = self._selected_step_keys
|
|
1315
|
+
|
|
1316
|
+
if partial_steps is not None:
|
|
1317
|
+
next_td = TensorDict.lazy_stack([next_td[i] for i in workers_range])
|
|
1318
|
+
out = next_td.named_apply(
|
|
1319
|
+
lambda *args: self.select_and_clone(*args, selected_keys),
|
|
1320
|
+
nested_keys=True,
|
|
1321
|
+
filter_empty=True,
|
|
1322
|
+
)
|
|
1323
|
+
if out_tds is not None:
|
|
1324
|
+
out.update(
|
|
1325
|
+
LazyStackedTensorDict(*out_tds),
|
|
1326
|
+
keys_to_update=self._non_tensor_keys,
|
|
1327
|
+
)
|
|
1328
|
+
|
|
1329
|
+
if out.device != device:
|
|
1330
|
+
if device is None:
|
|
1331
|
+
out = out.clear_device_()
|
|
1332
|
+
elif out.device != device:
|
|
1333
|
+
out = out.to(device, non_blocking=self.non_blocking)
|
|
1334
|
+
self._sync_w2m()
|
|
1335
|
+
else:
|
|
1336
|
+
for i, _data_in in zip(workers_range, data_in):
|
|
1337
|
+
out_td = self._envs[i]._step(_data_in)
|
|
1338
|
+
out_tds.append(out_td)
|
|
1339
|
+
out = LazyStackedTensorDict.maybe_dense_stack(out_tds)
|
|
1340
|
+
|
|
1341
|
+
if partial_steps is not None and not partial_steps.all():
|
|
1342
|
+
result = out.new_zeros(tensordict_save.shape)
|
|
1343
|
+
# Copy the observation data from the previous step as placeholder
|
|
1344
|
+
|
|
1345
|
+
def select_and_clone(x, y):
|
|
1346
|
+
if y is not None:
|
|
1347
|
+
if x.device != y.device:
|
|
1348
|
+
x = x.to(y.device)
|
|
1349
|
+
else:
|
|
1350
|
+
x = x.clone()
|
|
1351
|
+
return x
|
|
1352
|
+
|
|
1353
|
+
prev = tensordict_save._fast_apply(
|
|
1354
|
+
select_and_clone,
|
|
1355
|
+
result,
|
|
1356
|
+
filter_empty=True,
|
|
1357
|
+
device=result.device,
|
|
1358
|
+
batch_size=result.batch_size,
|
|
1359
|
+
is_leaf=_is_leaf_nontensor,
|
|
1360
|
+
default=None,
|
|
1361
|
+
)
|
|
1362
|
+
|
|
1363
|
+
result.update(prev)
|
|
1364
|
+
if partial_steps.any():
|
|
1365
|
+
result[partial_steps] = out
|
|
1366
|
+
return result
|
|
1367
|
+
|
|
1368
|
+
return out
|
|
1369
|
+
|
|
1370
|
+
def __getattr__(self, attr: str) -> Any:
|
|
1371
|
+
if attr in self.__dir__():
|
|
1372
|
+
return super().__getattr__(
|
|
1373
|
+
attr
|
|
1374
|
+
) # make sure that appropriate exceptions are raised
|
|
1375
|
+
elif attr.startswith("__"):
|
|
1376
|
+
raise AttributeError(
|
|
1377
|
+
"dispatching built-in private methods is "
|
|
1378
|
+
f"not permitted with type {type(self)}. "
|
|
1379
|
+
f"Got attribute {attr}."
|
|
1380
|
+
)
|
|
1381
|
+
else:
|
|
1382
|
+
if attr in self._excluded_wrapped_keys:
|
|
1383
|
+
raise AttributeError(f"Getting {attr} resulted in an exception")
|
|
1384
|
+
try:
|
|
1385
|
+
# determine if attr is a callable
|
|
1386
|
+
list_attr = [getattr(env, attr) for env in self._envs]
|
|
1387
|
+
callable_attr = callable(list_attr[0])
|
|
1388
|
+
if callable_attr:
|
|
1389
|
+
if self.is_closed:
|
|
1390
|
+
raise RuntimeError(
|
|
1391
|
+
"Trying to access attributes of closed/non started "
|
|
1392
|
+
"environments. Check that the batched environment "
|
|
1393
|
+
"has been started (e.g. by calling env.reset)"
|
|
1394
|
+
)
|
|
1395
|
+
return _dispatch_caller_serial(list_attr)
|
|
1396
|
+
else:
|
|
1397
|
+
return list_attr
|
|
1398
|
+
except AttributeError:
|
|
1399
|
+
raise AttributeError(
|
|
1400
|
+
f"attribute {attr} not found in " f"{self._dummy_env_str}"
|
|
1401
|
+
)
|
|
1402
|
+
|
|
1403
|
+
def to(self, device: DEVICE_TYPING):
|
|
1404
|
+
device = _make_ordinal_device(torch.device(device))
|
|
1405
|
+
if device == self.device:
|
|
1406
|
+
return self
|
|
1407
|
+
super().to(device)
|
|
1408
|
+
if not self.is_closed:
|
|
1409
|
+
self._envs = [env.to(device) for env in self._envs]
|
|
1410
|
+
return self
|
|
1411
|
+
|
|
1412
|
+
|
|
1413
|
+
class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta):
|
|
1414
|
+
"""Creates one environment per process.
|
|
1415
|
+
|
|
1416
|
+
TensorDicts are passed via shared memory or memory map.
|
|
1417
|
+
|
|
1418
|
+
"""
|
|
1419
|
+
|
|
1420
|
+
__doc__ += BatchedEnvBase.__doc__
|
|
1421
|
+
__doc__ += """
|
|
1422
|
+
|
|
1423
|
+
.. note:: ParallelEnv will timeout after one of the worker is idle for a determinate amount of time.
|
|
1424
|
+
This can be controlled via the BATCHED_PIPE_TIMEOUT environment variable, which in turn modifies
|
|
1425
|
+
the torchrl._utils.BATCHED_PIPE_TIMEOUT integer. The default timeout value is 10000 seconds.
|
|
1426
|
+
|
|
1427
|
+
.. warning::
|
|
1428
|
+
TorchRL's ParallelEnv is quite stringent when it comes to env specs, since
|
|
1429
|
+
these are used to build shared memory buffers for inter-process communication.
|
|
1430
|
+
As such, we encourage users to first run a check of the env specs with
|
|
1431
|
+
:func:`~torchrl.envs.utils.check_env_specs`:
|
|
1432
|
+
|
|
1433
|
+
>>> from torchrl.envs import check_env_specs
|
|
1434
|
+
>>> env = make_env()
|
|
1435
|
+
>>> check_env_specs(env) # if this passes without error you're good to go!
|
|
1436
|
+
>>> penv = ParallelEnv(2, make_env)
|
|
1437
|
+
|
|
1438
|
+
In particular, gym-like envs with info-dict readers may be difficult to
|
|
1439
|
+
share across processes if the spec is not properly set, which is hard to
|
|
1440
|
+
do automatically. Check :meth:`~torchrl.envs.GymLikeEnv.set_info_dict_reader`
|
|
1441
|
+
for more information. Here is a short example:
|
|
1442
|
+
|
|
1443
|
+
>>> from torchrl.envs import GymEnv, set_gym_backend, check_env_specs, TransformedEnv, TensorDictPrimer
|
|
1444
|
+
>>> import torch
|
|
1445
|
+
>>> env = GymEnv("HalfCheetah-v4")
|
|
1446
|
+
>>> env.rollout(3) # no info registered, this env passes check_env_specs
|
|
1447
|
+
TensorDict(
|
|
1448
|
+
fields={
|
|
1449
|
+
action: Tensor(shape=torch.Size([10, 6]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1450
|
+
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1451
|
+
next: TensorDict(
|
|
1452
|
+
fields={
|
|
1453
|
+
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1454
|
+
observation: Tensor(shape=torch.Size([10, 17]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
1455
|
+
reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1456
|
+
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1457
|
+
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
1458
|
+
batch_size=torch.Size([10]),
|
|
1459
|
+
device=cpu,
|
|
1460
|
+
is_shared=False),
|
|
1461
|
+
observation: Tensor(shape=torch.Size([10, 17]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
1462
|
+
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1463
|
+
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
1464
|
+
batch_size=torch.Size([10]),
|
|
1465
|
+
device=cpu,
|
|
1466
|
+
is_shared=False)
|
|
1467
|
+
>>> check_env_specs(env) # succeeds!
|
|
1468
|
+
>>> env.set_info_dict_reader() # sets the default info_dict reader
|
|
1469
|
+
>>> env.rollout(10) # because the info_dict is empty at reset time, we're missing the root infos!
|
|
1470
|
+
TensorDict(
|
|
1471
|
+
fields={
|
|
1472
|
+
action: Tensor(shape=torch.Size([10, 6]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1473
|
+
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1474
|
+
next: TensorDict(
|
|
1475
|
+
fields={
|
|
1476
|
+
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1477
|
+
observation: Tensor(shape=torch.Size([10, 17]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
1478
|
+
reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1479
|
+
reward_ctrl: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
1480
|
+
reward_run: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
1481
|
+
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1482
|
+
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1483
|
+
x_position: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
1484
|
+
x_velocity: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False)},
|
|
1485
|
+
batch_size=torch.Size([10]),
|
|
1486
|
+
device=cpu,
|
|
1487
|
+
is_shared=False),
|
|
1488
|
+
observation: Tensor(shape=torch.Size([10, 17]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
1489
|
+
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1490
|
+
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
1491
|
+
batch_size=torch.Size([10]),
|
|
1492
|
+
device=cpu,
|
|
1493
|
+
is_shared=False)
|
|
1494
|
+
>>> check_env_specs(env) # This check now fails! We should not use an env constructed like this in a parallel env
|
|
1495
|
+
>>> # This ad-hoc fix registers the info-spec for reset. It is wrapped inside `env.auto_register_info_dict()`
|
|
1496
|
+
>>> env_fixed = TransformedEnv(env, TensorDictPrimer(env.info_dict_reader[0].info_spec))
|
|
1497
|
+
>>> env_fixed.rollout(10)
|
|
1498
|
+
TensorDict(
|
|
1499
|
+
fields={
|
|
1500
|
+
action: Tensor(shape=torch.Size([10, 6]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1501
|
+
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1502
|
+
next: TensorDict(
|
|
1503
|
+
fields={
|
|
1504
|
+
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1505
|
+
observation: Tensor(shape=torch.Size([10, 17]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
1506
|
+
reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
1507
|
+
reward_ctrl: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
1508
|
+
reward_run: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
1509
|
+
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1510
|
+
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1511
|
+
x_position: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
1512
|
+
x_velocity: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False)},
|
|
1513
|
+
batch_size=torch.Size([10]),
|
|
1514
|
+
device=cpu,
|
|
1515
|
+
is_shared=False),
|
|
1516
|
+
observation: Tensor(shape=torch.Size([10, 17]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
1517
|
+
reward_ctrl: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
1518
|
+
reward_run: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
1519
|
+
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1520
|
+
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
1521
|
+
x_position: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
1522
|
+
x_velocity: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False)},
|
|
1523
|
+
batch_size=torch.Size([10]),
|
|
1524
|
+
device=cpu,
|
|
1525
|
+
is_shared=False)
|
|
1526
|
+
>>> check_env_specs(env_fixed) # Succeeds! This env can be used within a parallel env!
|
|
1527
|
+
|
|
1528
|
+
Related classes and methods: :meth:`~torchrl.envs.GymLikeEnv.auto_register_info_dict`
|
|
1529
|
+
and :class:`~torchrl.envs.gym_like.default_info_dict_reader`.
|
|
1530
|
+
|
|
1531
|
+
.. warning::
|
|
1532
|
+
The choice of the devices where ParallelEnv needs to be executed can
|
|
1533
|
+
drastically influence its performance. The rule of thumbs is:
|
|
1534
|
+
|
|
1535
|
+
- If the base environment (backend, e.g., Gym) is executed on CPU, the
|
|
1536
|
+
sub-environments should be executed on CPU and the data should be
|
|
1537
|
+
passed via shared physical memory.
|
|
1538
|
+
- If the base environment is (or can be) executed on CUDA, the sub-environments
|
|
1539
|
+
should be placed on CUDA too.
|
|
1540
|
+
- If a CUDA device is available and the policy is to be executed on CUDA,
|
|
1541
|
+
the ParallelEnv device should be set to CUDA.
|
|
1542
|
+
|
|
1543
|
+
Therefore, supposing a CUDA device is available, we have the following scenarios:
|
|
1544
|
+
|
|
1545
|
+
>>> # The sub-envs are executed on CPU, but the policy is on GPU
|
|
1546
|
+
>>> env = ParallelEnv(N, MyEnv(..., device="cpu"), device="cuda")
|
|
1547
|
+
>>> # The sub-envs are executed on CUDA
|
|
1548
|
+
>>> env = ParallelEnv(N, MyEnv(..., device="cuda"), device="cuda")
|
|
1549
|
+
>>> # this will create the exact same environment
|
|
1550
|
+
>>> env = ParallelEnv(N, MyEnv(..., device="cuda"))
|
|
1551
|
+
>>> # If no cuda device is available
|
|
1552
|
+
>>> env = ParallelEnv(N, MyEnv(..., device="cpu"))
|
|
1553
|
+
|
|
1554
|
+
.. warning::
|
|
1555
|
+
ParallelEnv disable gradients in all operations (:meth:`step`,
|
|
1556
|
+
:meth:`reset` and :meth:`step_and_maybe_reset`) because gradients
|
|
1557
|
+
cannot be passed through :class:`multiprocessing.Pipe` objects.
|
|
1558
|
+
Only :class:`~torchrl.envs.SerialEnv` will support backpropagation.
|
|
1559
|
+
|
|
1560
|
+
"""
|
|
1561
|
+
|
|
1562
|
+
def _start_workers(self) -> None:
|
|
1563
|
+
import torchrl
|
|
1564
|
+
|
|
1565
|
+
self._timeout = 10.0
|
|
1566
|
+
self.BATCHED_PIPE_TIMEOUT = torchrl._utils.BATCHED_PIPE_TIMEOUT
|
|
1567
|
+
|
|
1568
|
+
num_threads = max(
|
|
1569
|
+
1, torch.get_num_threads() - self.num_workers
|
|
1570
|
+
) # 1 more thread for this proc
|
|
1571
|
+
|
|
1572
|
+
if self.num_threads is None:
|
|
1573
|
+
self.num_threads = num_threads
|
|
1574
|
+
|
|
1575
|
+
if self.num_threads != torch.get_num_threads():
|
|
1576
|
+
torch.set_num_threads(self.num_threads)
|
|
1577
|
+
|
|
1578
|
+
if self._mp_start_method is not None:
|
|
1579
|
+
ctx = mp.get_context(self._mp_start_method)
|
|
1580
|
+
else:
|
|
1581
|
+
ctx = mp.get_context(_get_default_mp_start_method())
|
|
1582
|
+
# Use ctx.Process directly to ensure all multiprocessing primitives
|
|
1583
|
+
# (Queue, Pipe, Process, Event) come from the same context.
|
|
1584
|
+
# Warning filtering and num_threads are handled in the worker functions.
|
|
1585
|
+
proc_fun = ctx.Process
|
|
1586
|
+
num_sub_threads = self.num_sub_threads
|
|
1587
|
+
|
|
1588
|
+
_num_workers = self.num_workers
|
|
1589
|
+
|
|
1590
|
+
self.parent_channels = []
|
|
1591
|
+
self._workers = []
|
|
1592
|
+
if self._use_buffers:
|
|
1593
|
+
func = _run_worker_pipe_shared_mem
|
|
1594
|
+
else:
|
|
1595
|
+
func = _run_worker_pipe_direct
|
|
1596
|
+
# We look for cuda tensors through the leaves
|
|
1597
|
+
# because the shared tensordict could be partially on cuda
|
|
1598
|
+
# and some leaves may be inaccessible through get (e.g., LazyStacked)
|
|
1599
|
+
has_cuda = [False]
|
|
1600
|
+
|
|
1601
|
+
def look_for_cuda(tensor, has_cuda=has_cuda):
|
|
1602
|
+
has_cuda[0] = has_cuda[0] or tensor.is_cuda
|
|
1603
|
+
|
|
1604
|
+
if self._use_buffers:
|
|
1605
|
+
self.shared_tensordict_parent.apply(look_for_cuda, filter_empty=True)
|
|
1606
|
+
has_cuda = has_cuda[0]
|
|
1607
|
+
if has_cuda:
|
|
1608
|
+
self.event = torch.cuda.Event()
|
|
1609
|
+
else:
|
|
1610
|
+
self.event = None
|
|
1611
|
+
self._events = [ctx.Event() for _ in range(_num_workers)]
|
|
1612
|
+
kwargs = [{"mp_event": self._events[i]} for i in range(_num_workers)]
|
|
1613
|
+
with clear_mpi_env_vars():
|
|
1614
|
+
for idx in range(_num_workers):
|
|
1615
|
+
if self._verbose:
|
|
1616
|
+
torchrl_logger.info(f"initiating worker {idx}")
|
|
1617
|
+
# No certainty which module multiprocessing_context is
|
|
1618
|
+
parent_pipe, child_pipe = ctx.Pipe()
|
|
1619
|
+
env_fun = self.create_env_fn[idx]
|
|
1620
|
+
if not isinstance(env_fun, (EnvCreator, CloudpickleWrapper)):
|
|
1621
|
+
env_fun = CloudpickleWrapper(env_fun)
|
|
1622
|
+
|
|
1623
|
+
kwargs[idx].update(
|
|
1624
|
+
{
|
|
1625
|
+
"parent_pipe": parent_pipe,
|
|
1626
|
+
"child_pipe": child_pipe,
|
|
1627
|
+
"env_fun": env_fun,
|
|
1628
|
+
"env_fun_kwargs": self.create_env_kwargs[idx],
|
|
1629
|
+
"has_lazy_inputs": self.has_lazy_inputs,
|
|
1630
|
+
"num_threads": num_sub_threads,
|
|
1631
|
+
"non_blocking": self.non_blocking,
|
|
1632
|
+
"filter_warnings": self._filter_warnings_subprocess(),
|
|
1633
|
+
}
|
|
1634
|
+
)
|
|
1635
|
+
if self._use_buffers:
|
|
1636
|
+
kwargs[idx].update(
|
|
1637
|
+
{
|
|
1638
|
+
"shared_tensordict": self.shared_tensordicts[idx],
|
|
1639
|
+
"_selected_input_keys": self._selected_input_keys,
|
|
1640
|
+
"_selected_reset_keys": self._selected_reset_keys,
|
|
1641
|
+
"_selected_step_keys": self._selected_step_keys,
|
|
1642
|
+
"_non_tensor_keys": self._non_tensor_keys,
|
|
1643
|
+
}
|
|
1644
|
+
)
|
|
1645
|
+
else:
|
|
1646
|
+
kwargs[idx].update(
|
|
1647
|
+
{
|
|
1648
|
+
"consolidate": self.consolidate,
|
|
1649
|
+
}
|
|
1650
|
+
)
|
|
1651
|
+
process = proc_fun(target=func, kwargs=kwargs[idx])
|
|
1652
|
+
process.daemon = self.daemon
|
|
1653
|
+
process.start()
|
|
1654
|
+
child_pipe.close()
|
|
1655
|
+
self.parent_channels.append(parent_pipe)
|
|
1656
|
+
self._workers.append(process)
|
|
1657
|
+
|
|
1658
|
+
for parent_pipe in self.parent_channels:
|
|
1659
|
+
# use msg as sync point
|
|
1660
|
+
parent_pipe.recv()
|
|
1661
|
+
|
|
1662
|
+
# send shared tensordict to workers
|
|
1663
|
+
for channel in self.parent_channels:
|
|
1664
|
+
channel.send(("init", None))
|
|
1665
|
+
self.is_closed = False
|
|
1666
|
+
self.set_spec_lock_()
|
|
1667
|
+
|
|
1668
|
+
def _filter_warnings_subprocess(self) -> bool:
|
|
1669
|
+
from torchrl import filter_warnings_subprocess
|
|
1670
|
+
|
|
1671
|
+
return filter_warnings_subprocess
|
|
1672
|
+
|
|
1673
|
+
@_check_start
|
|
1674
|
+
def state_dict(self) -> OrderedDict:
|
|
1675
|
+
state_dict = OrderedDict()
|
|
1676
|
+
for channel in self.parent_channels:
|
|
1677
|
+
channel.send(("state_dict", None))
|
|
1678
|
+
for idx, channel in enumerate(self.parent_channels):
|
|
1679
|
+
msg, _state_dict = channel.recv()
|
|
1680
|
+
if msg != "state_dict":
|
|
1681
|
+
raise RuntimeError(f"Expected 'state_dict' but received {msg}")
|
|
1682
|
+
state_dict[f"worker{idx}"] = _state_dict
|
|
1683
|
+
|
|
1684
|
+
return state_dict
|
|
1685
|
+
|
|
1686
|
+
@_check_start
|
|
1687
|
+
def load_state_dict(self, state_dict: OrderedDict) -> None:
|
|
1688
|
+
if "worker0" not in state_dict:
|
|
1689
|
+
state_dict = OrderedDict(
|
|
1690
|
+
**{f"worker{idx}": state_dict for idx in range(self.num_workers)}
|
|
1691
|
+
)
|
|
1692
|
+
for i, channel in enumerate(self.parent_channels):
|
|
1693
|
+
channel.send(("load_state_dict", state_dict[f"worker{i}"]))
|
|
1694
|
+
for event in self._events:
|
|
1695
|
+
event.wait(self._timeout)
|
|
1696
|
+
event.clear()
|
|
1697
|
+
|
|
1698
|
+
def _step_and_maybe_reset_no_buffers(
|
|
1699
|
+
self, tensordict: TensorDictBase
|
|
1700
|
+
) -> tuple[TensorDictBase, TensorDictBase]:
|
|
1701
|
+
partial_steps = tensordict.get("_step", None)
|
|
1702
|
+
tensordict_save = tensordict
|
|
1703
|
+
if partial_steps is not None and partial_steps.all():
|
|
1704
|
+
partial_steps = None
|
|
1705
|
+
if partial_steps is not None:
|
|
1706
|
+
partial_steps = partial_steps.view(tensordict.shape)
|
|
1707
|
+
tensordict = tensordict[partial_steps]
|
|
1708
|
+
workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
|
|
1709
|
+
else:
|
|
1710
|
+
workers_range = range(self.num_workers)
|
|
1711
|
+
|
|
1712
|
+
if self.consolidate:
|
|
1713
|
+
try:
|
|
1714
|
+
td = tensordict.consolidate(
|
|
1715
|
+
# share_memory=False: avoid resource_sharer which causes
|
|
1716
|
+
# progressive slowdown with fork on Linux
|
|
1717
|
+
share_memory=False,
|
|
1718
|
+
inplace=True,
|
|
1719
|
+
num_threads=1,
|
|
1720
|
+
)
|
|
1721
|
+
except Exception as err:
|
|
1722
|
+
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
|
|
1723
|
+
else:
|
|
1724
|
+
td = tensordict
|
|
1725
|
+
|
|
1726
|
+
for i in workers_range:
|
|
1727
|
+
# We send the same td multiple times as it is in shared mem and we just need to index it
|
|
1728
|
+
# in each process.
|
|
1729
|
+
# If we don't do this, we need to unbind it but then the custom pickler will require
|
|
1730
|
+
# some extra metadata to be collected.
|
|
1731
|
+
self.parent_channels[i].send(("step_and_maybe_reset", (td, i)))
|
|
1732
|
+
|
|
1733
|
+
results = [None] * len(workers_range)
|
|
1734
|
+
|
|
1735
|
+
self._wait_for_workers(workers_range)
|
|
1736
|
+
|
|
1737
|
+
for i, w in enumerate(workers_range):
|
|
1738
|
+
results[i] = self.parent_channels[w].recv()
|
|
1739
|
+
|
|
1740
|
+
out_next, out_root = zip(*(future for future in results))
|
|
1741
|
+
out = TensorDict.maybe_dense_stack(out_next), TensorDict.maybe_dense_stack(
|
|
1742
|
+
out_root
|
|
1743
|
+
)
|
|
1744
|
+
if partial_steps is not None:
|
|
1745
|
+
result = out.new_zeros(tensordict_save.shape)
|
|
1746
|
+
|
|
1747
|
+
def select_and_clone(x, y):
|
|
1748
|
+
if y is not None:
|
|
1749
|
+
if x.device != y.device:
|
|
1750
|
+
x = x.to(y.device)
|
|
1751
|
+
else:
|
|
1752
|
+
x = x.clone()
|
|
1753
|
+
return x
|
|
1754
|
+
|
|
1755
|
+
prev = tensordict_save._fast_apply(
|
|
1756
|
+
select_and_clone,
|
|
1757
|
+
result,
|
|
1758
|
+
filter_empty=True,
|
|
1759
|
+
device=result.device,
|
|
1760
|
+
batch_size=result.batch_size,
|
|
1761
|
+
is_leaf=_is_leaf_nontensor,
|
|
1762
|
+
default=None,
|
|
1763
|
+
)
|
|
1764
|
+
|
|
1765
|
+
result.update(prev)
|
|
1766
|
+
|
|
1767
|
+
if partial_steps.any():
|
|
1768
|
+
result[partial_steps] = out
|
|
1769
|
+
return result
|
|
1770
|
+
return out
|
|
1771
|
+
|
|
1772
|
+
@torch.no_grad()
|
|
1773
|
+
@_check_start
|
|
1774
|
+
def step_and_maybe_reset(
|
|
1775
|
+
self, tensordict: TensorDictBase
|
|
1776
|
+
) -> tuple[TensorDictBase, TensorDictBase]:
|
|
1777
|
+
if not self._use_buffers:
|
|
1778
|
+
# Simply dispatch the input to the workers
|
|
1779
|
+
# return self._step_and_maybe_reset_no_buffers(tensordict)
|
|
1780
|
+
return super().step_and_maybe_reset(tensordict)
|
|
1781
|
+
|
|
1782
|
+
partial_steps = tensordict.get("_step")
|
|
1783
|
+
tensordict_save = tensordict
|
|
1784
|
+
if partial_steps is not None and partial_steps.all():
|
|
1785
|
+
partial_steps = None
|
|
1786
|
+
if partial_steps is not None:
|
|
1787
|
+
partial_steps = partial_steps.view(tensordict.shape)
|
|
1788
|
+
workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
|
|
1789
|
+
shared_tensordict_parent = TensorDict.lazy_stack(
|
|
1790
|
+
[self.shared_tensordict_parent[i] for i in workers_range]
|
|
1791
|
+
)
|
|
1792
|
+
next_td = TensorDict.lazy_stack(
|
|
1793
|
+
[self._shared_tensordict_parent_next[i] for i in workers_range]
|
|
1794
|
+
)
|
|
1795
|
+
tensordict_ = TensorDict.lazy_stack(
|
|
1796
|
+
[self._shared_tensordict_parent_root[i] for i in workers_range]
|
|
1797
|
+
)
|
|
1798
|
+
if self.shared_tensordict_parent.device is None:
|
|
1799
|
+
tensordict = tensordict._fast_apply(
|
|
1800
|
+
lambda x, y: x[partial_steps].to(y.device)
|
|
1801
|
+
if y is not None
|
|
1802
|
+
else x[partial_steps],
|
|
1803
|
+
self.shared_tensordict_parent,
|
|
1804
|
+
default=None,
|
|
1805
|
+
device=None,
|
|
1806
|
+
batch_size=shared_tensordict_parent.shape,
|
|
1807
|
+
)
|
|
1808
|
+
else:
|
|
1809
|
+
tensordict = tensordict[partial_steps].to(
|
|
1810
|
+
self.shared_tensordict_parent.device
|
|
1811
|
+
)
|
|
1812
|
+
else:
|
|
1813
|
+
workers_range = range(self.num_workers)
|
|
1814
|
+
shared_tensordict_parent = self.shared_tensordict_parent
|
|
1815
|
+
next_td = self._shared_tensordict_parent_next
|
|
1816
|
+
tensordict_ = self._shared_tensordict_parent_root
|
|
1817
|
+
|
|
1818
|
+
# We must use the in_keys and nothing else for the following reasons:
|
|
1819
|
+
# - efficiency: copying all the keys will in practice mean doing a lot
|
|
1820
|
+
# of writing operations since the input tensordict may (and often will)
|
|
1821
|
+
# contain all the previous output data.
|
|
1822
|
+
# - value mismatch: if the batched env is placed within a transform
|
|
1823
|
+
# and this transform overrides an observation key (eg, CatFrames)
|
|
1824
|
+
# the shape, dtype or device may not necessarily match and writing
|
|
1825
|
+
# the value in-place will fail.
|
|
1826
|
+
shared_tensordict_parent.update_(
|
|
1827
|
+
tensordict,
|
|
1828
|
+
keys_to_update=self._env_input_keys,
|
|
1829
|
+
non_blocking=self.non_blocking,
|
|
1830
|
+
)
|
|
1831
|
+
next_td_passthrough = tensordict.get("next", default=None)
|
|
1832
|
+
if next_td_passthrough is not None:
|
|
1833
|
+
# if we have input "next" data (eg, RNNs which pass the next state)
|
|
1834
|
+
# the sub-envs will need to process them through step_and_maybe_reset.
|
|
1835
|
+
# We keep track of which keys are present to let the worker know what
|
|
1836
|
+
# should be passed to the env (we don't want to pass done states for instance)
|
|
1837
|
+
next_td_keys = list(next_td_passthrough.keys(True, True))
|
|
1838
|
+
data = [{"next_td_passthrough_keys": next_td_keys} for _ in workers_range]
|
|
1839
|
+
shared_tensordict_parent.get("next").update_(
|
|
1840
|
+
next_td_passthrough, non_blocking=self.non_blocking
|
|
1841
|
+
)
|
|
1842
|
+
else:
|
|
1843
|
+
# next_td_keys = None
|
|
1844
|
+
data = [{} for _ in workers_range]
|
|
1845
|
+
|
|
1846
|
+
if self._non_tensor_keys:
|
|
1847
|
+
for i, td in zip(
|
|
1848
|
+
workers_range,
|
|
1849
|
+
tensordict.select(*self._non_tensor_keys, strict=False).unbind(0),
|
|
1850
|
+
):
|
|
1851
|
+
data[i]["non_tensor_data"] = td
|
|
1852
|
+
|
|
1853
|
+
self._sync_m2w()
|
|
1854
|
+
for i, _data in zip(workers_range, data):
|
|
1855
|
+
self.parent_channels[i].send(("step_and_maybe_reset", _data))
|
|
1856
|
+
|
|
1857
|
+
self._wait_for_workers(workers_range)
|
|
1858
|
+
if self._non_tensor_keys:
|
|
1859
|
+
non_tensor_tds = []
|
|
1860
|
+
for i in workers_range:
|
|
1861
|
+
msg, non_tensor_td = self.parent_channels[i].recv()
|
|
1862
|
+
non_tensor_tds.append(non_tensor_td)
|
|
1863
|
+
|
|
1864
|
+
# We must pass a clone of the tensordict, as the values of this tensordict
|
|
1865
|
+
# will be modified in-place at further steps
|
|
1866
|
+
device = self.device
|
|
1867
|
+
if shared_tensordict_parent.device == device:
|
|
1868
|
+
next_td = next_td.clone()
|
|
1869
|
+
tensordict_ = tensordict_.clone()
|
|
1870
|
+
elif device is not None:
|
|
1871
|
+
next_td = next_td._fast_apply(
|
|
1872
|
+
lambda x: x.to(device, non_blocking=self.non_blocking)
|
|
1873
|
+
if x.device != device
|
|
1874
|
+
else x.clone(),
|
|
1875
|
+
device=device,
|
|
1876
|
+
filter_empty=True,
|
|
1877
|
+
)
|
|
1878
|
+
tensordict_ = tensordict_._fast_apply(
|
|
1879
|
+
lambda x: x.to(device, non_blocking=self.non_blocking)
|
|
1880
|
+
if x.device != device
|
|
1881
|
+
else x.clone(),
|
|
1882
|
+
device=device,
|
|
1883
|
+
filter_empty=True,
|
|
1884
|
+
)
|
|
1885
|
+
if tensordict.device != device:
|
|
1886
|
+
tensordict = tensordict._fast_apply(
|
|
1887
|
+
lambda x: x.to(device, non_blocking=self.non_blocking)
|
|
1888
|
+
if x.device != device
|
|
1889
|
+
else x,
|
|
1890
|
+
device=device,
|
|
1891
|
+
filter_empty=True,
|
|
1892
|
+
)
|
|
1893
|
+
self._sync_w2m()
|
|
1894
|
+
else:
|
|
1895
|
+
next_td = next_td.clone().clear_device_()
|
|
1896
|
+
tensordict_ = tensordict_.clone().clear_device_()
|
|
1897
|
+
tensordict.set("next", next_td)
|
|
1898
|
+
if self._non_tensor_keys:
|
|
1899
|
+
non_tensor_tds = LazyStackedTensorDict(*non_tensor_tds)
|
|
1900
|
+
tensordict.update(
|
|
1901
|
+
non_tensor_tds,
|
|
1902
|
+
keys_to_update=[("next", key) for key in self._non_tensor_keys],
|
|
1903
|
+
)
|
|
1904
|
+
tensordict_.update(non_tensor_tds, keys_to_update=self._non_tensor_keys)
|
|
1905
|
+
|
|
1906
|
+
if partial_steps is not None:
|
|
1907
|
+
result = tensordict.new_zeros(tensordict_save.shape)
|
|
1908
|
+
result_ = tensordict_.new_zeros(tensordict_save.shape)
|
|
1909
|
+
|
|
1910
|
+
def select_and_transfer(x, y):
|
|
1911
|
+
if y is not None:
|
|
1912
|
+
return (
|
|
1913
|
+
x.to(y.device, non_blocking=self.non_blocking)
|
|
1914
|
+
if x.device != y.device
|
|
1915
|
+
else x.clone()
|
|
1916
|
+
)
|
|
1917
|
+
|
|
1918
|
+
old_r_copy = tensordict_save._fast_apply(
|
|
1919
|
+
select_and_transfer,
|
|
1920
|
+
result,
|
|
1921
|
+
filter_empty=True,
|
|
1922
|
+
device=device,
|
|
1923
|
+
default=None,
|
|
1924
|
+
)
|
|
1925
|
+
old_r_copy.set(
|
|
1926
|
+
"next",
|
|
1927
|
+
tensordict_save._fast_apply(
|
|
1928
|
+
select_and_transfer,
|
|
1929
|
+
next_td,
|
|
1930
|
+
filter_empty=True,
|
|
1931
|
+
device=device,
|
|
1932
|
+
default=None,
|
|
1933
|
+
),
|
|
1934
|
+
)
|
|
1935
|
+
result.update(old_r_copy)
|
|
1936
|
+
result_.update(
|
|
1937
|
+
tensordict_save._fast_apply(
|
|
1938
|
+
select_and_transfer,
|
|
1939
|
+
result_,
|
|
1940
|
+
filter_empty=True,
|
|
1941
|
+
device=device,
|
|
1942
|
+
default=None,
|
|
1943
|
+
)
|
|
1944
|
+
)
|
|
1945
|
+
self._sync_w2m()
|
|
1946
|
+
|
|
1947
|
+
if partial_steps.any():
|
|
1948
|
+
result[partial_steps] = tensordict
|
|
1949
|
+
result_[partial_steps] = tensordict_
|
|
1950
|
+
return result, result_
|
|
1951
|
+
|
|
1952
|
+
return tensordict, tensordict_
|
|
1953
|
+
|
|
1954
|
+
def _wait_for_workers(self, workers_range):
|
|
1955
|
+
"""Wait for all workers to signal completion via their events.
|
|
1956
|
+
|
|
1957
|
+
Uses multiprocessing.connection.wait() for efficient OS-level
|
|
1958
|
+
waiting on multiple pipes simultaneously.
|
|
1959
|
+
"""
|
|
1960
|
+
timeout = self.BATCHED_PIPE_TIMEOUT
|
|
1961
|
+
t0 = time.time()
|
|
1962
|
+
|
|
1963
|
+
# In shared-memory/buffer mode, workers signal completion by setting
|
|
1964
|
+
# their `mp_event` (they may not send anything back on the pipe).
|
|
1965
|
+
if self._use_buffers:
|
|
1966
|
+
pending = set(workers_range)
|
|
1967
|
+
n_iter = 0
|
|
1968
|
+
while pending:
|
|
1969
|
+
n_iter += 1
|
|
1970
|
+
remaining = timeout - (time.time() - t0)
|
|
1971
|
+
if remaining <= 0:
|
|
1972
|
+
raise RuntimeError(
|
|
1973
|
+
f"Failed to run all workers within the {timeout} sec time limit. This "
|
|
1974
|
+
f"threshold can be increased via the BATCHED_PIPE_TIMEOUT env variable."
|
|
1975
|
+
)
|
|
1976
|
+
|
|
1977
|
+
# Wait in short slices so we can both harvest multiple events and
|
|
1978
|
+
# periodically check for dead workers without blocking forever.
|
|
1979
|
+
slice_timeout = min(0.1, remaining)
|
|
1980
|
+
progressed = False
|
|
1981
|
+
for wi in tuple(pending):
|
|
1982
|
+
if self._events[wi].wait(timeout=slice_timeout):
|
|
1983
|
+
self._events[wi].clear()
|
|
1984
|
+
pending.remove(wi)
|
|
1985
|
+
progressed = True
|
|
1986
|
+
|
|
1987
|
+
if not progressed and (n_iter % 50) == 0:
|
|
1988
|
+
for wi in pending:
|
|
1989
|
+
if not self._workers[wi].is_alive():
|
|
1990
|
+
try:
|
|
1991
|
+
self._shutdown_workers()
|
|
1992
|
+
finally:
|
|
1993
|
+
raise RuntimeError(f"Cannot proceed, worker {wi} dead.")
|
|
1994
|
+
return
|
|
1995
|
+
|
|
1996
|
+
# No-buffer mode: workers send back data on the pipe, so we can efficiently
|
|
1997
|
+
# block on readability.
|
|
1998
|
+
pipes_pending = {self.parent_channels[i]: i for i in workers_range}
|
|
1999
|
+
i = 0
|
|
2000
|
+
while pipes_pending:
|
|
2001
|
+
i += 1
|
|
2002
|
+
should_check_for_dead_workers = (i % 20) == 0
|
|
2003
|
+
remaining = timeout - (time.time() - t0)
|
|
2004
|
+
if remaining <= 0:
|
|
2005
|
+
raise RuntimeError(
|
|
2006
|
+
f"Failed to run all workers within the {timeout} sec time limit. This "
|
|
2007
|
+
f"threshold can be increased via the BATCHED_PIPE_TIMEOUT env variable."
|
|
2008
|
+
)
|
|
2009
|
+
|
|
2010
|
+
# Wait for any pipes to become readable (OS-level select/poll)
|
|
2011
|
+
ready = connection_wait(list(pipes_pending.keys()), timeout=remaining)
|
|
2012
|
+
|
|
2013
|
+
if not ready and should_check_for_dead_workers:
|
|
2014
|
+
# Timeout with no pipes ready - check for dead workers
|
|
2015
|
+
for wi in pipes_pending.values():
|
|
2016
|
+
if not self._workers[wi].is_alive():
|
|
2017
|
+
try:
|
|
2018
|
+
self._shutdown_workers()
|
|
2019
|
+
finally:
|
|
2020
|
+
raise RuntimeError(f"Cannot proceed, worker {wi} dead.")
|
|
2021
|
+
continue
|
|
2022
|
+
|
|
2023
|
+
# Clear events for ready workers (best-effort)
|
|
2024
|
+
for pipe in ready:
|
|
2025
|
+
wi = pipes_pending.pop(pipe)
|
|
2026
|
+
self._events[wi].clear()
|
|
2027
|
+
|
|
2028
|
+
def _step_no_buffers(
|
|
2029
|
+
self, tensordict: TensorDictBase
|
|
2030
|
+
) -> tuple[TensorDictBase, TensorDictBase]:
|
|
2031
|
+
partial_steps = tensordict.get("_step")
|
|
2032
|
+
tensordict_save = tensordict
|
|
2033
|
+
if partial_steps is not None and partial_steps.all():
|
|
2034
|
+
partial_steps = None
|
|
2035
|
+
if partial_steps is not None:
|
|
2036
|
+
partial_steps = partial_steps.view(tensordict.shape)
|
|
2037
|
+
tensordict = tensordict[partial_steps]
|
|
2038
|
+
workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
|
|
2039
|
+
else:
|
|
2040
|
+
workers_range = range(self.num_workers)
|
|
2041
|
+
|
|
2042
|
+
if self.consolidate:
|
|
2043
|
+
try:
|
|
2044
|
+
data = tensordict.consolidate(
|
|
2045
|
+
# share_memory=False: avoid resource_sharer which causes
|
|
2046
|
+
# progressive slowdown with fork on Linux
|
|
2047
|
+
share_memory=False,
|
|
2048
|
+
inplace=False,
|
|
2049
|
+
num_threads=1,
|
|
2050
|
+
)
|
|
2051
|
+
except Exception as err:
|
|
2052
|
+
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
|
|
2053
|
+
else:
|
|
2054
|
+
data = tensordict
|
|
2055
|
+
|
|
2056
|
+
for i, local_data in zip(workers_range, data.unbind(0)):
|
|
2057
|
+
env_device = (
|
|
2058
|
+
self.meta_data[i].device
|
|
2059
|
+
if isinstance(self.meta_data, list)
|
|
2060
|
+
else self.meta_data.device
|
|
2061
|
+
)
|
|
2062
|
+
if data.device != env_device:
|
|
2063
|
+
if env_device is None:
|
|
2064
|
+
local_data.clear_device_()
|
|
2065
|
+
else:
|
|
2066
|
+
local_data = local_data.to(env_device)
|
|
2067
|
+
self.parent_channels[i].send(("step", local_data))
|
|
2068
|
+
|
|
2069
|
+
self._wait_for_workers(workers_range)
|
|
2070
|
+
|
|
2071
|
+
out_tds = []
|
|
2072
|
+
for i in workers_range:
|
|
2073
|
+
channel = self.parent_channels[i]
|
|
2074
|
+
td = channel.recv()
|
|
2075
|
+
out_tds.append(td)
|
|
2076
|
+
|
|
2077
|
+
out = LazyStackedTensorDict.maybe_dense_stack(out_tds)
|
|
2078
|
+
if self.device is not None and out.device != self.device:
|
|
2079
|
+
out = out.to(self.device, non_blocking=self.non_blocking)
|
|
2080
|
+
if partial_steps is not None:
|
|
2081
|
+
result = out.new_zeros(tensordict_save.shape)
|
|
2082
|
+
|
|
2083
|
+
def select_and_clone(x, y):
|
|
2084
|
+
if y is not None:
|
|
2085
|
+
if x.device != y.device:
|
|
2086
|
+
x = x.to(y.device)
|
|
2087
|
+
else:
|
|
2088
|
+
x = x.clone()
|
|
2089
|
+
return x
|
|
2090
|
+
|
|
2091
|
+
prev = tensordict_save._fast_apply(
|
|
2092
|
+
select_and_clone,
|
|
2093
|
+
result,
|
|
2094
|
+
filter_empty=True,
|
|
2095
|
+
device=result.device,
|
|
2096
|
+
batch_size=result.batch_size,
|
|
2097
|
+
is_leaf=_is_leaf_nontensor,
|
|
2098
|
+
default=None,
|
|
2099
|
+
)
|
|
2100
|
+
|
|
2101
|
+
result.update(prev)
|
|
2102
|
+
|
|
2103
|
+
if partial_steps.any():
|
|
2104
|
+
result[partial_steps] = out
|
|
2105
|
+
return result
|
|
2106
|
+
return out
|
|
2107
|
+
|
|
2108
|
+
@torch.no_grad()
|
|
2109
|
+
@_check_start
|
|
2110
|
+
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
2111
|
+
if not self._use_buffers:
|
|
2112
|
+
return self._step_no_buffers(tensordict)
|
|
2113
|
+
# We must use the in_keys and nothing else for the following reasons:
|
|
2114
|
+
# - efficiency: copying all the keys will in practice mean doing a lot
|
|
2115
|
+
# of writing operations since the input tensordict may (and often will)
|
|
2116
|
+
# contain all the previous output data.
|
|
2117
|
+
# - value mismatch: if the batched env is placed within a transform
|
|
2118
|
+
# and this transform overrides an observation key (eg, CatFrames)
|
|
2119
|
+
# the shape, dtype or device may not necessarily match and writing
|
|
2120
|
+
# the value in-place will fail.
|
|
2121
|
+
partial_steps = tensordict.get("_step")
|
|
2122
|
+
tensordict_save = tensordict
|
|
2123
|
+
if partial_steps is not None and partial_steps.all():
|
|
2124
|
+
partial_steps = None
|
|
2125
|
+
if partial_steps is not None:
|
|
2126
|
+
partial_steps = partial_steps.view(tensordict.shape)
|
|
2127
|
+
workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
|
|
2128
|
+
shared_tensordict_parent = TensorDict.lazy_stack(
|
|
2129
|
+
[self.shared_tensordicts[i] for i in workers_range]
|
|
2130
|
+
)
|
|
2131
|
+
if self.shared_tensordict_parent.device is None:
|
|
2132
|
+
tensordict = tensordict._fast_apply(
|
|
2133
|
+
lambda x, y: x[partial_steps].to(y.device)
|
|
2134
|
+
if y is not None
|
|
2135
|
+
else x[partial_steps],
|
|
2136
|
+
self.shared_tensordict_parent,
|
|
2137
|
+
default=None,
|
|
2138
|
+
device=None,
|
|
2139
|
+
batch_size=shared_tensordict_parent.shape,
|
|
2140
|
+
)
|
|
2141
|
+
else:
|
|
2142
|
+
tensordict = tensordict[partial_steps].to(
|
|
2143
|
+
self.shared_tensordict_parent.device
|
|
2144
|
+
)
|
|
2145
|
+
else:
|
|
2146
|
+
workers_range = range(self.num_workers)
|
|
2147
|
+
shared_tensordict_parent = self.shared_tensordict_parent
|
|
2148
|
+
|
|
2149
|
+
shared_tensordict_parent.update_(
|
|
2150
|
+
tensordict,
|
|
2151
|
+
# We also update the output keys because they can be implicitly used, eg
|
|
2152
|
+
# during partial steps to fill in values
|
|
2153
|
+
keys_to_update=list(self._env_input_keys),
|
|
2154
|
+
non_blocking=self.non_blocking,
|
|
2155
|
+
)
|
|
2156
|
+
next_td_passthrough = tensordict.get("next", None)
|
|
2157
|
+
if next_td_passthrough is not None:
|
|
2158
|
+
# if we have input "next" data (eg, RNNs which pass the next state)
|
|
2159
|
+
# the sub-envs will need to process them through step_and_maybe_reset.
|
|
2160
|
+
# We keep track of which keys are present to let the worker know what
|
|
2161
|
+
# should be passed to the env (we don't want to pass done states for instance)
|
|
2162
|
+
next_td_keys = list(next_td_passthrough.keys(True, True))
|
|
2163
|
+
next_shared_tensordict_parent = shared_tensordict_parent.get("next")
|
|
2164
|
+
|
|
2165
|
+
# We separate keys that are and are not present in the buffer here and not in step_and_maybe_reset.
|
|
2166
|
+
# The reason we do that is that the policy may write stuff in 'next' that is not part of the specs of
|
|
2167
|
+
# the batched env but part of the specs of a transformed batched env.
|
|
2168
|
+
# If that is the case, `update_` will fail to find the entries to update.
|
|
2169
|
+
# What we do instead is keeping the tensors on the side and putting them back after completing _step.
|
|
2170
|
+
keys_to_update, keys_to_copy = zip(
|
|
2171
|
+
*[
|
|
2172
|
+
(key, None)
|
|
2173
|
+
if key in next_shared_tensordict_parent.keys(True, True)
|
|
2174
|
+
else (None, key)
|
|
2175
|
+
for key in next_td_keys
|
|
2176
|
+
]
|
|
2177
|
+
)
|
|
2178
|
+
keys_to_update = [key for key in keys_to_update if key is not None]
|
|
2179
|
+
keys_to_copy = [key for key in keys_to_copy if key is not None]
|
|
2180
|
+
data = [
|
|
2181
|
+
{"next_td_passthrough_keys": keys_to_update}
|
|
2182
|
+
for _ in range(self.num_workers)
|
|
2183
|
+
]
|
|
2184
|
+
if keys_to_update:
|
|
2185
|
+
next_shared_tensordict_parent.update_(
|
|
2186
|
+
next_td_passthrough,
|
|
2187
|
+
non_blocking=self.non_blocking,
|
|
2188
|
+
keys_to_update=keys_to_update,
|
|
2189
|
+
)
|
|
2190
|
+
if keys_to_copy:
|
|
2191
|
+
next_td_passthrough = next_td_passthrough.select(*keys_to_copy)
|
|
2192
|
+
else:
|
|
2193
|
+
next_td_passthrough = None
|
|
2194
|
+
else:
|
|
2195
|
+
next_td_passthrough = None
|
|
2196
|
+
data = [{} for _ in range(self.num_workers)]
|
|
2197
|
+
|
|
2198
|
+
if self._non_tensor_keys:
|
|
2199
|
+
for i, td in zip(
|
|
2200
|
+
workers_range,
|
|
2201
|
+
tensordict.select(*self._non_tensor_keys, strict=False).unbind(0),
|
|
2202
|
+
):
|
|
2203
|
+
data[i]["non_tensor_data"] = td
|
|
2204
|
+
|
|
2205
|
+
self._sync_m2w()
|
|
2206
|
+
|
|
2207
|
+
if self.event is not None:
|
|
2208
|
+
self.event.record()
|
|
2209
|
+
self.event.synchronize()
|
|
2210
|
+
|
|
2211
|
+
for i in workers_range:
|
|
2212
|
+
self.parent_channels[i].send(("step", data[i]))
|
|
2213
|
+
|
|
2214
|
+
self._wait_for_workers(workers_range)
|
|
2215
|
+
|
|
2216
|
+
if self._non_tensor_keys:
|
|
2217
|
+
non_tensor_tds = []
|
|
2218
|
+
for i in workers_range:
|
|
2219
|
+
msg, non_tensor_td = self.parent_channels[i].recv()
|
|
2220
|
+
non_tensor_tds.append(non_tensor_td)
|
|
2221
|
+
|
|
2222
|
+
# We must pass a clone of the tensordict, as the values of this tensordict
|
|
2223
|
+
# will be modified in-place at further steps
|
|
2224
|
+
next_td = shared_tensordict_parent.get("next")
|
|
2225
|
+
device = self.device
|
|
2226
|
+
|
|
2227
|
+
out = next_td.named_apply(
|
|
2228
|
+
self.select_and_clone,
|
|
2229
|
+
nested_keys=True,
|
|
2230
|
+
filter_empty=True,
|
|
2231
|
+
device=device,
|
|
2232
|
+
)
|
|
2233
|
+
if self._non_tensor_keys:
|
|
2234
|
+
out.update(
|
|
2235
|
+
LazyStackedTensorDict(*non_tensor_tds),
|
|
2236
|
+
keys_to_update=self._non_tensor_keys,
|
|
2237
|
+
)
|
|
2238
|
+
if next_td_passthrough is not None:
|
|
2239
|
+
out.update(next_td_passthrough)
|
|
2240
|
+
|
|
2241
|
+
self._sync_w2m()
|
|
2242
|
+
if partial_steps is not None:
|
|
2243
|
+
result = out.new_zeros(tensordict_save.shape)
|
|
2244
|
+
|
|
2245
|
+
def select_and_clone(x, y):
|
|
2246
|
+
if y is not None:
|
|
2247
|
+
if x.device != y.device:
|
|
2248
|
+
x = x.to(y.device)
|
|
2249
|
+
else:
|
|
2250
|
+
x = x.clone()
|
|
2251
|
+
return x
|
|
2252
|
+
|
|
2253
|
+
prev = tensordict_save._fast_apply(
|
|
2254
|
+
select_and_clone,
|
|
2255
|
+
result,
|
|
2256
|
+
filter_empty=True,
|
|
2257
|
+
device=result.device,
|
|
2258
|
+
batch_size=result.batch_size,
|
|
2259
|
+
is_leaf=_is_leaf_nontensor,
|
|
2260
|
+
default=None,
|
|
2261
|
+
)
|
|
2262
|
+
|
|
2263
|
+
result.update(prev)
|
|
2264
|
+
if partial_steps.any():
|
|
2265
|
+
result[partial_steps] = out
|
|
2266
|
+
return result
|
|
2267
|
+
return out
|
|
2268
|
+
|
|
2269
|
+
def _reset_no_buffers(
|
|
2270
|
+
self,
|
|
2271
|
+
tensordict: TensorDictBase,
|
|
2272
|
+
reset_kwargs_list,
|
|
2273
|
+
needs_resetting,
|
|
2274
|
+
) -> tuple[TensorDictBase, TensorDictBase]:
|
|
2275
|
+
if is_tensor_collection(tensordict):
|
|
2276
|
+
if self.consolidate:
|
|
2277
|
+
try:
|
|
2278
|
+
tensordict = tensordict.consolidate(
|
|
2279
|
+
# share_memory=False: avoid resource_sharer which causes
|
|
2280
|
+
# progressive slowdown with fork on Linux
|
|
2281
|
+
share_memory=False,
|
|
2282
|
+
num_threads=1,
|
|
2283
|
+
)
|
|
2284
|
+
except Exception as err:
|
|
2285
|
+
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
|
|
2286
|
+
tensordict = tensordict.unbind(0)
|
|
2287
|
+
else:
|
|
2288
|
+
tensordict = [None] * self.num_workers
|
|
2289
|
+
out_tds = [None] * self.num_workers
|
|
2290
|
+
needs_resetting_int = []
|
|
2291
|
+
for i, (local_data, reset_kwargs) in enumerate(
|
|
2292
|
+
zip(tensordict, reset_kwargs_list)
|
|
2293
|
+
):
|
|
2294
|
+
if not needs_resetting[i]:
|
|
2295
|
+
localtd = local_data
|
|
2296
|
+
if localtd is not None:
|
|
2297
|
+
localtd = localtd.exclude(*self.reset_keys)
|
|
2298
|
+
out_tds[i] = localtd
|
|
2299
|
+
continue
|
|
2300
|
+
needs_resetting_int.append(i)
|
|
2301
|
+
self.parent_channels[i].send(("reset", (local_data, reset_kwargs)))
|
|
2302
|
+
|
|
2303
|
+
self._wait_for_workers(needs_resetting_int)
|
|
2304
|
+
|
|
2305
|
+
for i, channel in enumerate(self.parent_channels):
|
|
2306
|
+
if not needs_resetting[i]:
|
|
2307
|
+
continue
|
|
2308
|
+
td = channel.recv()
|
|
2309
|
+
out_tds[i] = td
|
|
2310
|
+
result = LazyStackedTensorDict.maybe_dense_stack(out_tds)
|
|
2311
|
+
device = self.device
|
|
2312
|
+
if device is not None and result.device != device:
|
|
2313
|
+
return result.to(self.device, non_blocking=self.non_blocking)
|
|
2314
|
+
return result
|
|
2315
|
+
|
|
2316
|
+
@torch.no_grad()
|
|
2317
|
+
@_check_start
|
|
2318
|
+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
|
|
2319
|
+
|
|
2320
|
+
list_of_kwargs = kwargs.pop("list_of_kwargs", [kwargs] * self.num_workers)
|
|
2321
|
+
if kwargs is not list_of_kwargs[0] and kwargs:
|
|
2322
|
+
# this means that kwargs had more than one element and that a list was provided
|
|
2323
|
+
for elt in list_of_kwargs:
|
|
2324
|
+
elt.update(kwargs)
|
|
2325
|
+
|
|
2326
|
+
if tensordict is not None:
|
|
2327
|
+
if "_reset" in tensordict.keys():
|
|
2328
|
+
needs_resetting = tensordict["_reset"]
|
|
2329
|
+
else:
|
|
2330
|
+
needs_resetting = _aggregate_end_of_traj(
|
|
2331
|
+
tensordict, reset_keys=self.reset_keys
|
|
2332
|
+
)
|
|
2333
|
+
if needs_resetting.ndim > 2:
|
|
2334
|
+
needs_resetting = needs_resetting.flatten(1, needs_resetting.ndim - 1)
|
|
2335
|
+
if needs_resetting.ndim > 1:
|
|
2336
|
+
needs_resetting = needs_resetting.any(-1)
|
|
2337
|
+
elif not needs_resetting.ndim:
|
|
2338
|
+
needs_resetting = needs_resetting.expand((self.num_workers,))
|
|
2339
|
+
else:
|
|
2340
|
+
needs_resetting = torch.ones(
|
|
2341
|
+
(self.num_workers,), device=self.device, dtype=torch.bool
|
|
2342
|
+
)
|
|
2343
|
+
|
|
2344
|
+
if not self._use_buffers:
|
|
2345
|
+
return self._reset_no_buffers(tensordict, list_of_kwargs, needs_resetting)
|
|
2346
|
+
|
|
2347
|
+
outs = []
|
|
2348
|
+
for i in range(self.num_workers):
|
|
2349
|
+
if tensordict is not None:
|
|
2350
|
+
tensordict_ = tensordict[i]
|
|
2351
|
+
if tensordict_.is_empty():
|
|
2352
|
+
tensordict_ = None
|
|
2353
|
+
elif self.device is not None and self.device.type == "mps":
|
|
2354
|
+
# copy_ fails when moving mps->cpu using copy_
|
|
2355
|
+
# in some cases when a view of an mps tensor is used.
|
|
2356
|
+
# We know the shared tensors are not MPS, so we can
|
|
2357
|
+
# safely assume that the shared tensors are on cpu
|
|
2358
|
+
tensordict_ = tensordict_.to("cpu")
|
|
2359
|
+
else:
|
|
2360
|
+
tensordict_ = None
|
|
2361
|
+
if not needs_resetting[i]:
|
|
2362
|
+
# We update the stored tensordict with the value of the "next"
|
|
2363
|
+
# key as one may be surprised to receive data that is not up-to-date
|
|
2364
|
+
# If we don't do this, the result of calling reset and skipping one env
|
|
2365
|
+
# will be that the env will have the data from the previous
|
|
2366
|
+
# step at the root (since the shared_tensordict did not go through
|
|
2367
|
+
# step_mdp).
|
|
2368
|
+
self.shared_tensordicts[i].update_(
|
|
2369
|
+
self.shared_tensordicts[i].get("next"),
|
|
2370
|
+
keys_to_update=list(self._selected_reset_keys),
|
|
2371
|
+
non_blocking=self.non_blocking,
|
|
2372
|
+
)
|
|
2373
|
+
if tensordict_ is not None:
|
|
2374
|
+
self.shared_tensordicts[i].update_(
|
|
2375
|
+
tensordict_,
|
|
2376
|
+
keys_to_update=list(self._selected_reset_keys),
|
|
2377
|
+
non_blocking=self.non_blocking,
|
|
2378
|
+
)
|
|
2379
|
+
continue
|
|
2380
|
+
if tensordict_ is not None:
|
|
2381
|
+
tdkeys = list(tensordict_.keys(True, True))
|
|
2382
|
+
|
|
2383
|
+
# This way we can avoid calling select over all the keys in the shared tensordict
|
|
2384
|
+
def tentative_update(val, other):
|
|
2385
|
+
if other is not None:
|
|
2386
|
+
val.copy_(other, non_blocking=self.non_blocking)
|
|
2387
|
+
return val
|
|
2388
|
+
|
|
2389
|
+
self.shared_tensordicts[i].apply_(
|
|
2390
|
+
tentative_update, tensordict_, default=None
|
|
2391
|
+
)
|
|
2392
|
+
out = ("reset", (tdkeys, list_of_kwargs[i]))
|
|
2393
|
+
else:
|
|
2394
|
+
out = ("reset", (False, list_of_kwargs[i]))
|
|
2395
|
+
outs.append((i, out))
|
|
2396
|
+
|
|
2397
|
+
self._sync_m2w()
|
|
2398
|
+
|
|
2399
|
+
for i, out in outs:
|
|
2400
|
+
self.parent_channels[i].send(out)
|
|
2401
|
+
|
|
2402
|
+
self._wait_for_workers(list(zip(*outs))[0])
|
|
2403
|
+
|
|
2404
|
+
workers_nontensor = []
|
|
2405
|
+
if self._non_tensor_keys:
|
|
2406
|
+
for i, _ in outs:
|
|
2407
|
+
msg, non_tensor_td = self.parent_channels[i].recv()
|
|
2408
|
+
workers_nontensor.append((i, non_tensor_td))
|
|
2409
|
+
|
|
2410
|
+
selected_output_keys = self._selected_reset_keys_filt
|
|
2411
|
+
device = self.device
|
|
2412
|
+
|
|
2413
|
+
out = self.shared_tensordict_parent.named_apply(
|
|
2414
|
+
lambda *args: self.select_and_clone(
|
|
2415
|
+
*args, selected_keys=selected_output_keys
|
|
2416
|
+
),
|
|
2417
|
+
nested_keys=True,
|
|
2418
|
+
filter_empty=True,
|
|
2419
|
+
device=device,
|
|
2420
|
+
)
|
|
2421
|
+
if self._non_tensor_keys:
|
|
2422
|
+
workers, nontensor = zip(*workers_nontensor)
|
|
2423
|
+
out[torch.tensor(workers)] = LazyStackedTensorDict(*nontensor).select(
|
|
2424
|
+
*self._non_tensor_keys
|
|
2425
|
+
)
|
|
2426
|
+
self._sync_w2m()
|
|
2427
|
+
return out
|
|
2428
|
+
|
|
2429
|
+
@_check_start
|
|
2430
|
+
def _shutdown_workers(self) -> None:
|
|
2431
|
+
try:
|
|
2432
|
+
if self.is_closed:
|
|
2433
|
+
raise RuntimeError(
|
|
2434
|
+
"calling {self.__class__.__name__}._shutdown_workers only allowed when env.is_closed = False"
|
|
2435
|
+
)
|
|
2436
|
+
for i, channel in enumerate(self.parent_channels):
|
|
2437
|
+
if self._verbose:
|
|
2438
|
+
torchrl_logger.info(f"closing {i}")
|
|
2439
|
+
channel.send(("close", None))
|
|
2440
|
+
for i in range(self.num_workers):
|
|
2441
|
+
self._events[i].wait(self._timeout)
|
|
2442
|
+
self._events[i].clear()
|
|
2443
|
+
if self._use_buffers:
|
|
2444
|
+
del self.shared_tensordicts, self.shared_tensordict_parent
|
|
2445
|
+
|
|
2446
|
+
for channel in self.parent_channels:
|
|
2447
|
+
channel.close()
|
|
2448
|
+
start_time = time.time()
|
|
2449
|
+
while (
|
|
2450
|
+
any(proc.is_alive() for proc in self._workers)
|
|
2451
|
+
and (time.time() - start_time) < self._timeout
|
|
2452
|
+
):
|
|
2453
|
+
time.sleep(0.01)
|
|
2454
|
+
for proc in self._workers:
|
|
2455
|
+
proc.join()
|
|
2456
|
+
finally:
|
|
2457
|
+
for proc in self._workers:
|
|
2458
|
+
if proc.is_alive():
|
|
2459
|
+
proc.terminate()
|
|
2460
|
+
del self._workers
|
|
2461
|
+
del self.parent_channels
|
|
2462
|
+
self._cuda_events = None
|
|
2463
|
+
self._events = None
|
|
2464
|
+
self.event = None
|
|
2465
|
+
|
|
2466
|
+
@_check_start
|
|
2467
|
+
def set_seed(
|
|
2468
|
+
self, seed: int | None = None, static_seed: bool = False
|
|
2469
|
+
) -> int | None:
|
|
2470
|
+
self._seeds = []
|
|
2471
|
+
for channel in self.parent_channels:
|
|
2472
|
+
channel.send(("seed", (seed, static_seed)))
|
|
2473
|
+
self._seeds.append(seed)
|
|
2474
|
+
msg, new_seed = channel.recv()
|
|
2475
|
+
if msg != "seeded":
|
|
2476
|
+
raise RuntimeError(f"Expected 'seeded' but received {msg}")
|
|
2477
|
+
seed = new_seed
|
|
2478
|
+
return seed
|
|
2479
|
+
|
|
2480
|
+
def __reduce__(self):
|
|
2481
|
+
if not self.is_closed:
|
|
2482
|
+
# ParallelEnv contains non-instantiated envs, thus it can be
|
|
2483
|
+
# closed and serialized if the environment building functions
|
|
2484
|
+
# permit it
|
|
2485
|
+
self.close()
|
|
2486
|
+
return super().__reduce__()
|
|
2487
|
+
|
|
2488
|
+
def __getattr__(self, attr: str) -> Any:
|
|
2489
|
+
if attr in self.__dir__():
|
|
2490
|
+
return super().__getattr__(
|
|
2491
|
+
attr
|
|
2492
|
+
) # make sure that appropriate exceptions are raised
|
|
2493
|
+
elif attr.startswith("__"):
|
|
2494
|
+
raise AttributeError(
|
|
2495
|
+
"dispatching built-in private methods is not permitted."
|
|
2496
|
+
)
|
|
2497
|
+
else:
|
|
2498
|
+
if attr in self._excluded_wrapped_keys:
|
|
2499
|
+
raise AttributeError(f"Getting {attr} resulted in an exception")
|
|
2500
|
+
try:
|
|
2501
|
+
# _ = getattr(self._dummy_env, attr)
|
|
2502
|
+
if self.is_closed:
|
|
2503
|
+
self.start()
|
|
2504
|
+
raise RuntimeError(
|
|
2505
|
+
"Trying to access attributes of closed/non started "
|
|
2506
|
+
"environments. Check that the batched environment "
|
|
2507
|
+
"has been started (e.g. by calling env.reset)"
|
|
2508
|
+
)
|
|
2509
|
+
# dispatch to workers
|
|
2510
|
+
return _dispatch_caller_parallel(attr, self)
|
|
2511
|
+
except AttributeError:
|
|
2512
|
+
raise AttributeError(
|
|
2513
|
+
f"attribute {attr} not found in " f"{self._dummy_env_str}"
|
|
2514
|
+
)
|
|
2515
|
+
|
|
2516
|
+
def to(self, device: DEVICE_TYPING):
|
|
2517
|
+
device = _make_ordinal_device(torch.device(device))
|
|
2518
|
+
if device == self.device:
|
|
2519
|
+
return self
|
|
2520
|
+
super().to(device)
|
|
2521
|
+
if self._seeds is not None:
|
|
2522
|
+
warn(
|
|
2523
|
+
"Sending a seeded ParallelEnv to another device requires "
|
|
2524
|
+
f"re-seeding it. Re-seeding envs to {self._seeds}."
|
|
2525
|
+
)
|
|
2526
|
+
self.set_seed(self._seeds[0])
|
|
2527
|
+
return self
|
|
2528
|
+
|
|
2529
|
+
@classmethod
|
|
2530
|
+
def make_parallel(cls, *args, num_envs: int = 1, **parallel_kwargs) -> EnvBase:
|
|
2531
|
+
"""Backward-compatible factory matching EnvBase.make_parallel signature.
|
|
2532
|
+
|
|
2533
|
+
Supports calls like:
|
|
2534
|
+
ParallelEnv.make_parallel(create_env_fn, num_envs=4, ...)
|
|
2535
|
+
or the constructor form:
|
|
2536
|
+
ParallelEnv.make_parallel(num_workers, create_env_fn, ...)
|
|
2537
|
+
"""
|
|
2538
|
+
if len(args) >= 1 and isinstance(args[0], int):
|
|
2539
|
+
return cls(*args, **parallel_kwargs)
|
|
2540
|
+
if len(args) >= 1:
|
|
2541
|
+
create_env_fn = args[0]
|
|
2542
|
+
other_args = args[1:]
|
|
2543
|
+
return cls(int(num_envs), create_env_fn, *other_args, **parallel_kwargs)
|
|
2544
|
+
return cls(int(num_envs), **parallel_kwargs)
|
|
2545
|
+
|
|
2546
|
+
|
|
2547
|
+
def _recursively_strip_locks_from_state_dict(state_dict: OrderedDict) -> OrderedDict:
|
|
2548
|
+
return OrderedDict(
|
|
2549
|
+
**{
|
|
2550
|
+
k: _recursively_strip_locks_from_state_dict(item)
|
|
2551
|
+
if isinstance(item, OrderedDict)
|
|
2552
|
+
else None
|
|
2553
|
+
if isinstance(item, MpLock)
|
|
2554
|
+
else item
|
|
2555
|
+
for k, item in state_dict.items()
|
|
2556
|
+
}
|
|
2557
|
+
)
|
|
2558
|
+
|
|
2559
|
+
|
|
2560
|
+
def _run_worker_pipe_shared_mem(
|
|
2561
|
+
parent_pipe: connection.Connection,
|
|
2562
|
+
child_pipe: connection.Connection,
|
|
2563
|
+
env_fun: EnvBase | Callable,
|
|
2564
|
+
env_fun_kwargs: dict[str, Any],
|
|
2565
|
+
mp_event: mp.Event = None,
|
|
2566
|
+
shared_tensordict: TensorDictBase = None,
|
|
2567
|
+
_selected_input_keys=None,
|
|
2568
|
+
_selected_reset_keys=None,
|
|
2569
|
+
_selected_step_keys=None,
|
|
2570
|
+
_non_tensor_keys=None,
|
|
2571
|
+
non_blocking: bool = False,
|
|
2572
|
+
has_lazy_inputs: bool = False,
|
|
2573
|
+
verbose: bool = False,
|
|
2574
|
+
num_threads: int | None = None, # for fork start method
|
|
2575
|
+
filter_warnings: bool = False,
|
|
2576
|
+
) -> None:
|
|
2577
|
+
pid = os.getpid()
|
|
2578
|
+
# Handle warning filtering (moved from _ProcessNoWarn)
|
|
2579
|
+
if filter_warnings:
|
|
2580
|
+
warnings.filterwarnings("ignore")
|
|
2581
|
+
if num_threads is not None:
|
|
2582
|
+
torch.set_num_threads(num_threads)
|
|
2583
|
+
device = shared_tensordict.device
|
|
2584
|
+
if device is None or device.type != "cuda":
|
|
2585
|
+
# Check if some tensors are shared on cuda
|
|
2586
|
+
has_cuda = [False]
|
|
2587
|
+
|
|
2588
|
+
def look_for_cuda(tensor, has_cuda=has_cuda):
|
|
2589
|
+
has_cuda[0] = has_cuda[0] or tensor.is_cuda
|
|
2590
|
+
|
|
2591
|
+
shared_tensordict.apply(look_for_cuda, filter_empty=True)
|
|
2592
|
+
has_cuda = has_cuda[0]
|
|
2593
|
+
else:
|
|
2594
|
+
has_cuda = device.type == "cuda"
|
|
2595
|
+
if has_cuda:
|
|
2596
|
+
event = torch.cuda.Event()
|
|
2597
|
+
else:
|
|
2598
|
+
event = None
|
|
2599
|
+
parent_pipe.close()
|
|
2600
|
+
if not isinstance(env_fun, EnvBase):
|
|
2601
|
+
env = env_fun(**env_fun_kwargs)
|
|
2602
|
+
else:
|
|
2603
|
+
if env_fun_kwargs:
|
|
2604
|
+
raise RuntimeError(
|
|
2605
|
+
"env_fun_kwargs must be empty if an environment is passed to a process."
|
|
2606
|
+
)
|
|
2607
|
+
env = env_fun
|
|
2608
|
+
del env_fun
|
|
2609
|
+
env.set_spec_lock_()
|
|
2610
|
+
|
|
2611
|
+
i = -1
|
|
2612
|
+
import torchrl
|
|
2613
|
+
|
|
2614
|
+
_timeout = torchrl._utils.BATCHED_PIPE_TIMEOUT
|
|
2615
|
+
|
|
2616
|
+
initialized = False
|
|
2617
|
+
|
|
2618
|
+
child_pipe.send("started")
|
|
2619
|
+
next_shared_tensordict, root_shared_tensordict = (None,) * 2
|
|
2620
|
+
_cmd_count = 0
|
|
2621
|
+
_last_cmd = "N/A"
|
|
2622
|
+
# Create a timeit instance to track elapsed time since worker start
|
|
2623
|
+
_worker_timer = timeit(f"batched_env_worker/{pid}/lifetime").start()
|
|
2624
|
+
while True:
|
|
2625
|
+
try:
|
|
2626
|
+
if child_pipe.poll(_timeout):
|
|
2627
|
+
cmd, data = child_pipe.recv()
|
|
2628
|
+
_cmd_count += 1
|
|
2629
|
+
_last_cmd = cmd
|
|
2630
|
+
# Log every 1000 commands
|
|
2631
|
+
if _cmd_count % 1000 == 0:
|
|
2632
|
+
torchrl_logger.debug(
|
|
2633
|
+
f"batched_env worker {pid}: cmd_count={_cmd_count}, "
|
|
2634
|
+
f"elapsed={_worker_timer.elapsed():.1f}s, last_cmd={cmd}"
|
|
2635
|
+
)
|
|
2636
|
+
else:
|
|
2637
|
+
torchrl_logger.debug(
|
|
2638
|
+
f"batched_env worker {pid}: TIMEOUT after {_timeout}s waiting for cmd, "
|
|
2639
|
+
f"elapsed_since_start={_worker_timer.elapsed():.1f}s, "
|
|
2640
|
+
f"last_cmd={_last_cmd}, cmd_count={_cmd_count}"
|
|
2641
|
+
)
|
|
2642
|
+
raise TimeoutError(
|
|
2643
|
+
f"Worker timed out after {_timeout}s, "
|
|
2644
|
+
f"increase timeout if needed through the BATCHED_PIPE_TIMEOUT environment variable."
|
|
2645
|
+
)
|
|
2646
|
+
except EOFError as err:
|
|
2647
|
+
torchrl_logger.debug(
|
|
2648
|
+
f"batched_env worker {pid}: EOFError - pipe closed, "
|
|
2649
|
+
f"elapsed_since_start={_worker_timer.elapsed():.1f}s, "
|
|
2650
|
+
f"last_cmd={_last_cmd}, cmd_count={_cmd_count}"
|
|
2651
|
+
)
|
|
2652
|
+
raise EOFError(f"proc {pid} failed, last command: {_last_cmd}.") from err
|
|
2653
|
+
if cmd == "seed":
|
|
2654
|
+
if not initialized:
|
|
2655
|
+
raise RuntimeError("call 'init' before closing")
|
|
2656
|
+
torch.manual_seed(data[0])
|
|
2657
|
+
new_seed = env.set_seed(data[0], static_seed=data[1])
|
|
2658
|
+
child_pipe.send(("seeded", new_seed))
|
|
2659
|
+
|
|
2660
|
+
elif cmd == "init":
|
|
2661
|
+
if verbose:
|
|
2662
|
+
torchrl_logger.info(f"initializing {pid}")
|
|
2663
|
+
if initialized:
|
|
2664
|
+
raise RuntimeError("worker already initialized")
|
|
2665
|
+
i = 0
|
|
2666
|
+
next_shared_tensordict = shared_tensordict.get("next")
|
|
2667
|
+
root_shared_tensordict = shared_tensordict.exclude("next")
|
|
2668
|
+
# TODO: restore this
|
|
2669
|
+
# if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()):
|
|
2670
|
+
# raise RuntimeError(
|
|
2671
|
+
# "tensordict must be placed in shared memory (share_memory_() or memmap_())"
|
|
2672
|
+
# )
|
|
2673
|
+
shared_tensordict = shared_tensordict.clone(False).unlock_()
|
|
2674
|
+
|
|
2675
|
+
initialized = True
|
|
2676
|
+
|
|
2677
|
+
elif cmd == "reset":
|
|
2678
|
+
if verbose:
|
|
2679
|
+
torchrl_logger.info(f"resetting worker {pid}")
|
|
2680
|
+
if not initialized:
|
|
2681
|
+
raise RuntimeError("call 'init' before resetting")
|
|
2682
|
+
# we use 'data' to pass the keys that we need to pass to reset,
|
|
2683
|
+
# because passing the entire buffer may have unwanted consequences
|
|
2684
|
+
selected_reset_keys, reset_kwargs = data
|
|
2685
|
+
cur_td = env.reset(
|
|
2686
|
+
tensordict=root_shared_tensordict.select(
|
|
2687
|
+
*selected_reset_keys, strict=False
|
|
2688
|
+
)
|
|
2689
|
+
if selected_reset_keys
|
|
2690
|
+
else None,
|
|
2691
|
+
**reset_kwargs,
|
|
2692
|
+
)
|
|
2693
|
+
shared_tensordict.update_(
|
|
2694
|
+
cur_td,
|
|
2695
|
+
keys_to_update=list(_selected_reset_keys),
|
|
2696
|
+
non_blocking=non_blocking,
|
|
2697
|
+
)
|
|
2698
|
+
if event is not None:
|
|
2699
|
+
event.record()
|
|
2700
|
+
event.synchronize()
|
|
2701
|
+
|
|
2702
|
+
if _non_tensor_keys:
|
|
2703
|
+
# Set event BEFORE sending to avoid deadlocks when the pipe buffer
|
|
2704
|
+
# is full (the parent will start reading as soon as it observes
|
|
2705
|
+
# the event).
|
|
2706
|
+
mp_event.set()
|
|
2707
|
+
child_pipe.send(
|
|
2708
|
+
("non_tensor", cur_td.select(*_non_tensor_keys, strict=False))
|
|
2709
|
+
)
|
|
2710
|
+
else:
|
|
2711
|
+
mp_event.set()
|
|
2712
|
+
|
|
2713
|
+
del cur_td
|
|
2714
|
+
|
|
2715
|
+
elif cmd == "step":
|
|
2716
|
+
if not initialized:
|
|
2717
|
+
raise RuntimeError("called 'init' before step")
|
|
2718
|
+
i += 1
|
|
2719
|
+
# No need to copy here since we don't write in-place
|
|
2720
|
+
input = root_shared_tensordict.copy()
|
|
2721
|
+
if data:
|
|
2722
|
+
next_td_passthrough_keys = data.get("next_td_passthrough_keys")
|
|
2723
|
+
if next_td_passthrough_keys is not None:
|
|
2724
|
+
input = input.set(
|
|
2725
|
+
"next", next_shared_tensordict.select(*next_td_passthrough_keys)
|
|
2726
|
+
)
|
|
2727
|
+
non_tensor_data = data.get("non_tensor_data")
|
|
2728
|
+
if non_tensor_data is not None:
|
|
2729
|
+
input.update(non_tensor_data)
|
|
2730
|
+
|
|
2731
|
+
input = env.step(input)
|
|
2732
|
+
next_td = input.get("next")
|
|
2733
|
+
next_shared_tensordict.update_(next_td, non_blocking=non_blocking)
|
|
2734
|
+
|
|
2735
|
+
if event is not None:
|
|
2736
|
+
event.record()
|
|
2737
|
+
event.synchronize()
|
|
2738
|
+
|
|
2739
|
+
# Make sure the root is updated
|
|
2740
|
+
root_shared_tensordict.update_(env._step_mdp(input))
|
|
2741
|
+
|
|
2742
|
+
if _non_tensor_keys:
|
|
2743
|
+
# Set event BEFORE sending to avoid deadlocks when the pipe buffer
|
|
2744
|
+
# is full (the parent will start reading as soon as it observes
|
|
2745
|
+
# the event).
|
|
2746
|
+
mp_event.set()
|
|
2747
|
+
child_pipe.send(
|
|
2748
|
+
("non_tensor", next_td.select(*_non_tensor_keys, strict=False))
|
|
2749
|
+
)
|
|
2750
|
+
else:
|
|
2751
|
+
mp_event.set()
|
|
2752
|
+
|
|
2753
|
+
del next_td
|
|
2754
|
+
|
|
2755
|
+
elif cmd == "step_and_maybe_reset":
|
|
2756
|
+
if not initialized:
|
|
2757
|
+
raise RuntimeError("called 'init' before step")
|
|
2758
|
+
i += 1
|
|
2759
|
+
# We must copy the root shared td here, or at least get rid of done:
|
|
2760
|
+
# if we don't `td is root_shared_tensordict`
|
|
2761
|
+
# which means that root_shared_tensordict will carry the content of next
|
|
2762
|
+
# in the next iteration. When using StepCounter, it will look for an
|
|
2763
|
+
# existing done state, find it and consider the env as done by input (not
|
|
2764
|
+
# by output) of the step!
|
|
2765
|
+
# Caveat: for RNN we may need some keys of the "next" TD so we pass the list
|
|
2766
|
+
# through data
|
|
2767
|
+
input = root_shared_tensordict
|
|
2768
|
+
if data:
|
|
2769
|
+
next_td_passthrough_keys = data.get("next_td_passthrough_keys", None)
|
|
2770
|
+
if next_td_passthrough_keys is not None:
|
|
2771
|
+
input = input.set(
|
|
2772
|
+
"next", next_shared_tensordict.select(*next_td_passthrough_keys)
|
|
2773
|
+
)
|
|
2774
|
+
non_tensor_data = data.get("non_tensor_data", None)
|
|
2775
|
+
if non_tensor_data is not None:
|
|
2776
|
+
input.update(non_tensor_data)
|
|
2777
|
+
td, root_next_td = env.step_and_maybe_reset(input)
|
|
2778
|
+
td_next = td.pop("next")
|
|
2779
|
+
next_shared_tensordict.update_(td_next, non_blocking=non_blocking)
|
|
2780
|
+
root_shared_tensordict.update_(root_next_td, non_blocking=non_blocking)
|
|
2781
|
+
|
|
2782
|
+
if event is not None:
|
|
2783
|
+
event.record()
|
|
2784
|
+
event.synchronize()
|
|
2785
|
+
|
|
2786
|
+
if _non_tensor_keys:
|
|
2787
|
+
ntd = root_next_td.select(*_non_tensor_keys)
|
|
2788
|
+
ntd.set("next", td_next.select(*_non_tensor_keys))
|
|
2789
|
+
# Set event BEFORE sending to avoid deadlocks when the pipe buffer
|
|
2790
|
+
# is full (the parent will start reading as soon as it observes
|
|
2791
|
+
# the event).
|
|
2792
|
+
mp_event.set()
|
|
2793
|
+
child_pipe.send(("non_tensor", ntd))
|
|
2794
|
+
else:
|
|
2795
|
+
mp_event.set()
|
|
2796
|
+
|
|
2797
|
+
del td, root_next_td
|
|
2798
|
+
|
|
2799
|
+
elif cmd == "close":
|
|
2800
|
+
if not initialized:
|
|
2801
|
+
raise RuntimeError("call 'init' before closing")
|
|
2802
|
+
env.close()
|
|
2803
|
+
del (
|
|
2804
|
+
env,
|
|
2805
|
+
shared_tensordict,
|
|
2806
|
+
data,
|
|
2807
|
+
next_shared_tensordict,
|
|
2808
|
+
root_shared_tensordict,
|
|
2809
|
+
)
|
|
2810
|
+
mp_event.set()
|
|
2811
|
+
child_pipe.close()
|
|
2812
|
+
if verbose:
|
|
2813
|
+
torchrl_logger.info(f"{pid} closed")
|
|
2814
|
+
gc.collect()
|
|
2815
|
+
break
|
|
2816
|
+
|
|
2817
|
+
elif cmd == "load_state_dict":
|
|
2818
|
+
env.load_state_dict(data)
|
|
2819
|
+
mp_event.set()
|
|
2820
|
+
|
|
2821
|
+
elif cmd == "state_dict":
|
|
2822
|
+
state_dict = _recursively_strip_locks_from_state_dict(env.state_dict())
|
|
2823
|
+
msg = "state_dict"
|
|
2824
|
+
child_pipe.send((msg, state_dict))
|
|
2825
|
+
del state_dict
|
|
2826
|
+
|
|
2827
|
+
else:
|
|
2828
|
+
err_msg = f"{cmd} from env"
|
|
2829
|
+
try:
|
|
2830
|
+
attr = getattr(env, cmd)
|
|
2831
|
+
if callable(attr):
|
|
2832
|
+
args, kwargs = data
|
|
2833
|
+
args_replace = []
|
|
2834
|
+
for _arg in args:
|
|
2835
|
+
if isinstance(_arg, str) and _arg == "_self":
|
|
2836
|
+
continue
|
|
2837
|
+
else:
|
|
2838
|
+
args_replace.append(_arg)
|
|
2839
|
+
result = attr(*args_replace, **kwargs)
|
|
2840
|
+
else:
|
|
2841
|
+
result = attr
|
|
2842
|
+
except Exception as err:
|
|
2843
|
+
raise AttributeError(
|
|
2844
|
+
f"querying {err_msg} resulted in an error."
|
|
2845
|
+
) from err
|
|
2846
|
+
if cmd not in ("to"):
|
|
2847
|
+
child_pipe.send(("_".join([cmd, "done"]), result))
|
|
2848
|
+
else:
|
|
2849
|
+
# don't send env through pipe
|
|
2850
|
+
child_pipe.send(("_".join([cmd, "done"]), None))
|
|
2851
|
+
|
|
2852
|
+
|
|
2853
|
+
def _run_worker_pipe_direct(
|
|
2854
|
+
parent_pipe: connection.Connection,
|
|
2855
|
+
child_pipe: connection.Connection,
|
|
2856
|
+
env_fun: EnvBase | Callable,
|
|
2857
|
+
env_fun_kwargs: dict[str, Any],
|
|
2858
|
+
mp_event: mp.Event = None,
|
|
2859
|
+
non_blocking: bool = False,
|
|
2860
|
+
has_lazy_inputs: bool = False,
|
|
2861
|
+
verbose: bool = False,
|
|
2862
|
+
num_threads: int | None = None, # for fork start method
|
|
2863
|
+
consolidate: bool = True,
|
|
2864
|
+
filter_warnings: bool = False,
|
|
2865
|
+
) -> None:
|
|
2866
|
+
# Handle warning filtering (moved from _ProcessNoWarn)
|
|
2867
|
+
if filter_warnings:
|
|
2868
|
+
warnings.filterwarnings("ignore")
|
|
2869
|
+
if num_threads is not None:
|
|
2870
|
+
torch.set_num_threads(num_threads)
|
|
2871
|
+
|
|
2872
|
+
parent_pipe.close()
|
|
2873
|
+
pid = os.getpid()
|
|
2874
|
+
if not isinstance(env_fun, EnvBase):
|
|
2875
|
+
env = env_fun(**env_fun_kwargs)
|
|
2876
|
+
else:
|
|
2877
|
+
if env_fun_kwargs:
|
|
2878
|
+
raise RuntimeError(
|
|
2879
|
+
"env_fun_kwargs must be empty if an environment is passed to a process."
|
|
2880
|
+
)
|
|
2881
|
+
env = env_fun
|
|
2882
|
+
del env_fun
|
|
2883
|
+
for spec in env.output_spec.values(True, True):
|
|
2884
|
+
if spec.device is not None and spec.device.type == "cuda":
|
|
2885
|
+
has_cuda = True
|
|
2886
|
+
break
|
|
2887
|
+
else:
|
|
2888
|
+
for spec in env.input_spec.values(True, True):
|
|
2889
|
+
if spec.device is not None and spec.device.type == "cuda":
|
|
2890
|
+
has_cuda = True
|
|
2891
|
+
break
|
|
2892
|
+
else:
|
|
2893
|
+
has_cuda = False
|
|
2894
|
+
if has_cuda:
|
|
2895
|
+
event = torch.cuda.Event()
|
|
2896
|
+
else:
|
|
2897
|
+
event = None
|
|
2898
|
+
|
|
2899
|
+
i = -1
|
|
2900
|
+
import torchrl
|
|
2901
|
+
|
|
2902
|
+
_timeout = torchrl._utils.BATCHED_PIPE_TIMEOUT
|
|
2903
|
+
|
|
2904
|
+
initialized = False
|
|
2905
|
+
|
|
2906
|
+
child_pipe.send("started")
|
|
2907
|
+
while True:
|
|
2908
|
+
try:
|
|
2909
|
+
if child_pipe.poll(_timeout):
|
|
2910
|
+
cmd, data = child_pipe.recv()
|
|
2911
|
+
else:
|
|
2912
|
+
raise TimeoutError(
|
|
2913
|
+
f"Worker timed out after {_timeout}s, "
|
|
2914
|
+
f"increase timeout if needed through the BATCHED_PIPE_TIMEOUT environment variable."
|
|
2915
|
+
)
|
|
2916
|
+
except EOFError as err:
|
|
2917
|
+
raise EOFError(f"proc {pid} failed, last command: {cmd}.") from err
|
|
2918
|
+
if cmd == "seed":
|
|
2919
|
+
if not initialized:
|
|
2920
|
+
raise RuntimeError("call 'init' before closing")
|
|
2921
|
+
# torch.manual_seed(data)
|
|
2922
|
+
# np.random.seed(data)
|
|
2923
|
+
new_seed = env.set_seed(data[0], static_seed=data[1])
|
|
2924
|
+
child_pipe.send(("seeded", new_seed))
|
|
2925
|
+
|
|
2926
|
+
elif cmd == "init":
|
|
2927
|
+
if verbose:
|
|
2928
|
+
torchrl_logger.info(f"initializing {pid}")
|
|
2929
|
+
if initialized:
|
|
2930
|
+
raise RuntimeError("worker already initialized")
|
|
2931
|
+
i = 0
|
|
2932
|
+
|
|
2933
|
+
initialized = True
|
|
2934
|
+
|
|
2935
|
+
elif cmd == "reset":
|
|
2936
|
+
if verbose:
|
|
2937
|
+
torchrl_logger.info(f"resetting worker {pid}")
|
|
2938
|
+
if not initialized:
|
|
2939
|
+
raise RuntimeError("call 'init' before resetting")
|
|
2940
|
+
# we use 'data' to pass the keys that we need to pass to reset,
|
|
2941
|
+
# because passing the entire buffer may have unwanted consequences
|
|
2942
|
+
data, reset_kwargs = data
|
|
2943
|
+
if data is not None:
|
|
2944
|
+
data.unlock_()
|
|
2945
|
+
data._fast_apply(
|
|
2946
|
+
lambda x: x.clone() if x.device.type == "cuda" else x, out=data
|
|
2947
|
+
)
|
|
2948
|
+
cur_td = env.reset(
|
|
2949
|
+
tensordict=data,
|
|
2950
|
+
**reset_kwargs,
|
|
2951
|
+
)
|
|
2952
|
+
if event is not None:
|
|
2953
|
+
event.record()
|
|
2954
|
+
event.synchronize()
|
|
2955
|
+
if consolidate:
|
|
2956
|
+
try:
|
|
2957
|
+
cur_td = cur_td.consolidate(
|
|
2958
|
+
# share_memory=False: avoid resource_sharer which causes
|
|
2959
|
+
# progressive slowdown with fork on Linux
|
|
2960
|
+
share_memory=False,
|
|
2961
|
+
inplace=True,
|
|
2962
|
+
num_threads=1,
|
|
2963
|
+
)
|
|
2964
|
+
except Exception as err:
|
|
2965
|
+
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
|
|
2966
|
+
# Set event BEFORE send so parent starts reading, which unblocks send
|
|
2967
|
+
# if pipe buffer was full (prevents deadlock)
|
|
2968
|
+
mp_event.set()
|
|
2969
|
+
child_pipe.send(cur_td)
|
|
2970
|
+
|
|
2971
|
+
del cur_td
|
|
2972
|
+
|
|
2973
|
+
elif cmd == "step":
|
|
2974
|
+
if not initialized:
|
|
2975
|
+
raise RuntimeError("called 'init' before step")
|
|
2976
|
+
i += 1
|
|
2977
|
+
next_td = env._step(data)
|
|
2978
|
+
if event is not None:
|
|
2979
|
+
event.record()
|
|
2980
|
+
event.synchronize()
|
|
2981
|
+
if consolidate:
|
|
2982
|
+
try:
|
|
2983
|
+
next_td = next_td.consolidate(
|
|
2984
|
+
# share_memory=False: avoid resource_sharer which causes
|
|
2985
|
+
# progressive slowdown with fork on Linux
|
|
2986
|
+
share_memory=False,
|
|
2987
|
+
inplace=True,
|
|
2988
|
+
num_threads=1,
|
|
2989
|
+
)
|
|
2990
|
+
except Exception as err:
|
|
2991
|
+
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
|
|
2992
|
+
# Set event BEFORE send so parent starts reading, which unblocks send
|
|
2993
|
+
# if pipe buffer was full (prevents deadlock)
|
|
2994
|
+
mp_event.set()
|
|
2995
|
+
child_pipe.send(next_td)
|
|
2996
|
+
|
|
2997
|
+
del next_td
|
|
2998
|
+
|
|
2999
|
+
elif cmd == "step_and_maybe_reset":
|
|
3000
|
+
if not initialized:
|
|
3001
|
+
raise RuntimeError("called 'init' before step")
|
|
3002
|
+
i += 1
|
|
3003
|
+
# data, idx = data
|
|
3004
|
+
# data = data[idx]
|
|
3005
|
+
data._fast_apply(
|
|
3006
|
+
lambda x: x.clone() if x.device.type == "cuda" else x, out=data
|
|
3007
|
+
)
|
|
3008
|
+
td, root_next_td = env.step_and_maybe_reset(data)
|
|
3009
|
+
|
|
3010
|
+
if event is not None:
|
|
3011
|
+
event.record()
|
|
3012
|
+
event.synchronize()
|
|
3013
|
+
child_pipe.send((td, root_next_td))
|
|
3014
|
+
mp_event.set()
|
|
3015
|
+
del td, root_next_td
|
|
3016
|
+
|
|
3017
|
+
elif cmd == "close":
|
|
3018
|
+
if not initialized:
|
|
3019
|
+
raise RuntimeError("call 'init' before closing")
|
|
3020
|
+
env.close()
|
|
3021
|
+
mp_event.set()
|
|
3022
|
+
child_pipe.close()
|
|
3023
|
+
if verbose:
|
|
3024
|
+
torchrl_logger.info(f"{pid} closed")
|
|
3025
|
+
del (env, data, child_pipe, mp_event)
|
|
3026
|
+
gc.collect()
|
|
3027
|
+
return
|
|
3028
|
+
|
|
3029
|
+
elif cmd == "load_state_dict":
|
|
3030
|
+
env.load_state_dict(data)
|
|
3031
|
+
mp_event.set()
|
|
3032
|
+
|
|
3033
|
+
elif cmd == "state_dict":
|
|
3034
|
+
state_dict = _recursively_strip_locks_from_state_dict(env.state_dict())
|
|
3035
|
+
msg = "state_dict"
|
|
3036
|
+
child_pipe.send((msg, state_dict))
|
|
3037
|
+
del state_dict
|
|
3038
|
+
|
|
3039
|
+
else:
|
|
3040
|
+
err_msg = f"{cmd} from env"
|
|
3041
|
+
try:
|
|
3042
|
+
attr = getattr(env, cmd)
|
|
3043
|
+
if callable(attr):
|
|
3044
|
+
args, kwargs = data
|
|
3045
|
+
args_replace = []
|
|
3046
|
+
for _arg in args:
|
|
3047
|
+
if isinstance(_arg, str) and _arg == "_self":
|
|
3048
|
+
continue
|
|
3049
|
+
else:
|
|
3050
|
+
args_replace.append(_arg)
|
|
3051
|
+
result = attr(*args_replace, **kwargs)
|
|
3052
|
+
else:
|
|
3053
|
+
result = attr
|
|
3054
|
+
except Exception as err:
|
|
3055
|
+
raise AttributeError(
|
|
3056
|
+
f"querying {err_msg} resulted in an error."
|
|
3057
|
+
) from err
|
|
3058
|
+
if cmd not in ("to"):
|
|
3059
|
+
child_pipe.send(("_".join([cmd, "done"]), result))
|
|
3060
|
+
else:
|
|
3061
|
+
# don't send env through pipe
|
|
3062
|
+
child_pipe.send(("_".join([cmd, "done"]), None))
|
|
3063
|
+
|
|
3064
|
+
|
|
3065
|
+
def _filter_empty(tensordict):
|
|
3066
|
+
return tensordict.select(*tensordict.keys(True, True))
|
|
3067
|
+
|
|
3068
|
+
|
|
3069
|
+
def _stackable(*tensordicts):
|
|
3070
|
+
try:
|
|
3071
|
+
ls = LazyStackedTensorDict(*tensordicts, stack_dim=0)
|
|
3072
|
+
ls.contiguous()
|
|
3073
|
+
return not ls._has_exclusive_keys
|
|
3074
|
+
except RuntimeError:
|
|
3075
|
+
return False
|
|
3076
|
+
|
|
3077
|
+
|
|
3078
|
+
def _cuda_sync(device):
|
|
3079
|
+
return functools.partial(torch.cuda.synchronize, device=device)
|
|
3080
|
+
|
|
3081
|
+
|
|
3082
|
+
def _mps_sync(device):
|
|
3083
|
+
return torch.mps.synchronize
|
|
3084
|
+
|
|
3085
|
+
|
|
3086
|
+
# Create an alias for possible imports
|
|
3087
|
+
_BatchedEnv = BatchedEnvBase
|
|
3088
|
+
|
|
3089
|
+
# legacy re-exports (must be at end of file to avoid circular imports)
|
|
3090
|
+
from torchrl.envs.libs.envpool import ( # noqa: F401, E402
|
|
3091
|
+
MultiThreadedEnv,
|
|
3092
|
+
MultiThreadedEnvWrapper,
|
|
3093
|
+
)
|