torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,599 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib
|
|
8
|
+
from collections.abc import Mapping, Sequence
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from tensordict import TensorDict, TensorDictBase
|
|
12
|
+
|
|
13
|
+
from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec
|
|
14
|
+
from torchrl.envs.common import _EnvWrapper
|
|
15
|
+
from torchrl.envs.libs.dm_control import _dmcontrol_to_torchrl_spec_transform
|
|
16
|
+
from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType
|
|
17
|
+
|
|
18
|
+
_has_meltingpot = importlib.util.find_spec("meltingpot") is not None
|
|
19
|
+
|
|
20
|
+
PLAYER_STR_FORMAT = "player_{index}"
|
|
21
|
+
_WORLD_PREFIX = "WORLD."
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _get_envs():
|
|
25
|
+
if not _has_meltingpot:
|
|
26
|
+
raise ImportError("meltingpot is not installed in your virtual environment.")
|
|
27
|
+
from meltingpot.configs import substrates as substrate_configs
|
|
28
|
+
|
|
29
|
+
return list(substrate_configs.SUBSTRATES)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _filter_global_state_from_dict(obs_dict: dict, world: bool) -> dict: # noqa
|
|
33
|
+
return {
|
|
34
|
+
key: value
|
|
35
|
+
for key, value in obs_dict.items()
|
|
36
|
+
if ((_WORLD_PREFIX not in key) if not world else (_WORLD_PREFIX in key))
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _remove_world_observations_from_obs_spec(
|
|
41
|
+
observation_spec: Sequence[Mapping[str, dm_env.specs.Array]], # noqa
|
|
42
|
+
) -> Sequence[Mapping[str, dm_env.specs.Array]]: # noqa
|
|
43
|
+
return [
|
|
44
|
+
_filter_global_state_from_dict(agent_obs, world=False)
|
|
45
|
+
for agent_obs in observation_spec
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _global_state_spec_from_obs_spec(
|
|
50
|
+
observation_spec: Sequence[Mapping[str, dm_env.specs.Array]] # noqa
|
|
51
|
+
) -> Mapping[str, dm_env.specs.Array]: # noqa
|
|
52
|
+
# We only look at agent 0 since world entries are the same for all agents
|
|
53
|
+
world_entries = _filter_global_state_from_dict(observation_spec[0], world=True)
|
|
54
|
+
if len(world_entries) != 1 and _WORLD_PREFIX + "RGB" not in world_entries:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
f"Expected only one world entry named {_WORLD_PREFIX}RGB in observation_spec, but got {world_entries}"
|
|
57
|
+
)
|
|
58
|
+
return _remove_world_prefix(world_entries)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _remove_world_prefix(world_entries: dict) -> dict:
|
|
62
|
+
return {key[len(_WORLD_PREFIX) :]: value for key, value in world_entries.items()}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class MeltingpotWrapper(_EnvWrapper):
|
|
66
|
+
"""Meltingpot environment wrapper.
|
|
67
|
+
|
|
68
|
+
GitHub: https://github.com/google-deepmind/meltingpot
|
|
69
|
+
|
|
70
|
+
Paper: https://arxiv.org/abs/2211.13746
|
|
71
|
+
|
|
72
|
+
Melting Pot assesses generalization to novel social situations involving both familiar and unfamiliar individuals,
|
|
73
|
+
and has been designed to test a broad range of social interactions such as: cooperation, competition, deception,
|
|
74
|
+
reciprocation, trust, stubbornness and so on. Melting Pot offers researchers a set of over 50 multi-agent
|
|
75
|
+
reinforcement learning substrates (multi-agent games) on which to train agents, and over 256 unique test scenarios
|
|
76
|
+
on which to evaluate these trained agents.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
env (``meltingpot.utils.substrates.substrate.Substrate``): the meltingpot substrate to wrap.
|
|
80
|
+
|
|
81
|
+
Keyword Args:
|
|
82
|
+
max_steps (int, optional): Horizon of the task. Defaults to ``None`` (infinite horizon).
|
|
83
|
+
Each Meltingpot substrate can
|
|
84
|
+
be terminating or not. If ``max_steps`` is specified,
|
|
85
|
+
the scenario is also terminated (and the ``"terminated"`` flag is set) whenever this horizon is reached.
|
|
86
|
+
Unlike gym's ``TimeLimit`` transform or torchrl's :class:`~torchrl.envs.transforms.StepCounter`,
|
|
87
|
+
this argument will not set the ``"truncated"`` entry in the tensordict.
|
|
88
|
+
categorical_actions (bool, optional): if the environment actions are discrete, whether to transform
|
|
89
|
+
them to categorical or one-hot. Defaults to ``True``.
|
|
90
|
+
group_map (MarlGroupMapType or Dict[str, List[str]], optional): how to group agents in tensordicts for
|
|
91
|
+
input/output. By default, they will be all put
|
|
92
|
+
in one group named ``"agents"``.
|
|
93
|
+
Otherwise, a group map can be specified or selected from some premade options.
|
|
94
|
+
See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info.
|
|
95
|
+
|
|
96
|
+
Attributes:
|
|
97
|
+
group_map (Dict[str, List[str]]): how to group agents in tensordicts for
|
|
98
|
+
input/output. See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info.
|
|
99
|
+
agent_names (list of str): names of the agent in the environment
|
|
100
|
+
agent_names_to_indices_map (Dict[str, int]): dictionary mapping agent names to their index in the environment
|
|
101
|
+
available_envs (List[str]): the list of the scenarios available to build.
|
|
102
|
+
|
|
103
|
+
.. warning::
|
|
104
|
+
Meltingpot returns a single ``done`` flag which does not distinguish between
|
|
105
|
+
when the env reached ``max_steps`` and termination.
|
|
106
|
+
If you deem the ``truncation`` signal necessary, set ``max_steps`` to
|
|
107
|
+
``None`` and use a :class:`~torchrl.envs.transforms.StepCounter` transform.
|
|
108
|
+
|
|
109
|
+
Examples:
|
|
110
|
+
>>> from meltingpot import substrate
|
|
111
|
+
>>> from torchrl.envs.libs.meltingpot import MeltingpotWrapper
|
|
112
|
+
>>> substrate_config = substrate.get_config("commons_harvest__open")
|
|
113
|
+
>>> mp_env = substrate.build_from_config(
|
|
114
|
+
... substrate_config, roles=substrate_config.default_player_roles
|
|
115
|
+
... )
|
|
116
|
+
>>> env_torchrl = MeltingpotWrapper(env=mp_env)
|
|
117
|
+
>>> print(env_torchrl.rollout(max_steps=5))
|
|
118
|
+
TensorDict(
|
|
119
|
+
fields={
|
|
120
|
+
RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
121
|
+
agents: TensorDict(
|
|
122
|
+
fields={
|
|
123
|
+
action: Tensor(shape=torch.Size([5, 7]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
124
|
+
observation: TensorDict(
|
|
125
|
+
fields={
|
|
126
|
+
COLLECTIVE_REWARD: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
127
|
+
READY_TO_SHOOT: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
128
|
+
RGB: Tensor(shape=torch.Size([5, 7, 88, 88, 3]), device=cpu, dtype=torch.uint8, is_shared=False)},
|
|
129
|
+
batch_size=torch.Size([5, 7]),
|
|
130
|
+
device=cpu,
|
|
131
|
+
is_shared=False)},
|
|
132
|
+
batch_size=torch.Size([5, 7]),
|
|
133
|
+
device=cpu,
|
|
134
|
+
is_shared=False),
|
|
135
|
+
done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
136
|
+
next: TensorDict(
|
|
137
|
+
fields={
|
|
138
|
+
RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
139
|
+
agents: TensorDict(
|
|
140
|
+
fields={
|
|
141
|
+
observation: TensorDict(
|
|
142
|
+
fields={
|
|
143
|
+
COLLECTIVE_REWARD: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
144
|
+
READY_TO_SHOOT: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
145
|
+
RGB: Tensor(shape=torch.Size([5, 7, 88, 88, 3]), device=cpu, dtype=torch.uint8, is_shared=False)},
|
|
146
|
+
batch_size=torch.Size([5, 7]),
|
|
147
|
+
device=cpu,
|
|
148
|
+
is_shared=False),
|
|
149
|
+
reward: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False)},
|
|
150
|
+
batch_size=torch.Size([5, 7]),
|
|
151
|
+
device=cpu,
|
|
152
|
+
is_shared=False),
|
|
153
|
+
done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
154
|
+
terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
155
|
+
batch_size=torch.Size([5]),
|
|
156
|
+
device=cpu,
|
|
157
|
+
is_shared=False),
|
|
158
|
+
terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
159
|
+
batch_size=torch.Size([5]),
|
|
160
|
+
device=cpu,
|
|
161
|
+
is_shared=False)
|
|
162
|
+
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
git_url = "https://github.com/google-deepmind/meltingpot"
|
|
166
|
+
libname = "melitingpot"
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def lib(self):
|
|
170
|
+
import meltingpot
|
|
171
|
+
|
|
172
|
+
return meltingpot
|
|
173
|
+
|
|
174
|
+
@_classproperty
|
|
175
|
+
def available_envs(cls):
|
|
176
|
+
if not _has_meltingpot:
|
|
177
|
+
return []
|
|
178
|
+
return _get_envs()
|
|
179
|
+
|
|
180
|
+
def __init__(
|
|
181
|
+
self,
|
|
182
|
+
env: meltingpot.utils.substrates.substrate.Substrate = None, # noqa
|
|
183
|
+
categorical_actions: bool = True,
|
|
184
|
+
group_map: MarlGroupMapType
|
|
185
|
+
| dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP,
|
|
186
|
+
max_steps: int | None = None,
|
|
187
|
+
**kwargs,
|
|
188
|
+
):
|
|
189
|
+
if env is not None:
|
|
190
|
+
kwargs["env"] = env
|
|
191
|
+
self.group_map = group_map
|
|
192
|
+
self.categorical_actions = categorical_actions
|
|
193
|
+
self.max_steps = max_steps
|
|
194
|
+
self.num_cycles = 0
|
|
195
|
+
super().__init__(**kwargs)
|
|
196
|
+
|
|
197
|
+
def _build_env(
|
|
198
|
+
self,
|
|
199
|
+
env: meltingpot.utils.substrates.substrate.Substrate, # noqa
|
|
200
|
+
):
|
|
201
|
+
return env
|
|
202
|
+
|
|
203
|
+
def _make_group_map(self):
|
|
204
|
+
if isinstance(self.group_map, MarlGroupMapType):
|
|
205
|
+
self.group_map = self.group_map.get_group_map(self.agent_names)
|
|
206
|
+
check_marl_grouping(self.group_map, self.agent_names)
|
|
207
|
+
|
|
208
|
+
def _make_specs(
|
|
209
|
+
self, env: meltingpot.utils.substrates.substrate.Substrate # noqa
|
|
210
|
+
) -> None:
|
|
211
|
+
mp_obs_spec = self._env.observation_spec() # List of dict of arrays
|
|
212
|
+
mp_obs_spec_no_world = _remove_world_observations_from_obs_spec(
|
|
213
|
+
mp_obs_spec
|
|
214
|
+
) # List of dict of arrays
|
|
215
|
+
mp_global_state_spec = _global_state_spec_from_obs_spec(
|
|
216
|
+
mp_obs_spec
|
|
217
|
+
) # Dict of arrays
|
|
218
|
+
mp_act_spec = self._env.action_spec() # List of discrete arrays
|
|
219
|
+
mp_rew_spec = self._env.reward_spec() # List of arrays
|
|
220
|
+
|
|
221
|
+
torchrl_agent_obs_specs = [
|
|
222
|
+
_dmcontrol_to_torchrl_spec_transform(agent_obs_spec)
|
|
223
|
+
for agent_obs_spec in mp_obs_spec_no_world
|
|
224
|
+
]
|
|
225
|
+
torchrl_agent_act_specs = [
|
|
226
|
+
_dmcontrol_to_torchrl_spec_transform(
|
|
227
|
+
agent_act_spec, categorical_discrete_encoding=self.categorical_actions
|
|
228
|
+
)
|
|
229
|
+
for agent_act_spec in mp_act_spec
|
|
230
|
+
]
|
|
231
|
+
torchrl_state_spec = _dmcontrol_to_torchrl_spec_transform(mp_global_state_spec)
|
|
232
|
+
torchrl_rew_spec = [
|
|
233
|
+
_dmcontrol_to_torchrl_spec_transform(agent_rew_spec)
|
|
234
|
+
for agent_rew_spec in mp_rew_spec
|
|
235
|
+
]
|
|
236
|
+
|
|
237
|
+
# Create and check group map
|
|
238
|
+
_num_players = len(torchrl_rew_spec)
|
|
239
|
+
self.agent_names = [
|
|
240
|
+
PLAYER_STR_FORMAT.format(index=index) for index in range(_num_players)
|
|
241
|
+
]
|
|
242
|
+
self.agent_names_to_indices_map = {
|
|
243
|
+
agent_name: i for i, agent_name in enumerate(self.agent_names)
|
|
244
|
+
}
|
|
245
|
+
self._make_group_map()
|
|
246
|
+
|
|
247
|
+
action_spec = Composite()
|
|
248
|
+
observation_spec = Composite()
|
|
249
|
+
reward_spec = Composite()
|
|
250
|
+
|
|
251
|
+
for group in self.group_map.keys():
|
|
252
|
+
(
|
|
253
|
+
group_observation_spec,
|
|
254
|
+
group_action_spec,
|
|
255
|
+
group_reward_spec,
|
|
256
|
+
) = self._make_group_specs(
|
|
257
|
+
group,
|
|
258
|
+
torchrl_agent_obs_specs,
|
|
259
|
+
torchrl_agent_act_specs,
|
|
260
|
+
torchrl_rew_spec,
|
|
261
|
+
)
|
|
262
|
+
action_spec[group] = group_action_spec
|
|
263
|
+
observation_spec[group] = group_observation_spec
|
|
264
|
+
reward_spec[group] = group_reward_spec
|
|
265
|
+
|
|
266
|
+
observation_spec.update(torchrl_state_spec)
|
|
267
|
+
self.done_spec = Composite(
|
|
268
|
+
{
|
|
269
|
+
"done": Categorical(n=2, shape=torch.Size((1,)), dtype=torch.bool),
|
|
270
|
+
},
|
|
271
|
+
)
|
|
272
|
+
self.action_spec = action_spec
|
|
273
|
+
self.observation_spec = observation_spec
|
|
274
|
+
self.reward_spec = reward_spec
|
|
275
|
+
|
|
276
|
+
def _make_group_specs(
|
|
277
|
+
self,
|
|
278
|
+
group: str,
|
|
279
|
+
torchrl_agent_obs_specs: list[TensorSpec],
|
|
280
|
+
torchrl_agent_act_specs: list[TensorSpec],
|
|
281
|
+
torchrl_rew_spec: list[TensorSpec],
|
|
282
|
+
):
|
|
283
|
+
# Agent specs
|
|
284
|
+
action_specs = []
|
|
285
|
+
observation_specs = []
|
|
286
|
+
reward_specs = []
|
|
287
|
+
|
|
288
|
+
for agent_name in self.group_map[group]:
|
|
289
|
+
agent_index = self.agent_names_to_indices_map[agent_name]
|
|
290
|
+
action_specs.append(
|
|
291
|
+
Composite(
|
|
292
|
+
{
|
|
293
|
+
"action": torchrl_agent_act_specs[
|
|
294
|
+
agent_index
|
|
295
|
+
] # shape = (n_actions_per_agent,)
|
|
296
|
+
},
|
|
297
|
+
)
|
|
298
|
+
)
|
|
299
|
+
observation_specs.append(
|
|
300
|
+
Composite(
|
|
301
|
+
{
|
|
302
|
+
"observation": torchrl_agent_obs_specs[
|
|
303
|
+
agent_index
|
|
304
|
+
] # shape = (n_obs_per_agent,)
|
|
305
|
+
},
|
|
306
|
+
)
|
|
307
|
+
)
|
|
308
|
+
reward_specs.append(
|
|
309
|
+
Composite({"reward": torchrl_rew_spec[agent_index]}) # shape = (1,)
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
# Create multi-agent specs
|
|
313
|
+
group_action_spec = torch.stack(
|
|
314
|
+
action_specs, dim=0
|
|
315
|
+
) # shape = (n_agents_in_group, n_actions_per_agent)
|
|
316
|
+
group_observation_spec = torch.stack(
|
|
317
|
+
observation_specs, dim=0
|
|
318
|
+
) # shape = (n_agents_in_group, n_obs_per_agent)
|
|
319
|
+
group_reward_spec = torch.stack(
|
|
320
|
+
reward_specs, dim=0
|
|
321
|
+
) # shape = (n_agents_in_group, 1)
|
|
322
|
+
return (
|
|
323
|
+
group_observation_spec,
|
|
324
|
+
group_action_spec,
|
|
325
|
+
group_reward_spec,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
def _check_kwargs(self, kwargs: dict):
|
|
329
|
+
meltingpot = self.lib
|
|
330
|
+
|
|
331
|
+
if "env" not in kwargs:
|
|
332
|
+
raise TypeError("Could not find environment key 'env' in kwargs.")
|
|
333
|
+
env = kwargs["env"]
|
|
334
|
+
if not isinstance(env, meltingpot.utils.substrates.substrate.Substrate):
|
|
335
|
+
raise TypeError(
|
|
336
|
+
"env is not of type 'meltingpot.utils.substrates.substrate.Substrate'."
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
def _init_env(self):
|
|
340
|
+
# Caching
|
|
341
|
+
self.cached_full_done_spec_zero = self.full_done_spec.zero()
|
|
342
|
+
|
|
343
|
+
def _set_seed(self, seed: int | None) -> None:
|
|
344
|
+
raise NotImplementedError(
|
|
345
|
+
"It is currently unclear how to set a seed in Meltingpot. "
|
|
346
|
+
"see https://github.com/google-deepmind/meltingpot/issues/129 to track the issue."
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
def _reset(
|
|
350
|
+
self, tensordict: TensorDictBase | None = None, **kwargs
|
|
351
|
+
) -> TensorDictBase:
|
|
352
|
+
self.num_cycles = 0
|
|
353
|
+
timestep = self._env.reset()
|
|
354
|
+
obs = timestep.observation
|
|
355
|
+
|
|
356
|
+
td = self.cached_full_done_spec_zero.clone()
|
|
357
|
+
|
|
358
|
+
for group, agent_names in self.group_map.items():
|
|
359
|
+
agent_tds = []
|
|
360
|
+
for index_in_group, agent_name in enumerate(agent_names):
|
|
361
|
+
global_index = self.agent_names_to_indices_map[agent_name]
|
|
362
|
+
agent_obs = self.observation_spec[group, "observation"][
|
|
363
|
+
index_in_group
|
|
364
|
+
].encode(_filter_global_state_from_dict(obs[global_index], world=False))
|
|
365
|
+
agent_td = TensorDict(
|
|
366
|
+
source={
|
|
367
|
+
"observation": agent_obs,
|
|
368
|
+
},
|
|
369
|
+
batch_size=self.batch_size,
|
|
370
|
+
device=self.device,
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
agent_tds.append(agent_td)
|
|
374
|
+
agent_tds = torch.stack(agent_tds, dim=0)
|
|
375
|
+
td.set(group, agent_tds)
|
|
376
|
+
|
|
377
|
+
# Global state
|
|
378
|
+
td.update(
|
|
379
|
+
_remove_world_prefix(_filter_global_state_from_dict(obs[0], world=True))
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
tensordict_out = TensorDict(
|
|
383
|
+
source=td,
|
|
384
|
+
batch_size=self.batch_size,
|
|
385
|
+
device=self.device,
|
|
386
|
+
)
|
|
387
|
+
return tensordict_out
|
|
388
|
+
|
|
389
|
+
def _step(
|
|
390
|
+
self,
|
|
391
|
+
tensordict: TensorDictBase,
|
|
392
|
+
) -> TensorDictBase:
|
|
393
|
+
action_dict = {}
|
|
394
|
+
for group, agents in self.group_map.items():
|
|
395
|
+
group_action = tensordict.get((group, "action"))
|
|
396
|
+
group_action_np = self.full_action_spec[group, "action"].to_numpy(
|
|
397
|
+
group_action
|
|
398
|
+
)
|
|
399
|
+
for index, agent in enumerate(agents):
|
|
400
|
+
action_dict[agent] = group_action_np[index]
|
|
401
|
+
|
|
402
|
+
actions = [action_dict[agent] for agent in self.agent_names]
|
|
403
|
+
timestep = self._env.step(actions)
|
|
404
|
+
self.num_cycles += 1
|
|
405
|
+
|
|
406
|
+
rewards = timestep.reward
|
|
407
|
+
done = timestep.last() or (
|
|
408
|
+
(self.num_cycles >= self.max_steps) if self.max_steps is not None else False
|
|
409
|
+
)
|
|
410
|
+
obs = timestep.observation
|
|
411
|
+
|
|
412
|
+
td = TensorDict(
|
|
413
|
+
{
|
|
414
|
+
"done": self.full_done_spec["done"].encode(done),
|
|
415
|
+
"terminated": self.full_done_spec["terminated"].encode(done),
|
|
416
|
+
},
|
|
417
|
+
batch_size=self.batch_size,
|
|
418
|
+
)
|
|
419
|
+
# Global state
|
|
420
|
+
td.update(
|
|
421
|
+
_remove_world_prefix(_filter_global_state_from_dict(obs[0], world=True))
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
for group, agent_names in self.group_map.items():
|
|
425
|
+
agent_tds = []
|
|
426
|
+
for index_in_group, agent_name in enumerate(agent_names):
|
|
427
|
+
global_index = self.agent_names_to_indices_map[agent_name]
|
|
428
|
+
agent_obs = self.observation_spec[group, "observation"][
|
|
429
|
+
index_in_group
|
|
430
|
+
].encode(_filter_global_state_from_dict(obs[global_index], world=False))
|
|
431
|
+
agent_reward = self.full_reward_spec[group, "reward"][
|
|
432
|
+
index_in_group
|
|
433
|
+
].encode(rewards[global_index])
|
|
434
|
+
agent_td = TensorDict(
|
|
435
|
+
source={
|
|
436
|
+
"observation": agent_obs,
|
|
437
|
+
"reward": agent_reward,
|
|
438
|
+
},
|
|
439
|
+
batch_size=self.batch_size,
|
|
440
|
+
device=self.device,
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
agent_tds.append(agent_td)
|
|
444
|
+
agent_tds = torch.stack(agent_tds, dim=0)
|
|
445
|
+
td.set(group, agent_tds)
|
|
446
|
+
|
|
447
|
+
return td
|
|
448
|
+
|
|
449
|
+
def get_rgb_image(self) -> torch.Tensor:
|
|
450
|
+
"""Returns an RGB image of the environment.
|
|
451
|
+
|
|
452
|
+
Returns:
|
|
453
|
+
a ``torch.Tensor`` containing image in format WHC.
|
|
454
|
+
|
|
455
|
+
"""
|
|
456
|
+
return torch.from_numpy(self._env.observation()[0][_WORLD_PREFIX + "RGB"])
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
class MeltingpotEnv(MeltingpotWrapper):
|
|
460
|
+
"""Meltingpot environment wrapper.
|
|
461
|
+
|
|
462
|
+
GitHub: https://github.com/google-deepmind/meltingpot
|
|
463
|
+
|
|
464
|
+
Paper: https://arxiv.org/abs/2211.13746
|
|
465
|
+
|
|
466
|
+
Melting Pot assesses generalization to novel social situations involving both familiar and unfamiliar individuals,
|
|
467
|
+
and has been designed to test a broad range of social interactions such as: cooperation, competition, deception,
|
|
468
|
+
reciprocation, trust, stubbornness and so on. Melting Pot offers researchers a set of over 50 multi-agent
|
|
469
|
+
reinforcement learning substrates (multi-agent games) on which to train agents, and over 256 unique test scenarios
|
|
470
|
+
on which to evaluate these trained agents.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
substrate(str or ml_collections.config_dict.ConfigDict): the meltingpot substrate to build.
|
|
474
|
+
Can be a string from :attr:`~.available_envs` or a ConfigDict for the substrate
|
|
475
|
+
|
|
476
|
+
Keyword Args:
|
|
477
|
+
max_steps (int, optional): Horizon of the task. Defaults to ``None`` (infinite horizon).
|
|
478
|
+
Each Meltingpot substrate can
|
|
479
|
+
be terminating or not. If ``max_steps`` is specified,
|
|
480
|
+
the scenario is also terminated (and the ``"terminated"`` flag is set) whenever this horizon is reached.
|
|
481
|
+
Unlike gym's ``TimeLimit`` transform or torchrl's :class:`~torchrl.envs.transforms.StepCounter`,
|
|
482
|
+
this argument will not set the ``"truncated"`` entry in the tensordict.
|
|
483
|
+
categorical_actions (bool, optional): if the environment actions are discrete, whether to transform
|
|
484
|
+
them to categorical or one-hot. Defaults to ``True``.
|
|
485
|
+
group_map (MarlGroupMapType or Dict[str, List[str]], optional): how to group agents in tensordicts for
|
|
486
|
+
input/output. By default, they will be all put
|
|
487
|
+
in one group named ``"agents"``.
|
|
488
|
+
Otherwise, a group map can be specified or selected from some premade options.
|
|
489
|
+
See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info.
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
Attributes:
|
|
493
|
+
group_map (Dict[str, List[str]]): how to group agents in tensordicts for
|
|
494
|
+
input/output. See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info.
|
|
495
|
+
agent_names (list of str): names of the agent in the environment
|
|
496
|
+
agent_names_to_indices_map (Dict[str, int]): dictionary mapping agent names to their index in the environment
|
|
497
|
+
available_envs (List[str]): the list of the scenarios available to build.
|
|
498
|
+
|
|
499
|
+
.. warning::
|
|
500
|
+
Meltingpot returns a single ``done`` flag which does not distinguish between
|
|
501
|
+
when the env reached ``max_steps`` and termination.
|
|
502
|
+
If you deem the ``truncation`` signal necessary, set ``max_steps`` to
|
|
503
|
+
``None`` and use a :class:`~torchrl.envs.transforms.StepCounter` transform.
|
|
504
|
+
|
|
505
|
+
Examples:
|
|
506
|
+
>>> from torchrl.envs.libs.meltingpot import MeltingpotEnv
|
|
507
|
+
>>> env_torchrl = MeltingpotEnv("commons_harvest__open")
|
|
508
|
+
>>> print(env_torchrl.rollout(max_steps=5))
|
|
509
|
+
TensorDict(
|
|
510
|
+
fields={
|
|
511
|
+
RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
512
|
+
agents: TensorDict(
|
|
513
|
+
fields={
|
|
514
|
+
action: Tensor(shape=torch.Size([5, 7]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
515
|
+
observation: TensorDict(
|
|
516
|
+
fields={
|
|
517
|
+
COLLECTIVE_REWARD: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
518
|
+
READY_TO_SHOOT: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
519
|
+
RGB: Tensor(shape=torch.Size([5, 7, 88, 88, 3]), device=cpu, dtype=torch.uint8, is_shared=False)},
|
|
520
|
+
batch_size=torch.Size([5, 7]),
|
|
521
|
+
device=cpu,
|
|
522
|
+
is_shared=False)},
|
|
523
|
+
batch_size=torch.Size([5, 7]),
|
|
524
|
+
device=cpu,
|
|
525
|
+
is_shared=False),
|
|
526
|
+
done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
527
|
+
next: TensorDict(
|
|
528
|
+
fields={
|
|
529
|
+
RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
530
|
+
agents: TensorDict(
|
|
531
|
+
fields={
|
|
532
|
+
observation: TensorDict(
|
|
533
|
+
fields={
|
|
534
|
+
COLLECTIVE_REWARD: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
535
|
+
READY_TO_SHOOT: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
536
|
+
RGB: Tensor(shape=torch.Size([5, 7, 88, 88, 3]), device=cpu, dtype=torch.uint8, is_shared=False)},
|
|
537
|
+
batch_size=torch.Size([5, 7]),
|
|
538
|
+
device=cpu,
|
|
539
|
+
is_shared=False),
|
|
540
|
+
reward: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False)},
|
|
541
|
+
batch_size=torch.Size([5, 7]),
|
|
542
|
+
device=cpu,
|
|
543
|
+
is_shared=False),
|
|
544
|
+
done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
545
|
+
terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
546
|
+
batch_size=torch.Size([5]),
|
|
547
|
+
device=cpu,
|
|
548
|
+
is_shared=False),
|
|
549
|
+
terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
550
|
+
batch_size=torch.Size([5]),
|
|
551
|
+
device=cpu,
|
|
552
|
+
is_shared=False)
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
"""
|
|
556
|
+
|
|
557
|
+
def __init__(
|
|
558
|
+
self,
|
|
559
|
+
substrate: str | ml_collections.config_dict.ConfigDict, # noqa
|
|
560
|
+
*,
|
|
561
|
+
max_steps: int | None = None,
|
|
562
|
+
categorical_actions: bool = True,
|
|
563
|
+
group_map: MarlGroupMapType
|
|
564
|
+
| dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP,
|
|
565
|
+
**kwargs,
|
|
566
|
+
):
|
|
567
|
+
if not _has_meltingpot:
|
|
568
|
+
raise ImportError(
|
|
569
|
+
f"meltingpot python package was not found. Please install this dependency. "
|
|
570
|
+
f"More info: {self.git_url}."
|
|
571
|
+
)
|
|
572
|
+
super().__init__(
|
|
573
|
+
substrate=substrate,
|
|
574
|
+
max_steps=max_steps,
|
|
575
|
+
categorical_actions=categorical_actions,
|
|
576
|
+
group_map=group_map,
|
|
577
|
+
**kwargs,
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
def _check_kwargs(self, kwargs: dict):
|
|
581
|
+
if "substrate" not in kwargs:
|
|
582
|
+
raise TypeError("Could not find environment key 'substrate' in kwargs.")
|
|
583
|
+
|
|
584
|
+
def _build_env(
|
|
585
|
+
self,
|
|
586
|
+
substrate: str | ml_collections.config_dict.ConfigDict, # noqa
|
|
587
|
+
) -> meltingpot.utils.substrates.substrate.Substrate: # noqa
|
|
588
|
+
from meltingpot import substrate as mp_substrate
|
|
589
|
+
|
|
590
|
+
if isinstance(substrate, str):
|
|
591
|
+
substrate_config = mp_substrate.get_config(substrate)
|
|
592
|
+
else:
|
|
593
|
+
substrate_config = substrate
|
|
594
|
+
|
|
595
|
+
return super()._build_env(
|
|
596
|
+
env=mp_substrate.build_from_config(
|
|
597
|
+
substrate_config, roles=substrate_config.default_player_roles
|
|
598
|
+
)
|
|
599
|
+
)
|