torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,1042 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import copy
|
|
8
|
+
import importlib
|
|
9
|
+
import warnings
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import packaging
|
|
13
|
+
import torch
|
|
14
|
+
from tensordict import TensorDictBase
|
|
15
|
+
|
|
16
|
+
from torchrl.data.tensor_specs import Categorical, Composite, OneHot, Unbounded
|
|
17
|
+
from torchrl.envs.common import _EnvWrapper
|
|
18
|
+
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform, set_gym_backend
|
|
19
|
+
from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType
|
|
20
|
+
|
|
21
|
+
_has_pettingzoo = importlib.util.find_spec("pettingzoo") is not None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _get_envs():
|
|
25
|
+
if not _has_pettingzoo:
|
|
26
|
+
raise ImportError("PettingZoo is not installed in your virtual environment.")
|
|
27
|
+
try:
|
|
28
|
+
from pettingzoo.utils.all_modules import all_environments
|
|
29
|
+
except ModuleNotFoundError as err:
|
|
30
|
+
warnings.warn(
|
|
31
|
+
f"PettingZoo failed to load all modules with error message {err}, trying to load individual modules."
|
|
32
|
+
)
|
|
33
|
+
all_environments = _load_available_envs()
|
|
34
|
+
|
|
35
|
+
return list(all_environments.keys())
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _load_available_envs() -> dict:
|
|
39
|
+
all_environments = {}
|
|
40
|
+
try:
|
|
41
|
+
from pettingzoo.mpe.all_modules import mpe_environments
|
|
42
|
+
|
|
43
|
+
all_environments.update(mpe_environments)
|
|
44
|
+
except ModuleNotFoundError as err:
|
|
45
|
+
warnings.warn(f"MPE environments failed to load with error message {err}.")
|
|
46
|
+
try:
|
|
47
|
+
from pettingzoo.sisl.all_modules import sisl_environments
|
|
48
|
+
|
|
49
|
+
all_environments.update(sisl_environments)
|
|
50
|
+
except ModuleNotFoundError as err:
|
|
51
|
+
warnings.warn(f"SISL environments failed to load with error message {err}.")
|
|
52
|
+
try:
|
|
53
|
+
from pettingzoo.classic.all_modules import classic_environments
|
|
54
|
+
|
|
55
|
+
all_environments.update(classic_environments)
|
|
56
|
+
except ModuleNotFoundError as err:
|
|
57
|
+
warnings.warn(f"Classic environments failed to load with error message {err}.")
|
|
58
|
+
try:
|
|
59
|
+
from pettingzoo.atari.all_modules import atari_environments
|
|
60
|
+
|
|
61
|
+
all_environments.update(atari_environments)
|
|
62
|
+
except ModuleNotFoundError as err:
|
|
63
|
+
warnings.warn(f"Atari environments failed to load with error message {err}.")
|
|
64
|
+
try:
|
|
65
|
+
from pettingzoo.butterfly.all_modules import butterfly_environments
|
|
66
|
+
|
|
67
|
+
all_environments.update(butterfly_environments)
|
|
68
|
+
except ModuleNotFoundError as err:
|
|
69
|
+
warnings.warn(
|
|
70
|
+
f"Butterfly environments failed to load with error message {err}."
|
|
71
|
+
)
|
|
72
|
+
return all_environments
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _extract_nested_with_index(data: np.ndarray | dict[str, np.ndarray], index: int):
|
|
76
|
+
if isinstance(data, np.ndarray):
|
|
77
|
+
return data[index]
|
|
78
|
+
elif isinstance(data, dict):
|
|
79
|
+
return {
|
|
80
|
+
key: _extract_nested_with_index(value, index) for key, value in data.items()
|
|
81
|
+
}
|
|
82
|
+
else:
|
|
83
|
+
raise NotImplementedError(f"Invalid type of data {data}")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class PettingZooWrapper(_EnvWrapper):
|
|
87
|
+
"""PettingZoo environment wrapper.
|
|
88
|
+
|
|
89
|
+
To install petting zoo follow the guide `here <https://github.com/Farama-Foundation/PettingZoo#installation>__`.
|
|
90
|
+
|
|
91
|
+
This class is a general torchrl wrapper for all PettingZoo environments.
|
|
92
|
+
It is able to wrap both ``pettingzoo.AECEnv`` and ``pettingzoo.ParallelEnv``.
|
|
93
|
+
|
|
94
|
+
Let's see how more in details:
|
|
95
|
+
|
|
96
|
+
In wrapped ``pettingzoo.ParallelEnv`` all agents will step at each environment step.
|
|
97
|
+
If the number of agents during the task varies, please set ``use_mask=True``.
|
|
98
|
+
``"mask"`` will be provided
|
|
99
|
+
as an output in each group and should be used to mask out dead agents.
|
|
100
|
+
The environment will be reset as soon as one agent is done (unless ``done_on_any`` is ``False``).
|
|
101
|
+
|
|
102
|
+
In wrapped ``pettingzoo.AECEnv``, at each step only one agent will act.
|
|
103
|
+
For this reason, it is compulsory to set ``use_mask=True`` for this type of environment.
|
|
104
|
+
``"mask"`` will be provided as an output for each group and can be used to mask out non-acting agents.
|
|
105
|
+
The environment will be reset only when all agents are done (unless ``done_on_any`` is ``True``).
|
|
106
|
+
|
|
107
|
+
If there are any unavailable actions for an agent,
|
|
108
|
+
the environment will also automatically update the mask of its ``action_spec`` and output an ``"action_mask"``
|
|
109
|
+
for each group to reflect the latest available actions. This should be passed to a masked distribution during
|
|
110
|
+
training.
|
|
111
|
+
|
|
112
|
+
As a feature of torchrl multiagent, you are able to control the grouping of agents in your environment.
|
|
113
|
+
You can group agents together (stacking their tensors) to leverage vectorization when passing them through the same
|
|
114
|
+
neural network. You can split agents in different groups where they are heterogenous or should be processed by
|
|
115
|
+
different neural networks. To group, you just need to pass a ``group_map`` at env constructiuon time.
|
|
116
|
+
|
|
117
|
+
By default, agents in pettingzoo will be grouped by name.
|
|
118
|
+
For example, with agents ``["agent_0","agent_1","agent_2","adversary_0"]``, the tensordicts will look like:
|
|
119
|
+
|
|
120
|
+
>>> print(env.rand_action(env.reset()))
|
|
121
|
+
TensorDict(
|
|
122
|
+
fields={
|
|
123
|
+
agent: TensorDict(
|
|
124
|
+
fields={
|
|
125
|
+
action: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
126
|
+
action_mask: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
127
|
+
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
128
|
+
observation: Tensor(shape=torch.Size([3, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False),
|
|
129
|
+
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
130
|
+
truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
131
|
+
batch_size=torch.Size([3]))},
|
|
132
|
+
adversary: TensorDict(
|
|
133
|
+
fields={
|
|
134
|
+
action: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
135
|
+
action_mask: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
136
|
+
done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
137
|
+
observation: Tensor(shape=torch.Size([1, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False),
|
|
138
|
+
terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
139
|
+
truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
140
|
+
batch_size=torch.Size([1]))},
|
|
141
|
+
batch_size=torch.Size([]))
|
|
142
|
+
>>> print(env.group_map)
|
|
143
|
+
{"agent": ["agent_0", "agent_1", "agent_2"], "adversary": ["adversary_0"]}
|
|
144
|
+
|
|
145
|
+
Otherwise, a group map can be specified or selected from some premade options.
|
|
146
|
+
See :class:`torchrl.envs.utils.MarlGroupMapType` for more info.
|
|
147
|
+
For example, you can provide ``MarlGroupMapType.ONE_GROUP_PER_AGENT``, telling that each agent should
|
|
148
|
+
have its own tensordict (similar to the pettingzoo parallel API).
|
|
149
|
+
|
|
150
|
+
Grouping is useful for leveraging vectorization among agents whose data goes through the same
|
|
151
|
+
neural network.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
env (``pettingzoo.utils.env.ParallelEnv`` or ``pettingzoo.utils.env.AECEnv``): the pettingzoo environment to wrap.
|
|
155
|
+
return_state (bool, optional): whether to return the global state from pettingzoo
|
|
156
|
+
(not available in all environments). Defaults to ``False``.
|
|
157
|
+
group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to group agents in tensordicts for
|
|
158
|
+
input/output. By default, agents will be grouped by their name. Otherwise, a group map can be specified
|
|
159
|
+
or selected from some premade options. See :class:`torchrl.envs.utils.MarlGroupMapType` for more info.
|
|
160
|
+
use_mask (bool, optional): whether the environment should output a ``"mask"``. This is compulsory in
|
|
161
|
+
wrapped ``pettingzoo.AECEnv`` to mask out non-acting agents and should be also used
|
|
162
|
+
for ``pettingzoo.ParallelEnv`` when the number of agents can vary. Defaults to ``False``.
|
|
163
|
+
categorical_actions (bool, optional): if the environments actions are discrete, whether to transform
|
|
164
|
+
them to categorical or one-hot.
|
|
165
|
+
seed (int, optional): the seed. Defaults to ``None``.
|
|
166
|
+
done_on_any (bool, optional): whether the environment's done keys are set by aggregating the agent keys
|
|
167
|
+
using ``any()`` (when ``True``) or ``all()`` (when ``False``). Default (``None``) is to use ``any()`` for
|
|
168
|
+
parallel environments and ``all()`` for AEC ones.
|
|
169
|
+
|
|
170
|
+
Examples:
|
|
171
|
+
>>> # Parallel env
|
|
172
|
+
>>> from torchrl.envs.libs.pettingzoo import PettingZooWrapper
|
|
173
|
+
>>> from pettingzoo.butterfly import pistonball_v6
|
|
174
|
+
>>> kwargs = {"n_pistons": 21, "continuous": True}
|
|
175
|
+
>>> env = PettingZooWrapper(
|
|
176
|
+
... env=pistonball_v6.parallel_env(**kwargs),
|
|
177
|
+
... return_state=True,
|
|
178
|
+
... group_map=None, # Use default for parallel (all pistons grouped together)
|
|
179
|
+
... )
|
|
180
|
+
>>> print(env.group_map)
|
|
181
|
+
... {'piston': ['piston_0', 'piston_1', ..., 'piston_20']}
|
|
182
|
+
>>> env.rollout(10)
|
|
183
|
+
>>> # AEC env
|
|
184
|
+
>>> from pettingzoo.classic import tictactoe_v3
|
|
185
|
+
>>> from torchrl.envs.libs.pettingzoo import PettingZooWrapper
|
|
186
|
+
>>> from torchrl.envs.utils import MarlGroupMapType
|
|
187
|
+
>>> env = PettingZooWrapper(
|
|
188
|
+
... env=tictactoe_v3.env(),
|
|
189
|
+
... use_mask=True, # Must use it since one player plays at a time
|
|
190
|
+
... group_map=None # # Use default for AEC (one group per player)
|
|
191
|
+
... )
|
|
192
|
+
>>> print(env.group_map)
|
|
193
|
+
... {'player_1': ['player_1'], 'player_2': ['player_2']}
|
|
194
|
+
>>> env.rollout(10)
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
git_url = "https://github.com/Farama-Foundation/PettingZoo"
|
|
198
|
+
libname = "pettingzoo"
|
|
199
|
+
|
|
200
|
+
@_classproperty
|
|
201
|
+
def available_envs(cls):
|
|
202
|
+
if not _has_pettingzoo:
|
|
203
|
+
return []
|
|
204
|
+
return list(_get_envs())
|
|
205
|
+
|
|
206
|
+
def __init__(
|
|
207
|
+
self,
|
|
208
|
+
env: (
|
|
209
|
+
pettingzoo.utils.env.ParallelEnv # noqa: F821
|
|
210
|
+
| pettingzoo.utils.env.AECEnv # noqa: F821
|
|
211
|
+
) = None,
|
|
212
|
+
return_state: bool = False,
|
|
213
|
+
group_map: MarlGroupMapType | dict[str, list[str]] | None = None,
|
|
214
|
+
use_mask: bool = False,
|
|
215
|
+
categorical_actions: bool = True,
|
|
216
|
+
seed: int | None = None,
|
|
217
|
+
done_on_any: bool | None = None,
|
|
218
|
+
**kwargs,
|
|
219
|
+
):
|
|
220
|
+
if env is not None:
|
|
221
|
+
kwargs["env"] = env
|
|
222
|
+
|
|
223
|
+
self.group_map = group_map
|
|
224
|
+
self.return_state = return_state
|
|
225
|
+
self.seed = seed
|
|
226
|
+
self.use_mask = use_mask
|
|
227
|
+
self.categorical_actions = categorical_actions
|
|
228
|
+
self.done_on_any = done_on_any
|
|
229
|
+
|
|
230
|
+
super().__init__(**kwargs, allow_done_after_reset=True)
|
|
231
|
+
|
|
232
|
+
def _get_default_group_map(self, agent_names: list[str]):
|
|
233
|
+
# This function performs the default grouping in pettingzoo
|
|
234
|
+
if not self.parallel:
|
|
235
|
+
# In AEC envs we will have one group per agent by default
|
|
236
|
+
group_map = MarlGroupMapType.ONE_GROUP_PER_AGENT.get_group_map(agent_names)
|
|
237
|
+
else:
|
|
238
|
+
# In parallel envs, by default
|
|
239
|
+
# Agents with names "str_int" will be grouped in group name "str"
|
|
240
|
+
group_map = {}
|
|
241
|
+
for agent_name in agent_names:
|
|
242
|
+
# See if the agent follows the convention "name_int"
|
|
243
|
+
follows_convention = True
|
|
244
|
+
agent_name_split = agent_name.split("_")
|
|
245
|
+
if len(agent_name_split) == 1:
|
|
246
|
+
follows_convention = False
|
|
247
|
+
try:
|
|
248
|
+
int(agent_name_split[-1])
|
|
249
|
+
except ValueError:
|
|
250
|
+
follows_convention = False
|
|
251
|
+
|
|
252
|
+
# If not, just put it in a single group
|
|
253
|
+
if not follows_convention:
|
|
254
|
+
group_map[agent_name] = [agent_name]
|
|
255
|
+
# Otherwise, group it with other agents that follow the same convention
|
|
256
|
+
else:
|
|
257
|
+
group_name = "_".join(agent_name_split[:-1])
|
|
258
|
+
if group_name in group_map:
|
|
259
|
+
group_map[group_name].append(agent_name)
|
|
260
|
+
else:
|
|
261
|
+
group_map[group_name] = [agent_name]
|
|
262
|
+
|
|
263
|
+
return group_map
|
|
264
|
+
|
|
265
|
+
@property
|
|
266
|
+
def lib(self):
|
|
267
|
+
import pettingzoo
|
|
268
|
+
|
|
269
|
+
return pettingzoo
|
|
270
|
+
|
|
271
|
+
def _build_env(
|
|
272
|
+
self,
|
|
273
|
+
env: (
|
|
274
|
+
pettingzoo.utils.env.ParallelEnv # noqa: F821
|
|
275
|
+
| pettingzoo.utils.env.AECEnv # noqa: F821
|
|
276
|
+
),
|
|
277
|
+
):
|
|
278
|
+
import pettingzoo
|
|
279
|
+
|
|
280
|
+
if packaging.version.parse(pettingzoo.__version__).base_version != "1.24.3":
|
|
281
|
+
warnings.warn(
|
|
282
|
+
"PettingZoo in TorchRL is tested using version == 1.24.3 , "
|
|
283
|
+
"If you are using a different version and are experiencing compatibility issues,"
|
|
284
|
+
"please raise an issue in the TorchRL github."
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
self.parallel = isinstance(env, pettingzoo.utils.env.ParallelEnv)
|
|
288
|
+
if not self.parallel and not self.use_mask:
|
|
289
|
+
raise ValueError("For AEC environments you need to set use_mask=True")
|
|
290
|
+
if len(self.batch_size):
|
|
291
|
+
raise RuntimeError(
|
|
292
|
+
f"PettingZoo does not support custom batch_size {self.batch_size}."
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
return env
|
|
296
|
+
|
|
297
|
+
@set_gym_backend("gymnasium")
|
|
298
|
+
def _make_specs(
|
|
299
|
+
self,
|
|
300
|
+
env: (
|
|
301
|
+
pettingzoo.utils.env.ParallelEnv # noqa: F821
|
|
302
|
+
| pettingzoo.utils.env.AECEnv # noqa: F821
|
|
303
|
+
),
|
|
304
|
+
) -> None:
|
|
305
|
+
# Set default for done on any or all
|
|
306
|
+
if self.done_on_any is None:
|
|
307
|
+
self.done_on_any = self.parallel
|
|
308
|
+
|
|
309
|
+
# Create and check group map
|
|
310
|
+
if self.group_map is None:
|
|
311
|
+
self.group_map = self._get_default_group_map(self.possible_agents)
|
|
312
|
+
elif isinstance(self.group_map, MarlGroupMapType):
|
|
313
|
+
self.group_map = self.group_map.get_group_map(self.possible_agents)
|
|
314
|
+
check_marl_grouping(self.group_map, self.possible_agents)
|
|
315
|
+
self.has_action_mask = {group: False for group in self.group_map.keys()}
|
|
316
|
+
|
|
317
|
+
action_spec = Composite()
|
|
318
|
+
observation_spec = Composite()
|
|
319
|
+
reward_spec = Composite()
|
|
320
|
+
done_spec = Composite(
|
|
321
|
+
{
|
|
322
|
+
"done": Categorical(
|
|
323
|
+
n=2,
|
|
324
|
+
shape=torch.Size((1,)),
|
|
325
|
+
dtype=torch.bool,
|
|
326
|
+
device=self.device,
|
|
327
|
+
),
|
|
328
|
+
"terminated": Categorical(
|
|
329
|
+
n=2,
|
|
330
|
+
shape=torch.Size((1,)),
|
|
331
|
+
dtype=torch.bool,
|
|
332
|
+
device=self.device,
|
|
333
|
+
),
|
|
334
|
+
"truncated": Categorical(
|
|
335
|
+
n=2,
|
|
336
|
+
shape=torch.Size((1,)),
|
|
337
|
+
dtype=torch.bool,
|
|
338
|
+
device=self.device,
|
|
339
|
+
),
|
|
340
|
+
},
|
|
341
|
+
)
|
|
342
|
+
for group, agents in self.group_map.items():
|
|
343
|
+
(
|
|
344
|
+
group_observation_spec,
|
|
345
|
+
group_action_spec,
|
|
346
|
+
group_reward_spec,
|
|
347
|
+
group_done_spec,
|
|
348
|
+
) = self._make_group_specs(group_name=group, agent_names=agents)
|
|
349
|
+
action_spec[group] = group_action_spec
|
|
350
|
+
observation_spec[group] = group_observation_spec
|
|
351
|
+
reward_spec[group] = group_reward_spec
|
|
352
|
+
done_spec[group] = group_done_spec
|
|
353
|
+
|
|
354
|
+
self.action_spec = action_spec
|
|
355
|
+
self.observation_spec = observation_spec
|
|
356
|
+
self.reward_spec = reward_spec
|
|
357
|
+
self.done_spec = done_spec
|
|
358
|
+
|
|
359
|
+
def _make_group_specs(self, group_name: str, agent_names: list[str]):
|
|
360
|
+
n_agents = len(agent_names)
|
|
361
|
+
action_specs = []
|
|
362
|
+
observation_specs = []
|
|
363
|
+
for agent in agent_names:
|
|
364
|
+
action_specs.append(
|
|
365
|
+
Composite(
|
|
366
|
+
{
|
|
367
|
+
"action": _gym_to_torchrl_spec_transform(
|
|
368
|
+
self.action_space(agent),
|
|
369
|
+
remap_state_to_observation=False,
|
|
370
|
+
categorical_action_encoding=self.categorical_actions,
|
|
371
|
+
device=self.device,
|
|
372
|
+
)
|
|
373
|
+
},
|
|
374
|
+
)
|
|
375
|
+
)
|
|
376
|
+
observation_specs.append(
|
|
377
|
+
Composite(
|
|
378
|
+
{
|
|
379
|
+
"observation": _gym_to_torchrl_spec_transform(
|
|
380
|
+
self.observation_space(agent),
|
|
381
|
+
remap_state_to_observation=False,
|
|
382
|
+
device=self.device,
|
|
383
|
+
)
|
|
384
|
+
}
|
|
385
|
+
)
|
|
386
|
+
)
|
|
387
|
+
group_action_spec = torch.stack(action_specs, dim=0)
|
|
388
|
+
group_observation_spec = torch.stack(observation_specs, dim=0)
|
|
389
|
+
|
|
390
|
+
# Sometimes the observation spec contains an action mask.
|
|
391
|
+
# Or sometimes the info spec contains an action mask.
|
|
392
|
+
# We uniform this by removing it from both places and optionally set it in a standard location.
|
|
393
|
+
group_observation_inner_spec = group_observation_spec["observation"]
|
|
394
|
+
if (
|
|
395
|
+
isinstance(group_observation_inner_spec, Composite)
|
|
396
|
+
and "action_mask" in group_observation_inner_spec.keys()
|
|
397
|
+
):
|
|
398
|
+
self.has_action_mask[group_name] = True
|
|
399
|
+
del group_observation_inner_spec["action_mask"]
|
|
400
|
+
group_observation_spec["action_mask"] = Categorical(
|
|
401
|
+
n=2,
|
|
402
|
+
shape=group_action_spec["action"].shape
|
|
403
|
+
if not self.categorical_actions
|
|
404
|
+
else group_action_spec["action"].to_one_hot_spec().shape,
|
|
405
|
+
dtype=torch.bool,
|
|
406
|
+
device=self.device,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
if self.use_mask:
|
|
410
|
+
group_observation_spec["mask"] = Categorical(
|
|
411
|
+
n=2,
|
|
412
|
+
shape=torch.Size((n_agents,)),
|
|
413
|
+
dtype=torch.bool,
|
|
414
|
+
device=self.device,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
group_reward_spec = Composite(
|
|
418
|
+
{
|
|
419
|
+
"reward": Unbounded(
|
|
420
|
+
shape=torch.Size((n_agents, 1)),
|
|
421
|
+
device=self.device,
|
|
422
|
+
dtype=torch.float32,
|
|
423
|
+
)
|
|
424
|
+
},
|
|
425
|
+
shape=torch.Size((n_agents,)),
|
|
426
|
+
)
|
|
427
|
+
group_done_spec = Composite(
|
|
428
|
+
{
|
|
429
|
+
"done": Categorical(
|
|
430
|
+
n=2,
|
|
431
|
+
shape=torch.Size((n_agents, 1)),
|
|
432
|
+
dtype=torch.bool,
|
|
433
|
+
device=self.device,
|
|
434
|
+
),
|
|
435
|
+
"terminated": Categorical(
|
|
436
|
+
n=2,
|
|
437
|
+
shape=torch.Size((n_agents, 1)),
|
|
438
|
+
dtype=torch.bool,
|
|
439
|
+
device=self.device,
|
|
440
|
+
),
|
|
441
|
+
"truncated": Categorical(
|
|
442
|
+
n=2,
|
|
443
|
+
shape=torch.Size((n_agents, 1)),
|
|
444
|
+
dtype=torch.bool,
|
|
445
|
+
device=self.device,
|
|
446
|
+
),
|
|
447
|
+
},
|
|
448
|
+
shape=torch.Size((n_agents,)),
|
|
449
|
+
)
|
|
450
|
+
return (
|
|
451
|
+
group_observation_spec,
|
|
452
|
+
group_action_spec,
|
|
453
|
+
group_reward_spec,
|
|
454
|
+
group_done_spec,
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
def _check_kwargs(self, kwargs: dict):
|
|
458
|
+
import pettingzoo
|
|
459
|
+
|
|
460
|
+
if "env" not in kwargs:
|
|
461
|
+
raise TypeError("Could not find environment key 'env' in kwargs.")
|
|
462
|
+
env = kwargs["env"]
|
|
463
|
+
if not isinstance(
|
|
464
|
+
env, (pettingzoo.utils.env.ParallelEnv, pettingzoo.utils.env.AECEnv)
|
|
465
|
+
):
|
|
466
|
+
raise TypeError("env is not of type expected.")
|
|
467
|
+
|
|
468
|
+
def _init_env(self):
|
|
469
|
+
# Add info
|
|
470
|
+
if self.parallel:
|
|
471
|
+
_, info_dict = self._reset_parallel(seed=self.seed)
|
|
472
|
+
else:
|
|
473
|
+
_, info_dict = self._reset_aec(seed=self.seed)
|
|
474
|
+
|
|
475
|
+
for group, agents in self.group_map.items():
|
|
476
|
+
info_specs = []
|
|
477
|
+
for agent in agents:
|
|
478
|
+
info_specs.append(
|
|
479
|
+
Composite(
|
|
480
|
+
{
|
|
481
|
+
"info": Composite(
|
|
482
|
+
{
|
|
483
|
+
key: Unbounded(
|
|
484
|
+
shape=torch.as_tensor(value).shape,
|
|
485
|
+
device=self.device,
|
|
486
|
+
)
|
|
487
|
+
for key, value in info_dict[agent].items()
|
|
488
|
+
}
|
|
489
|
+
)
|
|
490
|
+
},
|
|
491
|
+
device=self.device,
|
|
492
|
+
)
|
|
493
|
+
)
|
|
494
|
+
info_specs = torch.stack(info_specs, dim=0)
|
|
495
|
+
if ("info", "action_mask") in info_specs.keys(True, True):
|
|
496
|
+
if not self.has_action_mask[group]:
|
|
497
|
+
self.has_action_mask[group] = True
|
|
498
|
+
group_action_spec = self.input_spec[
|
|
499
|
+
"full_action_spec", group, "action"
|
|
500
|
+
]
|
|
501
|
+
self.observation_spec[group]["action_mask"] = Categorical(
|
|
502
|
+
n=2,
|
|
503
|
+
shape=group_action_spec.shape
|
|
504
|
+
if not self.categorical_actions
|
|
505
|
+
else group_action_spec.to_one_hot_spec().shape,
|
|
506
|
+
dtype=torch.bool,
|
|
507
|
+
device=self.device,
|
|
508
|
+
)
|
|
509
|
+
group_inner_info_spec = info_specs["info"]
|
|
510
|
+
del group_inner_info_spec["action_mask"]
|
|
511
|
+
|
|
512
|
+
if len(info_specs["info"].keys()):
|
|
513
|
+
self.observation_spec[group].update(info_specs)
|
|
514
|
+
|
|
515
|
+
if self.return_state:
|
|
516
|
+
try:
|
|
517
|
+
state_spec = _gym_to_torchrl_spec_transform(
|
|
518
|
+
self.state_space,
|
|
519
|
+
remap_state_to_observation=False,
|
|
520
|
+
device=self.device,
|
|
521
|
+
)
|
|
522
|
+
except AttributeError:
|
|
523
|
+
state_example = torch.as_tensor(self.state(), device=self.device)
|
|
524
|
+
state_spec = Unbounded(
|
|
525
|
+
shape=state_example.shape,
|
|
526
|
+
dtype=state_example.dtype,
|
|
527
|
+
device=self.device,
|
|
528
|
+
)
|
|
529
|
+
self.observation_spec["state"] = state_spec
|
|
530
|
+
|
|
531
|
+
# Caching
|
|
532
|
+
self.cached_reset_output_zero = self.observation_spec.zero()
|
|
533
|
+
self.cached_reset_output_zero.update(self.output_spec["full_done_spec"].zero())
|
|
534
|
+
|
|
535
|
+
self.cached_step_output_zero = self.observation_spec.zero()
|
|
536
|
+
self.cached_step_output_zero.update(self.output_spec["full_reward_spec"].zero())
|
|
537
|
+
self.cached_step_output_zero.update(self.output_spec["full_done_spec"].zero())
|
|
538
|
+
|
|
539
|
+
def _set_seed(self, seed: int | None) -> None:
|
|
540
|
+
self.seed = seed
|
|
541
|
+
self.reset(seed=self.seed)
|
|
542
|
+
|
|
543
|
+
def _reset(
|
|
544
|
+
self, tensordict: TensorDictBase | None = None, **kwargs
|
|
545
|
+
) -> TensorDictBase:
|
|
546
|
+
if tensordict is not None:
|
|
547
|
+
_reset = tensordict.get("_reset", None)
|
|
548
|
+
if _reset is not None and not _reset.all():
|
|
549
|
+
raise RuntimeError(
|
|
550
|
+
f"An attempt to call {type(self)}._reset was made when no "
|
|
551
|
+
f"reset signal could be found. Expected '_reset' entry to "
|
|
552
|
+
f"be `tensor(True)` or `None` but got `{_reset}`."
|
|
553
|
+
)
|
|
554
|
+
if self.parallel:
|
|
555
|
+
# This resets when any is done
|
|
556
|
+
observation_dict, info_dict = self._reset_parallel(**kwargs)
|
|
557
|
+
else:
|
|
558
|
+
# This resets when all are done
|
|
559
|
+
observation_dict, info_dict = self._reset_aec(**kwargs)
|
|
560
|
+
|
|
561
|
+
# We start with zeroed data and fill in the data for alive agents
|
|
562
|
+
tensordict_out = self.cached_reset_output_zero.clone()
|
|
563
|
+
# Update the "mask" for non-acting agents
|
|
564
|
+
self._update_agent_mask(tensordict_out)
|
|
565
|
+
# Update the "action_mask" for non-available actions
|
|
566
|
+
observation_dict, info_dict = self._update_action_mask(
|
|
567
|
+
tensordict_out, observation_dict, info_dict
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
# Now we get the data (obs and info)
|
|
571
|
+
for group, agent_names in self.group_map.items():
|
|
572
|
+
group_observation = tensordict_out.get((group, "observation"))
|
|
573
|
+
group_info = tensordict_out.get((group, "info"), None)
|
|
574
|
+
|
|
575
|
+
for index, agent in enumerate(agent_names):
|
|
576
|
+
group_observation[index] = self.observation_spec[group, "observation"][
|
|
577
|
+
index
|
|
578
|
+
].encode(observation_dict[agent])
|
|
579
|
+
if group_info is not None:
|
|
580
|
+
agent_info_dict = info_dict[agent]
|
|
581
|
+
for agent_info, value in agent_info_dict.items():
|
|
582
|
+
group_info.get(agent_info)[index] = torch.as_tensor(
|
|
583
|
+
value, device=self.device
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
if self.return_state:
|
|
587
|
+
state = torch.as_tensor(self.state(), device=self.device)
|
|
588
|
+
tensordict_out.set("state", state)
|
|
589
|
+
|
|
590
|
+
return tensordict_out
|
|
591
|
+
|
|
592
|
+
def _reset_aec(self, **kwargs) -> tuple[dict, dict]:
|
|
593
|
+
self._env.reset(**kwargs)
|
|
594
|
+
|
|
595
|
+
observation_dict = {
|
|
596
|
+
agent: self._env.observe(agent) for agent in self.possible_agents
|
|
597
|
+
}
|
|
598
|
+
info_dict = self._env.infos
|
|
599
|
+
return observation_dict, info_dict
|
|
600
|
+
|
|
601
|
+
def _reset_parallel(self, **kwargs) -> tuple[dict, dict]:
|
|
602
|
+
return self._env.reset(**kwargs)
|
|
603
|
+
|
|
604
|
+
def _step(
|
|
605
|
+
self,
|
|
606
|
+
tensordict: TensorDictBase,
|
|
607
|
+
) -> TensorDictBase:
|
|
608
|
+
if self.parallel:
|
|
609
|
+
(
|
|
610
|
+
observation_dict,
|
|
611
|
+
rewards_dict,
|
|
612
|
+
terminations_dict,
|
|
613
|
+
truncations_dict,
|
|
614
|
+
info_dict,
|
|
615
|
+
) = self._step_parallel(tensordict)
|
|
616
|
+
else:
|
|
617
|
+
(
|
|
618
|
+
observation_dict,
|
|
619
|
+
rewards_dict,
|
|
620
|
+
terminations_dict,
|
|
621
|
+
truncations_dict,
|
|
622
|
+
info_dict,
|
|
623
|
+
) = self._step_aec(tensordict)
|
|
624
|
+
|
|
625
|
+
# We start with zeroed data and fill in the data for alive agents
|
|
626
|
+
tensordict_out = self.cached_step_output_zero.clone()
|
|
627
|
+
# Update the "mask" for non-acting agents
|
|
628
|
+
self._update_agent_mask(tensordict_out)
|
|
629
|
+
# Update the "action_mask" for non-available actions
|
|
630
|
+
observation_dict, info_dict = self._update_action_mask(
|
|
631
|
+
tensordict_out, observation_dict, info_dict
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
# Now we get the data
|
|
635
|
+
for group, agent_names in self.group_map.items():
|
|
636
|
+
group_observation = tensordict_out.get((group, "observation"))
|
|
637
|
+
group_reward = tensordict_out.get((group, "reward"))
|
|
638
|
+
group_done = tensordict_out.get((group, "done"))
|
|
639
|
+
group_terminated = tensordict_out.get((group, "terminated"))
|
|
640
|
+
group_truncated = tensordict_out.get((group, "truncated"))
|
|
641
|
+
group_info = tensordict_out.get((group, "info"), None)
|
|
642
|
+
|
|
643
|
+
for index, agent in enumerate(agent_names):
|
|
644
|
+
if agent in observation_dict: # Live agents
|
|
645
|
+
group_observation[index] = self.observation_spec[
|
|
646
|
+
group, "observation"
|
|
647
|
+
][index].encode(observation_dict[agent])
|
|
648
|
+
group_reward[index] = torch.tensor(
|
|
649
|
+
rewards_dict[agent],
|
|
650
|
+
device=self.device,
|
|
651
|
+
dtype=torch.float32,
|
|
652
|
+
)
|
|
653
|
+
group_done[index] = torch.tensor(
|
|
654
|
+
terminations_dict[agent] or truncations_dict[agent],
|
|
655
|
+
device=self.device,
|
|
656
|
+
dtype=torch.bool,
|
|
657
|
+
)
|
|
658
|
+
group_truncated[index] = torch.tensor(
|
|
659
|
+
truncations_dict[agent],
|
|
660
|
+
device=self.device,
|
|
661
|
+
dtype=torch.bool,
|
|
662
|
+
)
|
|
663
|
+
group_terminated[index] = torch.tensor(
|
|
664
|
+
terminations_dict[agent],
|
|
665
|
+
device=self.device,
|
|
666
|
+
dtype=torch.bool,
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
if group_info is not None:
|
|
670
|
+
agent_info_dict = info_dict[agent]
|
|
671
|
+
for agent_info, value in agent_info_dict.items():
|
|
672
|
+
group_info.get(agent_info)[index] = torch.tensor(
|
|
673
|
+
value, device=self.device
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
elif self.use_mask:
|
|
677
|
+
if agent in self.agents:
|
|
678
|
+
raise ValueError(
|
|
679
|
+
f"Dead agent {agent} not found in step observation but still available in {self.agents}"
|
|
680
|
+
)
|
|
681
|
+
# Dead agent
|
|
682
|
+
terminated = (
|
|
683
|
+
terminations_dict[agent] if agent in terminations_dict else True
|
|
684
|
+
)
|
|
685
|
+
truncated = (
|
|
686
|
+
truncations_dict[agent] if agent in truncations_dict else True
|
|
687
|
+
)
|
|
688
|
+
done = terminated or truncated
|
|
689
|
+
group_done[index] = done
|
|
690
|
+
group_terminated[index] = terminated
|
|
691
|
+
group_truncated[index] = truncated
|
|
692
|
+
|
|
693
|
+
else:
|
|
694
|
+
# Dead agent, if we are not masking it out, this is not allowed
|
|
695
|
+
raise ValueError(
|
|
696
|
+
"Dead agents found in the environment,"
|
|
697
|
+
" you need to set use_mask=True to allow this."
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
# set done values
|
|
701
|
+
done, terminated, truncated = self._aggregate_done(
|
|
702
|
+
tensordict_out, use_any=self.done_on_any
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
tensordict_out.set("done", done)
|
|
706
|
+
tensordict_out.set("terminated", terminated)
|
|
707
|
+
tensordict_out.set("truncated", truncated)
|
|
708
|
+
|
|
709
|
+
if self.return_state:
|
|
710
|
+
state = torch.as_tensor(self.state(), device=self.device)
|
|
711
|
+
tensordict_out.set("state", state)
|
|
712
|
+
|
|
713
|
+
return tensordict_out
|
|
714
|
+
|
|
715
|
+
def _aggregate_done(self, tensordict_out, use_any):
|
|
716
|
+
done = False if use_any else True
|
|
717
|
+
truncated = False if use_any else True
|
|
718
|
+
terminated = False if use_any else True
|
|
719
|
+
for key in self.done_keys:
|
|
720
|
+
if isinstance(key, tuple): # Only look at group keys
|
|
721
|
+
if use_any:
|
|
722
|
+
if key[-1] == "done":
|
|
723
|
+
done = done | tensordict_out.get(key).any()
|
|
724
|
+
if key[-1] == "terminated":
|
|
725
|
+
terminated = terminated | tensordict_out.get(key).any()
|
|
726
|
+
if key[-1] == "truncated":
|
|
727
|
+
truncated = truncated | tensordict_out.get(key).any()
|
|
728
|
+
if done and terminated and truncated:
|
|
729
|
+
# no need to proceed further, all values are flipped
|
|
730
|
+
break
|
|
731
|
+
else:
|
|
732
|
+
if key[-1] == "done":
|
|
733
|
+
done = done & tensordict_out.get(key).all()
|
|
734
|
+
if key[-1] == "terminated":
|
|
735
|
+
terminated = terminated & tensordict_out.get(key).all()
|
|
736
|
+
if key[-1] == "truncated":
|
|
737
|
+
truncated = truncated & tensordict_out.get(key).all()
|
|
738
|
+
if not done and not terminated and not truncated:
|
|
739
|
+
# no need to proceed further, all values are flipped
|
|
740
|
+
break
|
|
741
|
+
return (
|
|
742
|
+
torch.tensor([done], device=self.device),
|
|
743
|
+
torch.tensor([terminated], device=self.device),
|
|
744
|
+
torch.tensor([truncated], device=self.device),
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
def _step_parallel(
|
|
748
|
+
self,
|
|
749
|
+
tensordict: TensorDictBase,
|
|
750
|
+
) -> tuple[dict, dict, dict, dict, dict]:
|
|
751
|
+
action_dict = {}
|
|
752
|
+
for group, agents in self.group_map.items():
|
|
753
|
+
group_action = tensordict.get((group, "action"))
|
|
754
|
+
group_action_np = self.input_spec[
|
|
755
|
+
"full_action_spec", group, "action"
|
|
756
|
+
].to_numpy(group_action)
|
|
757
|
+
for index, agent in enumerate(agents):
|
|
758
|
+
# group_action_np can be a dict or an array. We need to recursively index it
|
|
759
|
+
action = _extract_nested_with_index(group_action_np, index)
|
|
760
|
+
action_dict[agent] = action
|
|
761
|
+
|
|
762
|
+
return self._env.step(action_dict)
|
|
763
|
+
|
|
764
|
+
def _step_aec(
|
|
765
|
+
self,
|
|
766
|
+
tensordict: TensorDictBase,
|
|
767
|
+
) -> tuple[dict, dict, dict, dict, dict]:
|
|
768
|
+
for group, agents in self.group_map.items():
|
|
769
|
+
if self.agent_selection in agents:
|
|
770
|
+
agent_index = agents.index(self._env.agent_selection)
|
|
771
|
+
group_action = tensordict.get((group, "action"))
|
|
772
|
+
group_action_np = self.input_spec[
|
|
773
|
+
"full_action_spec", group, "action"
|
|
774
|
+
].to_numpy(group_action)
|
|
775
|
+
# group_action_np can be a dict or an array. We need to recursively index it
|
|
776
|
+
action = _extract_nested_with_index(group_action_np, agent_index)
|
|
777
|
+
break
|
|
778
|
+
|
|
779
|
+
self._env.step(action)
|
|
780
|
+
terminations_dict = self._env.terminations
|
|
781
|
+
truncations_dict = self._env.truncations
|
|
782
|
+
info_dict = self._env.infos
|
|
783
|
+
rewards_dict = self._env.rewards
|
|
784
|
+
observation_dict = {
|
|
785
|
+
agent: self._env.observe(agent) for agent in self.possible_agents
|
|
786
|
+
}
|
|
787
|
+
return (
|
|
788
|
+
observation_dict,
|
|
789
|
+
rewards_dict,
|
|
790
|
+
terminations_dict,
|
|
791
|
+
truncations_dict,
|
|
792
|
+
info_dict,
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
def _update_action_mask(self, td, observation_dict, info_dict):
|
|
796
|
+
# Since we remove the action_mask keys we need to copy the data
|
|
797
|
+
observation_dict = copy.deepcopy(observation_dict)
|
|
798
|
+
info_dict = copy.deepcopy(info_dict)
|
|
799
|
+
# In AEC only one agent acts, in parallel env self.agents contains the agents alive
|
|
800
|
+
agents_acting = self.agents if self.parallel else [self.agent_selection]
|
|
801
|
+
|
|
802
|
+
for group, agents in self.group_map.items():
|
|
803
|
+
if self.has_action_mask[group]:
|
|
804
|
+
group_mask = td.get((group, "action_mask"))
|
|
805
|
+
group_mask += True
|
|
806
|
+
for index, agent in enumerate(agents):
|
|
807
|
+
agent_obs = observation_dict[agent]
|
|
808
|
+
agent_info = info_dict[agent]
|
|
809
|
+
if isinstance(agent_obs, dict) and "action_mask" in agent_obs:
|
|
810
|
+
if agent in agents_acting:
|
|
811
|
+
group_mask[index] = torch.tensor(
|
|
812
|
+
agent_obs["action_mask"],
|
|
813
|
+
device=self.device,
|
|
814
|
+
dtype=torch.bool,
|
|
815
|
+
)
|
|
816
|
+
del agent_obs["action_mask"]
|
|
817
|
+
elif isinstance(agent_info, dict) and "action_mask" in agent_info:
|
|
818
|
+
if agent in agents_acting:
|
|
819
|
+
group_mask[index] = torch.tensor(
|
|
820
|
+
agent_info["action_mask"],
|
|
821
|
+
device=self.device,
|
|
822
|
+
dtype=torch.bool,
|
|
823
|
+
)
|
|
824
|
+
del agent_info["action_mask"]
|
|
825
|
+
|
|
826
|
+
group_action_spec = self.input_spec["full_action_spec", group, "action"]
|
|
827
|
+
if isinstance(group_action_spec, (Categorical, OneHot)):
|
|
828
|
+
# We update the mask for available actions
|
|
829
|
+
group_action_spec.update_mask(group_mask.clone())
|
|
830
|
+
|
|
831
|
+
return observation_dict, info_dict
|
|
832
|
+
|
|
833
|
+
def _update_agent_mask(self, td):
|
|
834
|
+
if self.use_mask:
|
|
835
|
+
# In AEC only one agent acts, in parallel env self.agents contains the agents alive
|
|
836
|
+
agents_acting = self.agents if self.parallel else [self.agent_selection]
|
|
837
|
+
for group, agents in self.group_map.items():
|
|
838
|
+
group_mask = td.get((group, "mask"))
|
|
839
|
+
group_mask += True
|
|
840
|
+
|
|
841
|
+
# We now add dead agents to the mask
|
|
842
|
+
for index, agent in enumerate(agents):
|
|
843
|
+
if agent not in agents_acting:
|
|
844
|
+
group_mask[index] = False
|
|
845
|
+
|
|
846
|
+
def close(self, *, raise_if_closed: bool = True) -> None:
|
|
847
|
+
self._env.close()
|
|
848
|
+
|
|
849
|
+
|
|
850
|
+
class PettingZooEnv(PettingZooWrapper):
|
|
851
|
+
"""PettingZoo Environment.
|
|
852
|
+
|
|
853
|
+
To install petting zoo follow the guide `here <https://github.com/Farama-Foundation/PettingZoo#installation>__`.
|
|
854
|
+
|
|
855
|
+
This class is a general torchrl wrapper for all PettingZoo environments.
|
|
856
|
+
It is able to wrap both ``pettingzoo.AECEnv`` and ``pettingzoo.ParallelEnv``.
|
|
857
|
+
|
|
858
|
+
Let's see how more in details:
|
|
859
|
+
|
|
860
|
+
For wrapping ``pettingzoo.ParallelEnv`` provide the name of your petting zoo task (in the ``task`` argument)
|
|
861
|
+
and specify ``parallel=True``. This will construct the ``pettingzoo.ParallelEnv`` version of that task
|
|
862
|
+
(if it is supported in pettingzoo) and wrap it for torchrl.
|
|
863
|
+
In wrapped ``pettingzoo.ParallelEnv`` all agents will step at each environment step.
|
|
864
|
+
If the number of agents during the task varies, please set ``use_mask=True``.
|
|
865
|
+
``"mask"`` will be provided
|
|
866
|
+
as an output in each group and should be used to mask out dead agents.
|
|
867
|
+
The environment will be reset as soon as one agent is done (unless ``done_on_any`` is ``False``).
|
|
868
|
+
|
|
869
|
+
For wrapping ``pettingzoo.AECEnv`` provide the name of your petting zoo task (in the ``task`` argument)
|
|
870
|
+
and specify ``parallel=False``. This will construct the ``pettingzoo.AECEnv`` version of that task
|
|
871
|
+
and wrap it for torchrl.
|
|
872
|
+
In wrapped ``pettingzoo.AECEnv``, at each step only one agent will act.
|
|
873
|
+
For this reason, it is compulsory to set ``use_mask=True`` for this type of environment.
|
|
874
|
+
``"mask"`` will be provided as an output for each group and can be used to mask out non-acting agents.
|
|
875
|
+
The environment will be reset only when all agents are done (unless ``done_on_any`` is ``True``).
|
|
876
|
+
|
|
877
|
+
If there are any unavailable actions for an agent,
|
|
878
|
+
the environment will also automatically update the mask of its ``action_spec`` and output an ``"action_mask"``
|
|
879
|
+
for each group to reflect the latest available actions. This should be passed to a masked distribution during
|
|
880
|
+
training.
|
|
881
|
+
|
|
882
|
+
As a feature of torchrl multiagent, you are able to control the grouping of agents in your environment.
|
|
883
|
+
You can group agents together (stacking their tensors) to leverage vectorization when passing them through the same
|
|
884
|
+
neural network. You can split agents in different groups where they are heterogenous or should be processed by
|
|
885
|
+
different neural networks. To group, you just need to pass a ``group_map`` at env constructiuon time.
|
|
886
|
+
|
|
887
|
+
By default, agents in pettingzoo will be grouped by name.
|
|
888
|
+
For example, with agents ``["agent_0","agent_1","agent_2","adversary_0"]``, the tensordicts will look like:
|
|
889
|
+
|
|
890
|
+
>>> print(env.rand_action(env.reset()))
|
|
891
|
+
TensorDict(
|
|
892
|
+
fields={
|
|
893
|
+
agent: TensorDict(
|
|
894
|
+
fields={
|
|
895
|
+
action: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
896
|
+
action_mask: Tensor(shape=torch.Size([3, 9]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
897
|
+
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
898
|
+
observation: Tensor(shape=torch.Size([3, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False),
|
|
899
|
+
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
900
|
+
truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
901
|
+
batch_size=torch.Size([3]))},
|
|
902
|
+
adversary: TensorDict(
|
|
903
|
+
fields={
|
|
904
|
+
action: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
905
|
+
action_mask: Tensor(shape=torch.Size([1, 9]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
906
|
+
done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
907
|
+
observation: Tensor(shape=torch.Size([1, 3, 3, 2]), device=cpu, dtype=torch.int8, is_shared=False),
|
|
908
|
+
terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
909
|
+
truncated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
910
|
+
batch_size=torch.Size([1]))},
|
|
911
|
+
batch_size=torch.Size([]))
|
|
912
|
+
>>> print(env.group_map)
|
|
913
|
+
{"agent": ["agent_0", "agent_1", "agent_2"], "adversary": ["adversary_0"]}
|
|
914
|
+
|
|
915
|
+
Otherwise, a group map can be specified or selected from some premade options.
|
|
916
|
+
See :class:`torchrl.envs.utils.MarlGroupMapType` for more info.
|
|
917
|
+
For example, you can provide ``MarlGroupMapType.ONE_GROUP_PER_AGENT``, telling that each agent should
|
|
918
|
+
have its own tensordict (similar to the pettingzoo parallel API).
|
|
919
|
+
|
|
920
|
+
Grouping is useful for leveraging vectorization among agents whose data goes through the same
|
|
921
|
+
neural network.
|
|
922
|
+
|
|
923
|
+
Args:
|
|
924
|
+
task (str): the name of the pettingzoo task to create in the "<env>/<task>" format (for example, "sisl/multiwalker_v9")
|
|
925
|
+
or "<task>" format (for example, "multiwalker_v9").
|
|
926
|
+
parallel (bool): if to construct the ``pettingzoo.ParallelEnv`` version of the task or the ``pettingzoo.AECEnv``.
|
|
927
|
+
return_state (bool, optional): whether to return the global state from pettingzoo
|
|
928
|
+
(not available in all environments). Defaults to ``False``.
|
|
929
|
+
group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to group agents in tensordicts for
|
|
930
|
+
input/output. By default, agents will be grouped by their name. Otherwise, a group map can be specified
|
|
931
|
+
or selected from some premade options. See :class:`torchrl.envs.utils.MarlGroupMapType` for more info.
|
|
932
|
+
use_mask (bool, optional): whether the environment should output an ``"mask"``. This is compulsory in
|
|
933
|
+
wrapped ``pettingzoo.AECEnv`` to mask out non-acting agents and should be also used
|
|
934
|
+
for ``pettingzoo.ParallelEnv`` when the number of agents can vary. Defaults to ``False``.
|
|
935
|
+
categorical_actions (bool, optional): if the environments actions are discrete, whether to transform
|
|
936
|
+
them to categorical or one-hot.
|
|
937
|
+
seed (int, optional): the seed. Defaults to ``None``.
|
|
938
|
+
done_on_any (bool, optional): whether the environment's done keys are set by aggregating the agent keys
|
|
939
|
+
using ``any()`` (when ``True``) or ``all()`` (when ``False``). Default (``None``) is to use ``any()`` for
|
|
940
|
+
parallel environments and ``all()`` for AEC ones.
|
|
941
|
+
|
|
942
|
+
Examples:
|
|
943
|
+
>>> # Parallel env
|
|
944
|
+
>>> from torchrl.envs.libs.pettingzoo import PettingZooEnv
|
|
945
|
+
>>> kwargs = {"n_pistons": 21, "continuous": True}
|
|
946
|
+
>>> env = PettingZooEnv(
|
|
947
|
+
... task="pistonball_v6",
|
|
948
|
+
... parallel=True,
|
|
949
|
+
... return_state=True,
|
|
950
|
+
... group_map=None, # Use default (all pistons grouped together)
|
|
951
|
+
... **kwargs,
|
|
952
|
+
... )
|
|
953
|
+
>>> print(env.group_map)
|
|
954
|
+
... {'piston': ['piston_0', 'piston_1', ..., 'piston_20']}
|
|
955
|
+
>>> env.rollout(10)
|
|
956
|
+
>>> # AEC env
|
|
957
|
+
>>> from torchrl.envs.libs.pettingzoo import PettingZooEnv
|
|
958
|
+
>>> from torchrl.envs.utils import MarlGroupMapType
|
|
959
|
+
>>> env = PettingZooEnv(
|
|
960
|
+
... task="tictactoe_v3",
|
|
961
|
+
... parallel=False,
|
|
962
|
+
... use_mask=True, # Must use it since one player plays at a time
|
|
963
|
+
... group_map=None # # Use default for AEC (one group per player)
|
|
964
|
+
... )
|
|
965
|
+
>>> print(env.group_map)
|
|
966
|
+
... {'player_1': ['player_1'], 'player_2': ['player_2']}
|
|
967
|
+
>>> env.rollout(10)
|
|
968
|
+
"""
|
|
969
|
+
|
|
970
|
+
def __init__(
|
|
971
|
+
self,
|
|
972
|
+
task: str,
|
|
973
|
+
parallel: bool,
|
|
974
|
+
return_state: bool = False,
|
|
975
|
+
group_map: MarlGroupMapType | dict[str, list[str]] | None = None,
|
|
976
|
+
use_mask: bool = False,
|
|
977
|
+
categorical_actions: bool = True,
|
|
978
|
+
seed: int | None = None,
|
|
979
|
+
done_on_any: bool | None = None,
|
|
980
|
+
**kwargs,
|
|
981
|
+
):
|
|
982
|
+
if not _has_pettingzoo:
|
|
983
|
+
raise ImportError(
|
|
984
|
+
f"pettingzoo python package was not found. Please install this dependency. "
|
|
985
|
+
f"More info: {self.git_url}."
|
|
986
|
+
)
|
|
987
|
+
kwargs["task"] = task
|
|
988
|
+
kwargs["parallel"] = parallel
|
|
989
|
+
kwargs["return_state"] = return_state
|
|
990
|
+
kwargs["group_map"] = group_map
|
|
991
|
+
kwargs["use_mask"] = use_mask
|
|
992
|
+
kwargs["categorical_actions"] = categorical_actions
|
|
993
|
+
kwargs["seed"] = seed
|
|
994
|
+
kwargs["done_on_any"] = done_on_any
|
|
995
|
+
|
|
996
|
+
super().__init__(**kwargs)
|
|
997
|
+
|
|
998
|
+
def _check_kwargs(self, kwargs: dict):
|
|
999
|
+
if "task" not in kwargs:
|
|
1000
|
+
raise TypeError("Could not find environment key 'task' in kwargs.")
|
|
1001
|
+
if "parallel" not in kwargs:
|
|
1002
|
+
raise TypeError("Could not find environment key 'parallel' in kwargs.")
|
|
1003
|
+
|
|
1004
|
+
def _build_env(
|
|
1005
|
+
self,
|
|
1006
|
+
task: str,
|
|
1007
|
+
parallel: bool,
|
|
1008
|
+
**kwargs,
|
|
1009
|
+
) -> (
|
|
1010
|
+
pettingzoo.utils.env.ParallelEnv # noqa: F821
|
|
1011
|
+
| pettingzoo.utils.env.AECEnv # noqa: F821
|
|
1012
|
+
):
|
|
1013
|
+
self.task_name = task
|
|
1014
|
+
|
|
1015
|
+
try:
|
|
1016
|
+
from pettingzoo.utils.all_modules import all_environments
|
|
1017
|
+
except ModuleNotFoundError as err:
|
|
1018
|
+
warnings.warn(
|
|
1019
|
+
f"PettingZoo failed to load all modules with error message {err}, trying to load individual modules."
|
|
1020
|
+
)
|
|
1021
|
+
all_environments = _load_available_envs()
|
|
1022
|
+
|
|
1023
|
+
if task not in all_environments:
|
|
1024
|
+
# Try looking at the literal translation of values
|
|
1025
|
+
task_module = None
|
|
1026
|
+
for value in all_environments.values():
|
|
1027
|
+
if value.__name__.split(".")[-1] == task:
|
|
1028
|
+
task_module = value
|
|
1029
|
+
break
|
|
1030
|
+
if task_module is None:
|
|
1031
|
+
raise RuntimeError(
|
|
1032
|
+
f"Specified task not in available environments {all_environments}"
|
|
1033
|
+
)
|
|
1034
|
+
else:
|
|
1035
|
+
task_module = all_environments[task]
|
|
1036
|
+
|
|
1037
|
+
if parallel:
|
|
1038
|
+
petting_zoo_env = task_module.parallel_env(**kwargs)
|
|
1039
|
+
else:
|
|
1040
|
+
petting_zoo_env = task_module.env(**kwargs)
|
|
1041
|
+
|
|
1042
|
+
return super()._build_env(env=petting_zoo_env)
|