torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from tensordict import TensorDict
|
|
8
|
+
|
|
9
|
+
from torchrl.data.tensor_specs import Categorical, Composite, Unbounded
|
|
10
|
+
from torchrl.envs.common import _EnvWrapper
|
|
11
|
+
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform, set_gym_backend
|
|
12
|
+
from torchrl.envs.utils import _classproperty
|
|
13
|
+
|
|
14
|
+
__all__ = ["ProcgenWrapper", "ProcgenEnv"]
|
|
15
|
+
|
|
16
|
+
_has_procgen = importlib.util.find_spec("procgen") is not None
|
|
17
|
+
|
|
18
|
+
if _has_procgen:
|
|
19
|
+
import procgen # type: ignore
|
|
20
|
+
else:
|
|
21
|
+
procgen = None # type: ignore
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _get_procgen_envs() -> list[str]:
|
|
25
|
+
if not _has_procgen:
|
|
26
|
+
raise ImportError("procgen is not installed.")
|
|
27
|
+
env_names = getattr(procgen, "ENV_NAMES", None)
|
|
28
|
+
if env_names:
|
|
29
|
+
return list(env_names)
|
|
30
|
+
try:
|
|
31
|
+
env_mod = importlib.import_module("procgen.env")
|
|
32
|
+
return list(getattr(env_mod, "ENV_NAMES", []))
|
|
33
|
+
except Exception:
|
|
34
|
+
return list(getattr(procgen, "ENV_NAMES", []))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _get_num_envs(env) -> int | None:
|
|
38
|
+
"""Get the number of parallel environments from a procgen env."""
|
|
39
|
+
# procgen.ProcgenEnv returns a ToGymEnv wrapper; the num attribute
|
|
40
|
+
# may be on the wrapper, the inner env (.env), or as num_envs
|
|
41
|
+
return (
|
|
42
|
+
getattr(env, "num", None)
|
|
43
|
+
or getattr(env, "nenvs", None)
|
|
44
|
+
or getattr(env, "num_envs", None)
|
|
45
|
+
or getattr(getattr(env, "env", None), "num", None)
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class ProcgenWrapper(_EnvWrapper):
|
|
50
|
+
"""OpenAI Procgen environment wrapper.
|
|
51
|
+
|
|
52
|
+
Wraps an existing :class:`procgen.ProcgenEnv` instance and exposes it
|
|
53
|
+
under the TorchRL environment API.
|
|
54
|
+
|
|
55
|
+
This wrapper is responsible for:
|
|
56
|
+
- Converting Procgen observations (``{"rgb": np.ndarray}``) to Torch tensors
|
|
57
|
+
- Handling vectorized Procgen semantics
|
|
58
|
+
- Producing TorchRL-compliant ``TensorDict`` outputs
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
env (procgen.ProcgenEnv): an already constructed Procgen environment.
|
|
62
|
+
|
|
63
|
+
Keyword Args:
|
|
64
|
+
device (torch.device | str, optional): device on which tensors are placed.
|
|
65
|
+
batch_size (torch.Size, optional): expected batch size.
|
|
66
|
+
allow_done_after_reset (bool, optional): tolerate done right after reset.
|
|
67
|
+
|
|
68
|
+
Attributes:
|
|
69
|
+
available_envs (List[str]): list of Procgen environment ids.
|
|
70
|
+
|
|
71
|
+
Examples:
|
|
72
|
+
>>> import procgen
|
|
73
|
+
>>> from torchrl.envs.libs.procgen import ProcgenWrapper
|
|
74
|
+
>>> env = procgen.ProcgenEnv(4, "coinrun")
|
|
75
|
+
>>> env = ProcgenWrapper(env=env)
|
|
76
|
+
>>> td = env.reset()
|
|
77
|
+
>>> print(td)
|
|
78
|
+
TensorDict(
|
|
79
|
+
fields={
|
|
80
|
+
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
81
|
+
observation: Tensor(shape=torch.Size([4, 3, 64, 64]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
82
|
+
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
83
|
+
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
84
|
+
truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
85
|
+
batch_size=torch.Size([]),
|
|
86
|
+
device=None,
|
|
87
|
+
is_shared=False
|
|
88
|
+
)
|
|
89
|
+
>>> print(td["observation"].shape)
|
|
90
|
+
torch.Size([4, 3, 64, 64])
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
git_url = "https://github.com/openai/procgen"
|
|
94
|
+
lib = procgen
|
|
95
|
+
|
|
96
|
+
@_classproperty
|
|
97
|
+
def available_envs(cls) -> list[str]:
|
|
98
|
+
if not _has_procgen:
|
|
99
|
+
return []
|
|
100
|
+
return _get_procgen_envs()
|
|
101
|
+
|
|
102
|
+
def __init__(self, env, **kwargs):
|
|
103
|
+
# Detect num_envs before calling parent __init__ so batch_size is set
|
|
104
|
+
# before _make_specs() is called
|
|
105
|
+
n = _get_num_envs(env)
|
|
106
|
+
if n is not None and "batch_size" not in kwargs:
|
|
107
|
+
kwargs["batch_size"] = torch.Size([n])
|
|
108
|
+
super().__init__(env=env, **kwargs)
|
|
109
|
+
|
|
110
|
+
def _check_kwargs(self, kwargs: dict) -> None:
|
|
111
|
+
if "env" not in kwargs:
|
|
112
|
+
raise TypeError("ProcgenWrapper requires an 'env' argument.")
|
|
113
|
+
|
|
114
|
+
def _build_env(self, env, **_) -> procgen.ProcgenEnv:
|
|
115
|
+
return env
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def observation_space(self):
|
|
119
|
+
# gym3 uses ob_space instead of observation_space
|
|
120
|
+
return getattr(self._env, "observation_space", None) or self._env.ob_space
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def action_space(self):
|
|
124
|
+
# gym3 uses ac_space instead of action_space
|
|
125
|
+
return getattr(self._env, "action_space", None) or self._env.ac_space
|
|
126
|
+
|
|
127
|
+
def _make_specs(self, env) -> None:
|
|
128
|
+
from torchrl.data.tensor_specs import Bounded
|
|
129
|
+
|
|
130
|
+
batch_size = self.batch_size
|
|
131
|
+
|
|
132
|
+
# Procgen observation is rgb with shape (64, 64, 3) per env
|
|
133
|
+
# After permuting in _reset/_step it becomes (3, 64, 64) per env
|
|
134
|
+
# With batch_size, full shape is (*batch_size, 3, 64, 64)
|
|
135
|
+
self.observation_spec = Composite(
|
|
136
|
+
observation=Bounded(
|
|
137
|
+
low=0,
|
|
138
|
+
high=255,
|
|
139
|
+
shape=(*batch_size, 3, 64, 64),
|
|
140
|
+
dtype=torch.uint8,
|
|
141
|
+
device=self.device,
|
|
142
|
+
),
|
|
143
|
+
shape=batch_size,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Procgen has Discrete(15) action space
|
|
147
|
+
with set_gym_backend("gym"):
|
|
148
|
+
action_spec = _gym_to_torchrl_spec_transform(
|
|
149
|
+
self.action_space,
|
|
150
|
+
categorical_action_encoding=True,
|
|
151
|
+
device=self.device,
|
|
152
|
+
)
|
|
153
|
+
# Expand action spec to include batch dimension
|
|
154
|
+
if len(batch_size) > 0 and action_spec.shape[: len(batch_size)] != batch_size:
|
|
155
|
+
action_spec = action_spec.expand(*batch_size, *action_spec.shape)
|
|
156
|
+
self.action_spec = action_spec
|
|
157
|
+
|
|
158
|
+
self.reward_spec = Composite(
|
|
159
|
+
reward=Unbounded(
|
|
160
|
+
shape=(*batch_size, 1), dtype=torch.float32, device=self.device
|
|
161
|
+
),
|
|
162
|
+
shape=batch_size,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
done_leaf = Categorical(
|
|
166
|
+
n=2, shape=(*batch_size, 1), dtype=torch.bool, device=self.device
|
|
167
|
+
)
|
|
168
|
+
self.done_spec = Composite(
|
|
169
|
+
done=done_leaf.clone(),
|
|
170
|
+
terminated=done_leaf.clone(),
|
|
171
|
+
truncated=done_leaf.clone(),
|
|
172
|
+
shape=batch_size,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def _init_env(self) -> None:
|
|
176
|
+
# batch_size is set in __init__ before _make_specs() is called
|
|
177
|
+
try:
|
|
178
|
+
self._env.reset()
|
|
179
|
+
except Exception:
|
|
180
|
+
pass
|
|
181
|
+
|
|
182
|
+
def _set_seed(self, seed: int | None) -> None:
|
|
183
|
+
if seed is None:
|
|
184
|
+
return
|
|
185
|
+
try:
|
|
186
|
+
if hasattr(self._env, "seed"):
|
|
187
|
+
self._env.seed(seed)
|
|
188
|
+
elif hasattr(self._env, "set_seed"):
|
|
189
|
+
self._env.set_seed(seed)
|
|
190
|
+
elif hasattr(self._env, "rand_seed"):
|
|
191
|
+
self._env.rand_seed = seed
|
|
192
|
+
except Exception:
|
|
193
|
+
warnings.warn("ProcgenWrapper: seeding failed (best-effort).")
|
|
194
|
+
|
|
195
|
+
def _reset(self, tensordict=None, **kwargs) -> TensorDict:
|
|
196
|
+
obs = self._env.reset()
|
|
197
|
+
if isinstance(obs, (tuple, list)):
|
|
198
|
+
obs = obs[0]
|
|
199
|
+
|
|
200
|
+
rgb = torch.from_numpy(obs["rgb"]).to(self.device).permute(0, 3, 1, 2)
|
|
201
|
+
|
|
202
|
+
td = TensorDict(
|
|
203
|
+
{"observation": rgb},
|
|
204
|
+
batch_size=self.batch_size,
|
|
205
|
+
device=self.device,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Set done flags (required by TorchRL)
|
|
209
|
+
zeros = torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.bool)
|
|
210
|
+
td.set("done", zeros)
|
|
211
|
+
td.set("terminated", zeros.clone())
|
|
212
|
+
td.set("truncated", zeros.clone())
|
|
213
|
+
|
|
214
|
+
return td
|
|
215
|
+
|
|
216
|
+
def _step(self, tensordict: TensorDict, **kwargs) -> TensorDict:
|
|
217
|
+
action = tensordict.get("action")
|
|
218
|
+
# Procgen expects numpy arrays with shape (num_envs,)
|
|
219
|
+
action_np = action.cpu().numpy().flatten()
|
|
220
|
+
obs, reward, done, info = self._env.step(action_np)
|
|
221
|
+
|
|
222
|
+
rgb = torch.from_numpy(obs["rgb"]).to(self.device).permute(0, 3, 1, 2)
|
|
223
|
+
reward = torch.as_tensor(reward, device=self.device).view(-1, 1)
|
|
224
|
+
done = torch.as_tensor(done, device=self.device).view(-1, 1).bool()
|
|
225
|
+
|
|
226
|
+
td = TensorDict(
|
|
227
|
+
{
|
|
228
|
+
"observation": rgb,
|
|
229
|
+
"reward": reward,
|
|
230
|
+
"done": done,
|
|
231
|
+
"terminated": done.clone(),
|
|
232
|
+
"truncated": torch.zeros_like(done),
|
|
233
|
+
},
|
|
234
|
+
batch_size=self.batch_size,
|
|
235
|
+
device=self.device,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Expose info dict fields (e.g., level_seed, prev_level_complete)
|
|
239
|
+
# Note: procgen may return info as a list of dicts or a single dict
|
|
240
|
+
if info and isinstance(info, dict):
|
|
241
|
+
for key, val in info.items():
|
|
242
|
+
td.set(key, torch.as_tensor(val, device=self.device))
|
|
243
|
+
|
|
244
|
+
return td
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
class ProcgenEnv(ProcgenWrapper):
|
|
248
|
+
"""OpenAI Procgen environment.
|
|
249
|
+
|
|
250
|
+
Convenience class that constructs a Procgen environment by name.
|
|
251
|
+
|
|
252
|
+
See https://github.com/openai/procgen for more details on Procgen.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
env_name (str): name of the Procgen game (e.g. ``"coinrun"``).
|
|
256
|
+
Available games: bigfish, bossfight, caveflyer, chaser, climber,
|
|
257
|
+
coinrun, dodgeball, fruitbot, heist, jumper, leaper, maze, miner,
|
|
258
|
+
ninja, plunder, starpilot.
|
|
259
|
+
|
|
260
|
+
Keyword Args:
|
|
261
|
+
num_envs (int, optional): number of parallel environments. Defaults to 1.
|
|
262
|
+
distribution_mode (str, optional): Procgen distribution mode. One of
|
|
263
|
+
``"easy"``, ``"hard"``, ``"extreme"``, ``"memory"``, ``"exploration"``.
|
|
264
|
+
Defaults to ``"hard"``.
|
|
265
|
+
start_level (int, optional): the level id to start from. Defaults to 0.
|
|
266
|
+
num_levels (int, optional): the number of unique levels that can be
|
|
267
|
+
generated. Set to 0 for unlimited levels. Defaults to 0.
|
|
268
|
+
use_sequential_levels (bool, optional): if ``True``, levels are played
|
|
269
|
+
sequentially rather than randomly. Defaults to ``False``.
|
|
270
|
+
center_agent (bool, optional): if ``True``, observations are centered
|
|
271
|
+
on the agent. Defaults to ``True``.
|
|
272
|
+
use_backgrounds (bool, optional): if ``True``, include background
|
|
273
|
+
assets. Defaults to ``True``.
|
|
274
|
+
use_monochrome_assets (bool, optional): if ``True``, use monochrome
|
|
275
|
+
assets for simpler visuals. Defaults to ``False``.
|
|
276
|
+
restrict_themes (bool, optional): if ``True``, restrict visual themes.
|
|
277
|
+
Defaults to ``False``.
|
|
278
|
+
use_generated_assets (bool, optional): if ``True``, use procedurally
|
|
279
|
+
generated assets. Defaults to ``False``.
|
|
280
|
+
paint_vel_info (bool, optional): if ``True``, paint velocity info on
|
|
281
|
+
observations. Defaults to ``False``.
|
|
282
|
+
seed (int, optional): random seed for the environment. Note that procgen
|
|
283
|
+
environments must be seeded at construction time; calling ``set_seed()``
|
|
284
|
+
after construction will not work reliably.
|
|
285
|
+
render_mode (str, optional): render mode for the environment.
|
|
286
|
+
device (torch.device | str, optional): device for tensors.
|
|
287
|
+
allow_done_after_reset (bool, optional): tolerate done after reset.
|
|
288
|
+
|
|
289
|
+
Examples:
|
|
290
|
+
>>> from torchrl.envs.libs.procgen import ProcgenEnv
|
|
291
|
+
>>> env = ProcgenEnv("coinrun", num_envs=8)
|
|
292
|
+
>>> td = env.reset()
|
|
293
|
+
>>> print(td)
|
|
294
|
+
TensorDict(
|
|
295
|
+
fields={
|
|
296
|
+
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
297
|
+
observation: Tensor(shape=torch.Size([8, 3, 64, 64]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
298
|
+
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
299
|
+
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
300
|
+
truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
301
|
+
batch_size=torch.Size([]),
|
|
302
|
+
device=None,
|
|
303
|
+
is_shared=False
|
|
304
|
+
)
|
|
305
|
+
>>> print(td["observation"].shape)
|
|
306
|
+
torch.Size([8, 3, 64, 64])
|
|
307
|
+
>>> print(env.available_envs)
|
|
308
|
+
['bigfish', 'bossfight', 'caveflyer', 'chaser', 'climber', 'coinrun', 'dodgeball', 'fruitbot', 'heist', 'jumper', 'leaper', 'maze', 'miner', 'ninja', 'plunder', 'starpilot']
|
|
309
|
+
"""
|
|
310
|
+
|
|
311
|
+
def __init__(self, env_name: str, **kwargs):
|
|
312
|
+
if not _has_procgen:
|
|
313
|
+
raise ImportError(
|
|
314
|
+
"procgen python package was not found. "
|
|
315
|
+
"Install it from https://github.com/openai/procgen."
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
if env_name not in self.available_envs:
|
|
319
|
+
raise ValueError(
|
|
320
|
+
f"Unknown Procgen environment '{env_name}'. "
|
|
321
|
+
f"Available envs: {self.available_envs}"
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
num_envs = kwargs.pop("num_envs", 1)
|
|
325
|
+
# Procgen uses rand_seed for seeding at construction time
|
|
326
|
+
seed = kwargs.pop("seed", None)
|
|
327
|
+
if seed is not None:
|
|
328
|
+
kwargs["rand_seed"] = seed
|
|
329
|
+
# Extract procgen-specific kwargs before passing to parent
|
|
330
|
+
procgen_kwargs = {}
|
|
331
|
+
for key in list(kwargs.keys()):
|
|
332
|
+
if key in (
|
|
333
|
+
"distribution_mode",
|
|
334
|
+
"start_level",
|
|
335
|
+
"num_levels",
|
|
336
|
+
"use_sequential_levels",
|
|
337
|
+
"center_agent",
|
|
338
|
+
"use_backgrounds",
|
|
339
|
+
"use_monochrome_assets",
|
|
340
|
+
"restrict_themes",
|
|
341
|
+
"use_generated_assets",
|
|
342
|
+
"paint_vel_info",
|
|
343
|
+
"render_mode",
|
|
344
|
+
"rand_seed",
|
|
345
|
+
):
|
|
346
|
+
procgen_kwargs[key] = kwargs.pop(key)
|
|
347
|
+
env = procgen.ProcgenEnv(num_envs, env_name, **procgen_kwargs)
|
|
348
|
+
# Pass batch_size to parent; it will be set before _make_specs()
|
|
349
|
+
if "batch_size" not in kwargs:
|
|
350
|
+
kwargs["batch_size"] = torch.Size([num_envs])
|
|
351
|
+
super().__init__(env=env, **kwargs)
|