torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import importlib
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
from tensordict import TensorDict, TensorDictBase
|
|
14
|
+
from torchrl._utils import logger as torchrl_logger
|
|
15
|
+
from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec, Unbounded
|
|
16
|
+
from torchrl.envs.common import _EnvWrapper
|
|
17
|
+
from torchrl.envs.utils import _classproperty
|
|
18
|
+
|
|
19
|
+
_has_envpool = importlib.util.find_spec("envpool") is not None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MultiThreadedEnvWrapper(_EnvWrapper):
|
|
23
|
+
"""Wrapper for envpool-based multithreaded environments.
|
|
24
|
+
|
|
25
|
+
GitHub: https://github.com/sail-sg/envpool
|
|
26
|
+
|
|
27
|
+
Paper: https://arxiv.org/abs/2206.10558
|
|
28
|
+
|
|
29
|
+
EnvPool environments auto-reset internally when episodes end. This wrapper
|
|
30
|
+
handles that behavior by caching the auto-reset observations and returning
|
|
31
|
+
them appropriately in step_and_maybe_reset.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
env (envpool.python.envpool.EnvPoolMixin): the envpool to wrap.
|
|
35
|
+
categorical_action_encoding (bool, optional): if ``True``, categorical
|
|
36
|
+
specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
|
|
37
|
+
otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
|
|
38
|
+
Defaults to ``False``.
|
|
39
|
+
|
|
40
|
+
Keyword Args:
|
|
41
|
+
disable_env_checker (bool, optional): for gym > 0.24 only. If ``True`` (default
|
|
42
|
+
for these versions), the environment checker won't be run.
|
|
43
|
+
frame_skip (int, optional): if provided, indicates for how many steps the
|
|
44
|
+
same action is to be repeated. The observation returned will be the
|
|
45
|
+
last observation of the sequence, whereas the reward will be the sum
|
|
46
|
+
of rewards across steps.
|
|
47
|
+
device (torch.device, optional): if provided, the device on which the data
|
|
48
|
+
is to be cast. Defaults to ``torch.device("cpu")``.
|
|
49
|
+
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
|
|
50
|
+
for envs to be ``done`` just after :meth:`reset` is called.
|
|
51
|
+
Defaults to ``False``.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
batch_size: The number of envs run simultaneously.
|
|
55
|
+
|
|
56
|
+
Examples:
|
|
57
|
+
>>> import envpool
|
|
58
|
+
>>> from torchrl.envs import MultiThreadedEnvWrapper
|
|
59
|
+
>>> env_base = envpool.make(
|
|
60
|
+
... task_id="Pong-v5", env_type="gym", num_envs=4, gym_reset_return_info=True
|
|
61
|
+
... )
|
|
62
|
+
>>> env = MultiThreadedEnvWrapper(envpool_env)
|
|
63
|
+
>>> env.reset()
|
|
64
|
+
>>> env.rand_step()
|
|
65
|
+
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
_verbose: bool = False
|
|
69
|
+
|
|
70
|
+
@_classproperty
|
|
71
|
+
def lib(cls):
|
|
72
|
+
import envpool
|
|
73
|
+
|
|
74
|
+
return envpool
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
env: envpool.python.envpool.EnvPoolMixin | None = None, # noqa: F821
|
|
79
|
+
**kwargs,
|
|
80
|
+
):
|
|
81
|
+
if not _has_envpool:
|
|
82
|
+
raise ImportError(
|
|
83
|
+
"envpool python package or one of its dependencies (gym, treevalue) were not found. Please install these dependencies."
|
|
84
|
+
)
|
|
85
|
+
if env is not None:
|
|
86
|
+
kwargs["env"] = env
|
|
87
|
+
self.num_workers = env.config["num_envs"]
|
|
88
|
+
# For synchronous mode batch size is equal to the number of workers
|
|
89
|
+
self.batch_size = torch.Size([self.num_workers])
|
|
90
|
+
super().__init__(**kwargs)
|
|
91
|
+
|
|
92
|
+
# Buffer to keep the latest observation for each worker
|
|
93
|
+
# It's a TensorDict when the observation consists of several variables, e.g. "position" and "velocity"
|
|
94
|
+
self.obs: torch.tensor | TensorDict = self.observation_spec.zero()
|
|
95
|
+
|
|
96
|
+
def _check_kwargs(self, kwargs: dict):
|
|
97
|
+
if "env" not in kwargs:
|
|
98
|
+
raise TypeError("Could not find environment key 'env' in kwargs.")
|
|
99
|
+
env = kwargs["env"]
|
|
100
|
+
import envpool
|
|
101
|
+
|
|
102
|
+
if not isinstance(env, (envpool.python.envpool.EnvPoolMixin,)):
|
|
103
|
+
raise TypeError("env is not of type 'envpool.python.envpool.EnvPoolMixin'.")
|
|
104
|
+
|
|
105
|
+
def _build_env(self, env: envpool.python.envpool.EnvPoolMixin): # noqa: F821
|
|
106
|
+
return env
|
|
107
|
+
|
|
108
|
+
def _make_specs(
|
|
109
|
+
self, env: envpool.python.envpool.EnvPoolMixin # noqa: F821
|
|
110
|
+
) -> None: # noqa: F821
|
|
111
|
+
from torchrl.envs.libs.gym import set_gym_backend
|
|
112
|
+
|
|
113
|
+
with set_gym_backend("gym"):
|
|
114
|
+
self.action_spec = self._get_action_spec()
|
|
115
|
+
output_spec = self._get_output_spec()
|
|
116
|
+
self.observation_spec = output_spec["full_observation_spec"]
|
|
117
|
+
self.reward_spec = output_spec["full_reward_spec"]
|
|
118
|
+
self.done_spec = output_spec["full_done_spec"]
|
|
119
|
+
|
|
120
|
+
def _init_env(self) -> int | None:
|
|
121
|
+
pass
|
|
122
|
+
|
|
123
|
+
def _reset(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
124
|
+
if tensordict is not None:
|
|
125
|
+
reset_workers = tensordict.get("_reset", None)
|
|
126
|
+
else:
|
|
127
|
+
reset_workers = None
|
|
128
|
+
if reset_workers is not None:
|
|
129
|
+
reset_data = self._env.reset(np.where(reset_workers.cpu().numpy())[0])
|
|
130
|
+
else:
|
|
131
|
+
reset_data = self._env.reset()
|
|
132
|
+
tensordict_out = self._transform_reset_output(reset_data, reset_workers)
|
|
133
|
+
self.is_closed = False
|
|
134
|
+
return tensordict_out
|
|
135
|
+
|
|
136
|
+
@torch.no_grad()
|
|
137
|
+
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
138
|
+
action = tensordict.get(self.action_key)
|
|
139
|
+
# Action needs to be moved to CPU and converted to numpy before being passed to envpool
|
|
140
|
+
action = action.to(torch.device("cpu"))
|
|
141
|
+
step_output = self._env.step(action.numpy())
|
|
142
|
+
tensordict_out = self._transform_step_output(step_output)
|
|
143
|
+
return tensordict_out
|
|
144
|
+
|
|
145
|
+
def step_and_maybe_reset(
|
|
146
|
+
self, tensordict: TensorDictBase
|
|
147
|
+
) -> tuple[TensorDictBase, TensorDictBase]:
|
|
148
|
+
"""Runs a step and handles envpool's internal auto-reset.
|
|
149
|
+
|
|
150
|
+
EnvPool auto-resets internally when episodes end. When done=True:
|
|
151
|
+
- The observation returned is the final observation of the ending episode
|
|
152
|
+
- The NEXT call to step() returns the first observation of a new episode
|
|
153
|
+
|
|
154
|
+
This method handles this by skipping explicit reset() calls for done
|
|
155
|
+
environments. EnvPool maintains its own internal state, so the next
|
|
156
|
+
step() will automatically return the reset observation.
|
|
157
|
+
|
|
158
|
+
Note: The observation in tensordict_ for done envs will be the final
|
|
159
|
+
observation (not the reset observation). This is acceptable because
|
|
160
|
+
envpool ignores the input observation and uses its internal state.
|
|
161
|
+
"""
|
|
162
|
+
# Perform the step
|
|
163
|
+
tensordict = self.step(tensordict)
|
|
164
|
+
|
|
165
|
+
# Move data from "next" to root for the next iteration
|
|
166
|
+
tensordict_ = self._step_mdp(tensordict)
|
|
167
|
+
|
|
168
|
+
# EnvPool auto-resets internally, so we skip calling reset().
|
|
169
|
+
# However, we need to clear the done flags in tensordict_ since envpool
|
|
170
|
+
# has already reset those environments. The next step() will return
|
|
171
|
+
# the reset observations automatically.
|
|
172
|
+
for key in self.done_keys:
|
|
173
|
+
if key in tensordict_.keys(True):
|
|
174
|
+
tensordict_.set(key, torch.zeros_like(tensordict_.get(key)))
|
|
175
|
+
|
|
176
|
+
return tensordict, tensordict_
|
|
177
|
+
|
|
178
|
+
def _get_action_spec(self) -> TensorSpec:
|
|
179
|
+
# local import to avoid importing gym in the script
|
|
180
|
+
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
|
181
|
+
|
|
182
|
+
# Envpool provides Gym-compatible specs as env.spec.action_space and
|
|
183
|
+
# DM_Control-compatible specs as env.spec.action_spec(). We use the Gym ones.
|
|
184
|
+
|
|
185
|
+
# Gym specs produced by EnvPool don't contain batch_size, we add it to satisfy checks in EnvBase
|
|
186
|
+
action_spec = _gym_to_torchrl_spec_transform(
|
|
187
|
+
self._env.spec.action_space,
|
|
188
|
+
device=self.device,
|
|
189
|
+
categorical_action_encoding=True,
|
|
190
|
+
)
|
|
191
|
+
action_spec = self._add_shape_to_spec(action_spec)
|
|
192
|
+
return action_spec
|
|
193
|
+
|
|
194
|
+
def _get_output_spec(self) -> TensorSpec:
|
|
195
|
+
return Composite(
|
|
196
|
+
full_observation_spec=self._get_observation_spec(),
|
|
197
|
+
full_reward_spec=self._get_reward_spec(),
|
|
198
|
+
full_done_spec=self._get_done_spec(),
|
|
199
|
+
shape=(self.num_workers,),
|
|
200
|
+
device=self.device,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def _get_observation_spec(self) -> TensorSpec:
|
|
204
|
+
# local import to avoid importing gym in the script
|
|
205
|
+
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
|
206
|
+
|
|
207
|
+
# Gym specs produced by EnvPool don't contain batch_size, we add it to satisfy checks in EnvBase
|
|
208
|
+
observation_spec = _gym_to_torchrl_spec_transform(
|
|
209
|
+
self._env.spec.observation_space,
|
|
210
|
+
device=self.device,
|
|
211
|
+
categorical_action_encoding=True,
|
|
212
|
+
)
|
|
213
|
+
observation_spec = self._add_shape_to_spec(observation_spec)
|
|
214
|
+
if isinstance(observation_spec, Composite):
|
|
215
|
+
return observation_spec
|
|
216
|
+
return Composite(
|
|
217
|
+
observation=observation_spec,
|
|
218
|
+
shape=(self.num_workers,),
|
|
219
|
+
device=self.device,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
def _add_shape_to_spec(self, spec: TensorSpec) -> TensorSpec:
|
|
223
|
+
return spec.expand((self.num_workers, *spec.shape))
|
|
224
|
+
|
|
225
|
+
def _get_reward_spec(self) -> TensorSpec:
|
|
226
|
+
return Unbounded(
|
|
227
|
+
device=self.device,
|
|
228
|
+
shape=self.batch_size,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
def _get_done_spec(self) -> TensorSpec:
|
|
232
|
+
spec = Categorical(
|
|
233
|
+
2,
|
|
234
|
+
device=self.device,
|
|
235
|
+
shape=self.batch_size,
|
|
236
|
+
dtype=torch.bool,
|
|
237
|
+
)
|
|
238
|
+
return Composite(
|
|
239
|
+
done=spec,
|
|
240
|
+
truncated=spec.clone(),
|
|
241
|
+
terminated=spec.clone(),
|
|
242
|
+
shape=self.batch_size,
|
|
243
|
+
device=self.device,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
def __repr__(self) -> str:
|
|
247
|
+
return f"{self.__class__.__name__}(num_workers={self.num_workers}, device={self.device})"
|
|
248
|
+
|
|
249
|
+
def _transform_reset_output(
|
|
250
|
+
self,
|
|
251
|
+
envpool_output: tuple[treevalue.TreeValue | np.ndarray, Any], # noqa: F821
|
|
252
|
+
reset_workers: torch.Tensor | None,
|
|
253
|
+
):
|
|
254
|
+
"""Process output of envpool env.reset."""
|
|
255
|
+
import treevalue
|
|
256
|
+
|
|
257
|
+
observation, _ = envpool_output
|
|
258
|
+
if reset_workers is not None:
|
|
259
|
+
# Only specified workers were reset - need to set observation buffer values only for them
|
|
260
|
+
if isinstance(observation, treevalue.TreeValue):
|
|
261
|
+
# If observation contain several fields, it will be returned as treevalue.TreeValue.
|
|
262
|
+
# Convert to treevalue.FastTreeValue to allow indexing
|
|
263
|
+
observation = treevalue.FastTreeValue(observation)
|
|
264
|
+
self.obs[reset_workers] = self._treevalue_or_numpy_to_tensor_or_dict(
|
|
265
|
+
observation
|
|
266
|
+
)
|
|
267
|
+
else:
|
|
268
|
+
# All workers were reset - rewrite the whole observation buffer
|
|
269
|
+
self.obs = TensorDict(
|
|
270
|
+
self._treevalue_or_numpy_to_tensor_or_dict(observation),
|
|
271
|
+
self.batch_size,
|
|
272
|
+
device=self.device,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
obs = self.obs.clone(False)
|
|
276
|
+
obs.update(self.full_done_spec.zero())
|
|
277
|
+
return obs
|
|
278
|
+
|
|
279
|
+
def _transform_step_output(
|
|
280
|
+
self, envpool_output: tuple[Any, Any, Any, ...]
|
|
281
|
+
) -> TensorDict:
|
|
282
|
+
"""Process output of envpool env.step."""
|
|
283
|
+
out = envpool_output
|
|
284
|
+
if len(out) == 4:
|
|
285
|
+
obs, reward, done, info = out
|
|
286
|
+
terminated = done
|
|
287
|
+
truncated = info.get("TimeLimit.truncated", done * 0)
|
|
288
|
+
elif len(out) == 5:
|
|
289
|
+
obs, reward, terminated, truncated, info = out
|
|
290
|
+
done = terminated | truncated
|
|
291
|
+
else:
|
|
292
|
+
raise TypeError(
|
|
293
|
+
f"The output of step was had {len(out)} elements, but only 4 or 5 are supported."
|
|
294
|
+
)
|
|
295
|
+
obs = self._treevalue_or_numpy_to_tensor_or_dict(obs)
|
|
296
|
+
reward_and_done = {self.reward_key: torch.as_tensor(reward)}
|
|
297
|
+
reward_and_done["done"] = done
|
|
298
|
+
reward_and_done["terminated"] = terminated
|
|
299
|
+
reward_and_done["truncated"] = truncated
|
|
300
|
+
obs.update(reward_and_done)
|
|
301
|
+
self.obs = tensordict_out = TensorDict(
|
|
302
|
+
obs,
|
|
303
|
+
batch_size=self.batch_size,
|
|
304
|
+
device=self.device,
|
|
305
|
+
)
|
|
306
|
+
return tensordict_out
|
|
307
|
+
|
|
308
|
+
def _treevalue_or_numpy_to_tensor_or_dict(
|
|
309
|
+
self, x: treevalue.TreeValue | np.ndarray # noqa: F821
|
|
310
|
+
) -> torch.Tensor | dict[str, torch.Tensor]:
|
|
311
|
+
"""Converts observation returned by EnvPool.
|
|
312
|
+
|
|
313
|
+
EnvPool step and reset return observation as a numpy array or a TreeValue of numpy arrays, which we convert
|
|
314
|
+
to a tensor or a dictionary of tensors. Currently only supports depth 1 trees, but can easily be extended to
|
|
315
|
+
arbitrary depth if necessary.
|
|
316
|
+
"""
|
|
317
|
+
import treevalue
|
|
318
|
+
|
|
319
|
+
if isinstance(x, treevalue.TreeValue):
|
|
320
|
+
ret = self._treevalue_to_dict(x)
|
|
321
|
+
elif not isinstance(x, dict):
|
|
322
|
+
ret = {"observation": torch.as_tensor(x)}
|
|
323
|
+
else:
|
|
324
|
+
ret = x
|
|
325
|
+
return ret
|
|
326
|
+
|
|
327
|
+
def _treevalue_to_dict(
|
|
328
|
+
self, tv: treevalue.TreeValue # noqa: F821
|
|
329
|
+
) -> dict[str, Any]:
|
|
330
|
+
"""Converts TreeValue to a dictionary.
|
|
331
|
+
|
|
332
|
+
Currently only supports depth 1 trees, but can easily be extended to arbitrary depth if necessary.
|
|
333
|
+
"""
|
|
334
|
+
import treevalue
|
|
335
|
+
|
|
336
|
+
return {k[0]: torch.as_tensor(v) for k, v in treevalue.flatten(tv)}
|
|
337
|
+
|
|
338
|
+
def _set_seed(self, seed: int | None) -> None:
|
|
339
|
+
if seed is not None:
|
|
340
|
+
torchrl_logger.info(
|
|
341
|
+
"MultiThreadedEnvWrapper._set_seed ignored, as setting seed in an existing envorinment is not\
|
|
342
|
+
supported by envpool. Please create a new environment, passing the seed to the constructor."
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
class MultiThreadedEnv(MultiThreadedEnvWrapper):
|
|
347
|
+
"""Multithreaded execution of environments based on EnvPool.
|
|
348
|
+
|
|
349
|
+
GitHub: https://github.com/sail-sg/envpool
|
|
350
|
+
|
|
351
|
+
Paper: https://arxiv.org/abs/2206.10558
|
|
352
|
+
|
|
353
|
+
An alternative to ParallelEnv based on multithreading. It's faster, as it doesn't require new process spawning, but
|
|
354
|
+
less flexible, as it only supports environments implemented in EnvPool library.
|
|
355
|
+
Currently, only supports synchronous execution mode, when the batch size is equal to the number of workers, see
|
|
356
|
+
https://envpool.readthedocs.io/en/latest/content/python_interface.html#batch-size.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
num_workers (int): The number of envs to run simultaneously. Will be
|
|
360
|
+
identical to the content of `~.batch_size`.
|
|
361
|
+
env_name (str): name of the environment to build.
|
|
362
|
+
|
|
363
|
+
Keyword Args:
|
|
364
|
+
create_env_kwargs (Dict[str, Any], optional): kwargs to be passed to envpool
|
|
365
|
+
environment constructor.
|
|
366
|
+
categorical_action_encoding (bool, optional): if ``True``, categorical
|
|
367
|
+
specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`),
|
|
368
|
+
otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`).
|
|
369
|
+
Defaults to ``False``.
|
|
370
|
+
disable_env_checker (bool, optional): for gym > 0.24 only. If ``True`` (default
|
|
371
|
+
for these versions), the environment checker won't be run.
|
|
372
|
+
frame_skip (int, optional): if provided, indicates for how many steps the
|
|
373
|
+
same action is to be repeated. The observation returned will be the
|
|
374
|
+
last observation of the sequence, whereas the reward will be the sum
|
|
375
|
+
of rewards across steps.
|
|
376
|
+
device (torch.device, optional): if provided, the device on which the data
|
|
377
|
+
is to be cast. Defaults to ``torch.device("cpu")``.
|
|
378
|
+
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
|
|
379
|
+
for envs to be ``done`` just after :meth:`reset` is called.
|
|
380
|
+
Defaults to ``False``.
|
|
381
|
+
|
|
382
|
+
Examples:
|
|
383
|
+
>>> env = MultiThreadedEnv(num_workers=3, env_name="Pendulum-v1")
|
|
384
|
+
>>> env.reset()
|
|
385
|
+
>>> env.rand_step()
|
|
386
|
+
>>> env.rollout(5)
|
|
387
|
+
>>> env.close()
|
|
388
|
+
|
|
389
|
+
"""
|
|
390
|
+
|
|
391
|
+
def __init__(
|
|
392
|
+
self,
|
|
393
|
+
num_workers: int,
|
|
394
|
+
env_name: str,
|
|
395
|
+
*,
|
|
396
|
+
create_env_kwargs: dict[str, Any] | None = None,
|
|
397
|
+
**kwargs,
|
|
398
|
+
):
|
|
399
|
+
self.env_name = env_name.replace("ALE/", "") # Naming convention of EnvPool
|
|
400
|
+
self.num_workers = num_workers
|
|
401
|
+
self.batch_size = torch.Size([num_workers])
|
|
402
|
+
self.create_env_kwargs = create_env_kwargs or {}
|
|
403
|
+
|
|
404
|
+
kwargs["num_workers"] = num_workers
|
|
405
|
+
kwargs["env_name"] = self.env_name
|
|
406
|
+
kwargs["create_env_kwargs"] = create_env_kwargs
|
|
407
|
+
super().__init__(**kwargs)
|
|
408
|
+
|
|
409
|
+
def _build_env(
|
|
410
|
+
self,
|
|
411
|
+
env_name: str,
|
|
412
|
+
num_workers: int,
|
|
413
|
+
create_env_kwargs: dict[str, Any] | None,
|
|
414
|
+
) -> Any:
|
|
415
|
+
import envpool
|
|
416
|
+
|
|
417
|
+
create_env_kwargs = create_env_kwargs or {}
|
|
418
|
+
# EnvPool requires max_num_players to be set for single-player environments
|
|
419
|
+
if "max_num_players" not in create_env_kwargs:
|
|
420
|
+
create_env_kwargs["max_num_players"] = 1
|
|
421
|
+
env = envpool.make(
|
|
422
|
+
task_id=env_name,
|
|
423
|
+
env_type="gym",
|
|
424
|
+
num_envs=num_workers,
|
|
425
|
+
gym_reset_return_info=True,
|
|
426
|
+
**create_env_kwargs,
|
|
427
|
+
)
|
|
428
|
+
return super()._build_env(env)
|
|
429
|
+
|
|
430
|
+
def _set_seed(self, seed: int | None) -> None:
|
|
431
|
+
"""Library EnvPool only supports setting a seed by recreating the environment."""
|
|
432
|
+
if seed is not None:
|
|
433
|
+
torchrl_logger.debug("Recreating EnvPool environment to set seed.")
|
|
434
|
+
self.create_env_kwargs["seed"] = seed
|
|
435
|
+
self._env = self._build_env(
|
|
436
|
+
env_name=self.env_name,
|
|
437
|
+
num_workers=self.num_workers,
|
|
438
|
+
create_env_kwargs=self.create_env_kwargs,
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
def _check_kwargs(self, kwargs: dict):
|
|
442
|
+
for arg in ["num_workers", "env_name", "create_env_kwargs"]:
|
|
443
|
+
if arg not in kwargs:
|
|
444
|
+
raise TypeError(f"Expected '{arg}' to be part of kwargs")
|
|
445
|
+
|
|
446
|
+
def __repr__(self) -> str:
|
|
447
|
+
return f"{self.__class__.__name__}(env={self.env_name}, num_workers={self.num_workers}, device={self.device})"
|