torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,652 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import importlib.util
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from tensordict import TensorDict, TensorDictBase
|
|
12
|
+
|
|
13
|
+
from torchrl.data.tensor_specs import (
|
|
14
|
+
Categorical,
|
|
15
|
+
Composite,
|
|
16
|
+
NonTensor,
|
|
17
|
+
OneHot,
|
|
18
|
+
Unbounded,
|
|
19
|
+
)
|
|
20
|
+
from torchrl.envs.common import _EnvWrapper
|
|
21
|
+
from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType
|
|
22
|
+
|
|
23
|
+
_has_pyspiel = importlib.util.find_spec("pyspiel") is not None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _get_envs():
|
|
27
|
+
if not _has_pyspiel:
|
|
28
|
+
raise ImportError(
|
|
29
|
+
"open_spiel not found. Consider downloading and installing "
|
|
30
|
+
f"open_spiel from {OpenSpielWrapper.git_url}."
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
import pyspiel
|
|
34
|
+
|
|
35
|
+
return [game.short_name for game in pyspiel.registered_games()]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class OpenSpielWrapper(_EnvWrapper):
|
|
39
|
+
"""Google DeepMind OpenSpiel environment wrapper.
|
|
40
|
+
|
|
41
|
+
GitHub: https://github.com/google-deepmind/open_spiel
|
|
42
|
+
|
|
43
|
+
Documentation: https://openspiel.readthedocs.io/en/latest/index.html
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
env (pyspiel.State): the game to wrap.
|
|
47
|
+
|
|
48
|
+
Keyword Args:
|
|
49
|
+
device (torch.device, optional): if provided, the device on which the data
|
|
50
|
+
is to be cast. Defaults to ``None``.
|
|
51
|
+
batch_size (torch.Size, optional): the batch size of the environment.
|
|
52
|
+
Defaults to ``torch.Size([])``.
|
|
53
|
+
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
|
|
54
|
+
for envs to be ``done`` just after :meth:`reset` is called.
|
|
55
|
+
Defaults to ``False``.
|
|
56
|
+
group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to
|
|
57
|
+
group agents in tensordicts for input/output. See
|
|
58
|
+
:class:`~torchrl.envs.utils.MarlGroupMapType` for more info.
|
|
59
|
+
Defaults to
|
|
60
|
+
:class:`~torchrl.envs.utils.MarlGroupMapType.ALL_IN_ONE_GROUP`.
|
|
61
|
+
categorical_actions (bool, optional): if ``True``, categorical specs
|
|
62
|
+
will be converted to the TorchRL equivalent
|
|
63
|
+
(:class:`torchrl.data.Categorical`), otherwise a one-hot encoding
|
|
64
|
+
will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``.
|
|
65
|
+
return_state (bool, optional): if ``True``, "state" is included in the
|
|
66
|
+
output of :meth:`reset` and :meth:`~step`. The state can be given
|
|
67
|
+
to :meth:`reset` to reset to that state, rather than resetting to
|
|
68
|
+
the initial state.
|
|
69
|
+
Defaults to ``False``.
|
|
70
|
+
|
|
71
|
+
Attributes:
|
|
72
|
+
available_envs: environments available to build
|
|
73
|
+
|
|
74
|
+
Examples:
|
|
75
|
+
>>> import pyspiel
|
|
76
|
+
>>> from torchrl.envs import OpenSpielWrapper
|
|
77
|
+
>>> from tensordict import TensorDict
|
|
78
|
+
>>> base_env = pyspiel.load_game('chess').new_initial_state()
|
|
79
|
+
>>> env = OpenSpielWrapper(base_env, return_state=True)
|
|
80
|
+
>>> td = env.reset()
|
|
81
|
+
>>> td = env.step(env.full_action_spec.rand())
|
|
82
|
+
>>> print(td)
|
|
83
|
+
TensorDict(
|
|
84
|
+
fields={
|
|
85
|
+
agents: TensorDict(
|
|
86
|
+
fields={
|
|
87
|
+
action: Tensor(shape=torch.Size([2, 4672]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
88
|
+
batch_size=torch.Size([]),
|
|
89
|
+
device=None,
|
|
90
|
+
is_shared=False),
|
|
91
|
+
next: TensorDict(
|
|
92
|
+
fields={
|
|
93
|
+
agents: TensorDict(
|
|
94
|
+
fields={
|
|
95
|
+
observation: Tensor(shape=torch.Size([2, 20, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
96
|
+
reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
97
|
+
batch_size=torch.Size([2]),
|
|
98
|
+
device=None,
|
|
99
|
+
is_shared=False),
|
|
100
|
+
current_player: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
101
|
+
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
102
|
+
state: NonTensorData(data=FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
|
|
103
|
+
3009
|
|
104
|
+
, batch_size=torch.Size([]), device=None),
|
|
105
|
+
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
106
|
+
batch_size=torch.Size([]),
|
|
107
|
+
device=None,
|
|
108
|
+
is_shared=False)},
|
|
109
|
+
batch_size=torch.Size([]),
|
|
110
|
+
device=None,
|
|
111
|
+
is_shared=False)
|
|
112
|
+
>>> print(env.available_envs)
|
|
113
|
+
['2048', 'add_noise', 'amazons', 'backgammon', ...]
|
|
114
|
+
|
|
115
|
+
:meth:`reset` can restore a specific state, rather than the initial
|
|
116
|
+
state, as long as ``return_state=True``.
|
|
117
|
+
|
|
118
|
+
>>> import pyspiel
|
|
119
|
+
>>> from torchrl.envs import OpenSpielWrapper
|
|
120
|
+
>>> from tensordict import TensorDict
|
|
121
|
+
>>> base_env = pyspiel.load_game('chess').new_initial_state()
|
|
122
|
+
>>> env = OpenSpielWrapper(base_env, return_state=True)
|
|
123
|
+
>>> td = env.reset()
|
|
124
|
+
>>> td = env.step(env.full_action_spec.rand())
|
|
125
|
+
>>> td_restore = td["next"]
|
|
126
|
+
>>> td = env.step(env.full_action_spec.rand())
|
|
127
|
+
>>> # Current state is not equal `td_restore`
|
|
128
|
+
>>> (td["next"] == td_restore).all()
|
|
129
|
+
False
|
|
130
|
+
>>> td = env.reset(td_restore)
|
|
131
|
+
>>> # After resetting, now the current state is equal to `td_restore`
|
|
132
|
+
>>> (td == td_restore).all()
|
|
133
|
+
True
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
git_url = "https://github.com/google-deepmind/open_spiel"
|
|
137
|
+
libname = "pyspiel"
|
|
138
|
+
_lib = None
|
|
139
|
+
|
|
140
|
+
@_classproperty
|
|
141
|
+
def lib(cls):
|
|
142
|
+
if cls._lib is not None:
|
|
143
|
+
return cls._lib
|
|
144
|
+
|
|
145
|
+
import pyspiel
|
|
146
|
+
|
|
147
|
+
cls._lib = pyspiel
|
|
148
|
+
return pyspiel
|
|
149
|
+
|
|
150
|
+
@_classproperty
|
|
151
|
+
def available_envs(cls):
|
|
152
|
+
if not _has_pyspiel:
|
|
153
|
+
return []
|
|
154
|
+
return _get_envs()
|
|
155
|
+
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
env=None,
|
|
159
|
+
*,
|
|
160
|
+
group_map: MarlGroupMapType
|
|
161
|
+
| dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP,
|
|
162
|
+
categorical_actions: bool = False,
|
|
163
|
+
return_state: bool = False,
|
|
164
|
+
**kwargs,
|
|
165
|
+
):
|
|
166
|
+
if env is not None:
|
|
167
|
+
kwargs["env"] = env
|
|
168
|
+
|
|
169
|
+
self.group_map = group_map
|
|
170
|
+
self.categorical_actions = categorical_actions
|
|
171
|
+
self.return_state = return_state
|
|
172
|
+
self._cached_game = None
|
|
173
|
+
super().__init__(**kwargs)
|
|
174
|
+
|
|
175
|
+
# `reset` allows resetting to any state, including a terminal state
|
|
176
|
+
self._allow_done_after_reset = True
|
|
177
|
+
|
|
178
|
+
def _check_kwargs(self, kwargs: dict):
|
|
179
|
+
pyspiel = self.lib
|
|
180
|
+
if "env" not in kwargs:
|
|
181
|
+
raise TypeError("Could not find environment key 'env' in kwargs.")
|
|
182
|
+
env = kwargs["env"]
|
|
183
|
+
if not isinstance(env, pyspiel.State):
|
|
184
|
+
raise TypeError("env is not of type 'pyspiel.State'.")
|
|
185
|
+
|
|
186
|
+
def _build_env(self, env, requires_grad: bool = False, **kwargs):
|
|
187
|
+
game = env.get_game()
|
|
188
|
+
game_type = game.get_type()
|
|
189
|
+
|
|
190
|
+
if game.max_chance_outcomes() != 0:
|
|
191
|
+
raise NotImplementedError(
|
|
192
|
+
f"The game '{game_type.short_name}' has chance nodes, which are not yet supported."
|
|
193
|
+
)
|
|
194
|
+
if game_type.dynamics == self.lib.GameType.Dynamics.MEAN_FIELD:
|
|
195
|
+
# NOTE: It is unclear from the OpenSpiel documentation what exactly
|
|
196
|
+
# "mean field" means exactly, and there is no documentation on the
|
|
197
|
+
# several games which have it.
|
|
198
|
+
raise RuntimeError(
|
|
199
|
+
f"Mean field games like '{game_type.name}' are not yet " "supported."
|
|
200
|
+
)
|
|
201
|
+
self.parallel = game_type.dynamics == self.lib.GameType.Dynamics.SIMULTANEOUS
|
|
202
|
+
self.requires_grad = requires_grad
|
|
203
|
+
return env
|
|
204
|
+
|
|
205
|
+
def _init_env(self):
|
|
206
|
+
self._update_action_mask()
|
|
207
|
+
|
|
208
|
+
def _get_game(self):
|
|
209
|
+
if self._cached_game is None:
|
|
210
|
+
self._cached_game = self._env.get_game()
|
|
211
|
+
return self._cached_game
|
|
212
|
+
|
|
213
|
+
def _make_group_map(self, group_map, agent_names):
|
|
214
|
+
if group_map is None:
|
|
215
|
+
group_map = MarlGroupMapType.ONE_GROUP_PER_AGENT.get_group_map(agent_names)
|
|
216
|
+
elif isinstance(group_map, MarlGroupMapType):
|
|
217
|
+
group_map = group_map.get_group_map(agent_names)
|
|
218
|
+
check_marl_grouping(group_map, agent_names)
|
|
219
|
+
return group_map
|
|
220
|
+
|
|
221
|
+
def _make_group_specs(
|
|
222
|
+
self,
|
|
223
|
+
env,
|
|
224
|
+
group: str,
|
|
225
|
+
):
|
|
226
|
+
observation_specs = []
|
|
227
|
+
action_specs = []
|
|
228
|
+
reward_specs = []
|
|
229
|
+
game = env.get_game()
|
|
230
|
+
|
|
231
|
+
for _ in self.group_map[group]:
|
|
232
|
+
observation_spec = Composite()
|
|
233
|
+
|
|
234
|
+
if self.has_observation:
|
|
235
|
+
observation_spec["observation"] = Unbounded(
|
|
236
|
+
shape=(*game.observation_tensor_shape(),),
|
|
237
|
+
device=self.device,
|
|
238
|
+
domain="continuous",
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
if self.has_information_state:
|
|
242
|
+
observation_spec["information_state"] = Unbounded(
|
|
243
|
+
shape=(*game.information_state_tensor_shape(),),
|
|
244
|
+
device=self.device,
|
|
245
|
+
domain="continuous",
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
observation_specs.append(observation_spec)
|
|
249
|
+
|
|
250
|
+
action_spec_cls = Categorical if self.categorical_actions else OneHot
|
|
251
|
+
action_specs.append(
|
|
252
|
+
Composite(
|
|
253
|
+
action=action_spec_cls(
|
|
254
|
+
env.num_distinct_actions(),
|
|
255
|
+
dtype=torch.int64,
|
|
256
|
+
device=self.device,
|
|
257
|
+
)
|
|
258
|
+
)
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
reward_specs.append(
|
|
262
|
+
Composite(
|
|
263
|
+
reward=Unbounded(
|
|
264
|
+
shape=(1,),
|
|
265
|
+
device=self.device,
|
|
266
|
+
domain="continuous",
|
|
267
|
+
)
|
|
268
|
+
)
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
group_observation_spec = torch.stack(
|
|
272
|
+
observation_specs, dim=0
|
|
273
|
+
) # shape = (n_agents, n_obser_per_agent)
|
|
274
|
+
group_action_spec = torch.stack(
|
|
275
|
+
action_specs, dim=0
|
|
276
|
+
) # shape = (n_agents, n_actions_per_agent)
|
|
277
|
+
group_reward_spec = torch.stack(reward_specs, dim=0) # shape = (n_agents, 1)
|
|
278
|
+
|
|
279
|
+
return (
|
|
280
|
+
group_observation_spec,
|
|
281
|
+
group_action_spec,
|
|
282
|
+
group_reward_spec,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
def _make_specs(self, env: pyspiel.State) -> None: # noqa: F821
|
|
286
|
+
self.agent_names = [f"player_{index}" for index in range(env.num_players())]
|
|
287
|
+
self.agent_names_to_indices_map = {
|
|
288
|
+
agent_name: i for i, agent_name in enumerate(self.agent_names)
|
|
289
|
+
}
|
|
290
|
+
self.group_map = self._make_group_map(self.group_map, self.agent_names)
|
|
291
|
+
self.done_spec = Categorical(
|
|
292
|
+
n=2,
|
|
293
|
+
shape=torch.Size((1,)),
|
|
294
|
+
dtype=torch.bool,
|
|
295
|
+
device=self.device,
|
|
296
|
+
)
|
|
297
|
+
game = env.get_game()
|
|
298
|
+
game_type = game.get_type()
|
|
299
|
+
# In OpenSpiel, a game's state may have either an "observation" tensor,
|
|
300
|
+
# an "information state" tensor, or both. If the OpenSpiel game does not
|
|
301
|
+
# have one of these, then its corresponding accessor functions raise an
|
|
302
|
+
# error, so we must avoid calling them.
|
|
303
|
+
self.has_observation = game_type.provides_observation_tensor
|
|
304
|
+
self.has_information_state = game_type.provides_information_state_tensor
|
|
305
|
+
|
|
306
|
+
observation_spec = {}
|
|
307
|
+
action_spec = {}
|
|
308
|
+
reward_spec = {}
|
|
309
|
+
|
|
310
|
+
for group in self.group_map.keys():
|
|
311
|
+
(
|
|
312
|
+
group_observation_spec,
|
|
313
|
+
group_action_spec,
|
|
314
|
+
group_reward_spec,
|
|
315
|
+
) = self._make_group_specs(
|
|
316
|
+
env,
|
|
317
|
+
group,
|
|
318
|
+
)
|
|
319
|
+
observation_spec[group] = group_observation_spec
|
|
320
|
+
action_spec[group] = group_action_spec
|
|
321
|
+
reward_spec[group] = group_reward_spec
|
|
322
|
+
|
|
323
|
+
if self.return_state:
|
|
324
|
+
observation_spec["state"] = NonTensor([])
|
|
325
|
+
|
|
326
|
+
observation_spec["current_player"] = Unbounded(
|
|
327
|
+
shape=(),
|
|
328
|
+
dtype=torch.int,
|
|
329
|
+
device=self.device,
|
|
330
|
+
domain="discrete",
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
self.observation_spec = Composite(observation_spec)
|
|
334
|
+
self.action_spec = Composite(action_spec)
|
|
335
|
+
self.reward_spec = Composite(reward_spec)
|
|
336
|
+
|
|
337
|
+
def _set_seed(self, seed: int | None) -> None:
|
|
338
|
+
if seed is not None:
|
|
339
|
+
raise NotImplementedError("This environment has no seed.")
|
|
340
|
+
|
|
341
|
+
def current_player(self):
|
|
342
|
+
return self._env.current_player()
|
|
343
|
+
|
|
344
|
+
def _update_action_mask(self):
|
|
345
|
+
if self._env.is_terminal():
|
|
346
|
+
agents_acting = []
|
|
347
|
+
else:
|
|
348
|
+
agents_acting = [
|
|
349
|
+
self.agent_names
|
|
350
|
+
if self.parallel
|
|
351
|
+
else self.agent_names[self._env.current_player()]
|
|
352
|
+
]
|
|
353
|
+
for group, agents in self.group_map.items():
|
|
354
|
+
action_masks = []
|
|
355
|
+
for agent in agents:
|
|
356
|
+
agent_index = self.agent_names_to_indices_map[agent]
|
|
357
|
+
if agent in agents_acting:
|
|
358
|
+
action_mask = torch.zeros(
|
|
359
|
+
self._env.num_distinct_actions(),
|
|
360
|
+
device=self.device,
|
|
361
|
+
dtype=torch.bool,
|
|
362
|
+
)
|
|
363
|
+
action_mask[self._env.legal_actions(agent_index)] = True
|
|
364
|
+
else:
|
|
365
|
+
action_mask = torch.zeros(
|
|
366
|
+
self._env.num_distinct_actions(),
|
|
367
|
+
device=self.device,
|
|
368
|
+
dtype=torch.bool,
|
|
369
|
+
)
|
|
370
|
+
# In OpenSpiel parallel games, non-acting players are
|
|
371
|
+
# expected to take action 0.
|
|
372
|
+
# https://openspiel.readthedocs.io/en/latest/api_reference/state_apply_action.html
|
|
373
|
+
action_mask[0] = True
|
|
374
|
+
action_masks.append(action_mask)
|
|
375
|
+
self.full_action_spec[group, "action"].update_mask(
|
|
376
|
+
torch.stack(action_masks, dim=0)
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
def _make_td_out(self, exclude_reward=False):
|
|
380
|
+
done = torch.tensor(
|
|
381
|
+
self._env.is_terminal(), device=self.device, dtype=torch.bool
|
|
382
|
+
)
|
|
383
|
+
current_player = torch.tensor(
|
|
384
|
+
self.current_player(), device=self.device, dtype=torch.int
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
source = {
|
|
388
|
+
"done": done,
|
|
389
|
+
"terminated": done.clone(),
|
|
390
|
+
"current_player": current_player,
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
if self.return_state:
|
|
394
|
+
source["state"] = self._env.serialize()
|
|
395
|
+
|
|
396
|
+
reward = self._env.returns()
|
|
397
|
+
|
|
398
|
+
for group, agent_names in self.group_map.items():
|
|
399
|
+
agent_tds = []
|
|
400
|
+
|
|
401
|
+
for agent in agent_names:
|
|
402
|
+
agent_index = self.agent_names_to_indices_map[agent]
|
|
403
|
+
agent_source = {}
|
|
404
|
+
if self.has_observation:
|
|
405
|
+
observation_shape = self._get_game().observation_tensor_shape()
|
|
406
|
+
agent_source["observation"] = self._to_tensor(
|
|
407
|
+
self._env.observation_tensor(agent_index)
|
|
408
|
+
).reshape(observation_shape)
|
|
409
|
+
|
|
410
|
+
if self.has_information_state:
|
|
411
|
+
information_state_shape = (
|
|
412
|
+
self._get_game().information_state_tensor_shape()
|
|
413
|
+
)
|
|
414
|
+
agent_source["information_state"] = self._to_tensor(
|
|
415
|
+
self._env.information_state_tensor(agent_index)
|
|
416
|
+
).reshape(information_state_shape)
|
|
417
|
+
|
|
418
|
+
if not exclude_reward:
|
|
419
|
+
agent_source["reward"] = self._to_tensor(reward[agent_index])
|
|
420
|
+
|
|
421
|
+
agent_td = TensorDict(
|
|
422
|
+
source=agent_source,
|
|
423
|
+
batch_size=self.batch_size,
|
|
424
|
+
device=self.device,
|
|
425
|
+
)
|
|
426
|
+
agent_tds.append(agent_td)
|
|
427
|
+
|
|
428
|
+
source[group] = torch.stack(agent_tds, dim=0)
|
|
429
|
+
|
|
430
|
+
tensordict_out = TensorDict(
|
|
431
|
+
source=source,
|
|
432
|
+
batch_size=self.batch_size,
|
|
433
|
+
device=self.device,
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
return tensordict_out
|
|
437
|
+
|
|
438
|
+
def _get_action_from_tensor(self, tensor):
|
|
439
|
+
if not self.categorical_actions:
|
|
440
|
+
action = torch.argmax(tensor, dim=-1)
|
|
441
|
+
else:
|
|
442
|
+
action = tensor
|
|
443
|
+
return action
|
|
444
|
+
|
|
445
|
+
def _step_parallel(self, tensordict: TensorDictBase):
|
|
446
|
+
actions = [0] * self._env.num_players()
|
|
447
|
+
for group, agents in self.group_map.items():
|
|
448
|
+
for index_in_group, agent in enumerate(agents):
|
|
449
|
+
agent_index = self.agent_names_to_indices_map[agent]
|
|
450
|
+
action_tensor = tensordict[group, "action"][index_in_group]
|
|
451
|
+
action = self._get_action_from_tensor(action_tensor)
|
|
452
|
+
actions[agent_index] = action
|
|
453
|
+
|
|
454
|
+
self._env.apply_actions(actions)
|
|
455
|
+
|
|
456
|
+
def _step_sequential(self, tensordict: TensorDictBase):
|
|
457
|
+
agent_index = self._env.current_player()
|
|
458
|
+
|
|
459
|
+
# If the game has ended, do nothing
|
|
460
|
+
if agent_index == self.lib.PlayerId.TERMINAL:
|
|
461
|
+
return
|
|
462
|
+
|
|
463
|
+
agent = self.agent_names[agent_index]
|
|
464
|
+
agent_group = None
|
|
465
|
+
agent_index_in_group = None
|
|
466
|
+
|
|
467
|
+
for group, agents in self.group_map.items():
|
|
468
|
+
if agent in agents:
|
|
469
|
+
agent_group = group
|
|
470
|
+
agent_index_in_group = agents.index(agent)
|
|
471
|
+
break
|
|
472
|
+
|
|
473
|
+
action_tensor = tensordict[agent_group, "action"][agent_index_in_group]
|
|
474
|
+
action = self._get_action_from_tensor(action_tensor)
|
|
475
|
+
self._env.apply_action(action)
|
|
476
|
+
|
|
477
|
+
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
478
|
+
if self.parallel:
|
|
479
|
+
self._step_parallel(tensordict)
|
|
480
|
+
else:
|
|
481
|
+
self._step_sequential(tensordict)
|
|
482
|
+
|
|
483
|
+
self._update_action_mask()
|
|
484
|
+
return self._make_td_out()
|
|
485
|
+
|
|
486
|
+
def _to_tensor(self, value):
|
|
487
|
+
return torch.tensor(value, device=self.device, dtype=torch.float32)
|
|
488
|
+
|
|
489
|
+
def _reset(
|
|
490
|
+
self, tensordict: TensorDictBase | None = None, **kwargs
|
|
491
|
+
) -> TensorDictBase:
|
|
492
|
+
game = self._get_game()
|
|
493
|
+
|
|
494
|
+
if tensordict is not None and "state" in tensordict:
|
|
495
|
+
new_env = game.deserialize_state(tensordict["state"])
|
|
496
|
+
else:
|
|
497
|
+
new_env = game.new_initial_state()
|
|
498
|
+
|
|
499
|
+
self._env = new_env
|
|
500
|
+
self._update_action_mask()
|
|
501
|
+
return self._make_td_out(exclude_reward=True)
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
class OpenSpielEnv(OpenSpielWrapper):
|
|
505
|
+
"""Google DeepMind OpenSpiel environment wrapper built with the game string.
|
|
506
|
+
|
|
507
|
+
GitHub: https://github.com/google-deepmind/open_spiel
|
|
508
|
+
|
|
509
|
+
Documentation: https://openspiel.readthedocs.io/en/latest/index.html
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
game_string (str): the name of the game to wrap. Must be part of
|
|
513
|
+
:attr:`~.available_envs`.
|
|
514
|
+
|
|
515
|
+
Keyword Args:
|
|
516
|
+
device (torch.device, optional): if provided, the device on which the data
|
|
517
|
+
is to be cast. Defaults to ``None``.
|
|
518
|
+
batch_size (torch.Size, optional): the batch size of the environment.
|
|
519
|
+
Defaults to ``torch.Size([])``.
|
|
520
|
+
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
|
|
521
|
+
for envs to be ``done`` just after :meth:`reset` is called.
|
|
522
|
+
Defaults to ``False``.
|
|
523
|
+
group_map (MarlGroupMapType or Dict[str, List[str]]], optional): how to
|
|
524
|
+
group agents in tensordicts for input/output. See
|
|
525
|
+
:class:`~torchrl.envs.utils.MarlGroupMapType` for more info.
|
|
526
|
+
Defaults to
|
|
527
|
+
:class:`~torchrl.envs.utils.MarlGroupMapType.ALL_IN_ONE_GROUP`.
|
|
528
|
+
categorical_actions (bool, optional): if ``True``, categorical specs
|
|
529
|
+
will be converted to the TorchRL equivalent
|
|
530
|
+
(:class:`torchrl.data.Categorical`), otherwise a one-hot encoding
|
|
531
|
+
will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``.
|
|
532
|
+
return_state (bool, optional): if ``True``, "state" is included in the
|
|
533
|
+
output of :meth:`reset` and :meth:`~step`. The state can be given
|
|
534
|
+
to :meth:`reset` to reset to that state, rather than resetting to
|
|
535
|
+
the initial state.
|
|
536
|
+
Defaults to ``False``.
|
|
537
|
+
|
|
538
|
+
Attributes:
|
|
539
|
+
available_envs: environments available to build
|
|
540
|
+
|
|
541
|
+
Examples:
|
|
542
|
+
>>> from torchrl.envs import OpenSpielEnv
|
|
543
|
+
>>> from tensordict import TensorDict
|
|
544
|
+
>>> env = OpenSpielEnv("chess", return_state=True)
|
|
545
|
+
>>> td = env.reset()
|
|
546
|
+
>>> td = env.step(env.full_action_spec.rand())
|
|
547
|
+
>>> print(td)
|
|
548
|
+
TensorDict(
|
|
549
|
+
fields={
|
|
550
|
+
agents: TensorDict(
|
|
551
|
+
fields={
|
|
552
|
+
action: Tensor(shape=torch.Size([2, 4672]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
553
|
+
batch_size=torch.Size([]),
|
|
554
|
+
device=None,
|
|
555
|
+
is_shared=False),
|
|
556
|
+
next: TensorDict(
|
|
557
|
+
fields={
|
|
558
|
+
agents: TensorDict(
|
|
559
|
+
fields={
|
|
560
|
+
observation: Tensor(shape=torch.Size([2, 20, 8, 8]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
561
|
+
reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
562
|
+
batch_size=torch.Size([2]),
|
|
563
|
+
device=None,
|
|
564
|
+
is_shared=False),
|
|
565
|
+
current_player: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
566
|
+
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
567
|
+
state: NonTensorData(data=FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
|
|
568
|
+
674
|
|
569
|
+
, batch_size=torch.Size([]), device=None),
|
|
570
|
+
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
571
|
+
batch_size=torch.Size([]),
|
|
572
|
+
device=None,
|
|
573
|
+
is_shared=False)},
|
|
574
|
+
batch_size=torch.Size([]),
|
|
575
|
+
device=None,
|
|
576
|
+
is_shared=False)
|
|
577
|
+
>>> print(env.available_envs)
|
|
578
|
+
['2048', 'add_noise', 'amazons', 'backgammon', ...]
|
|
579
|
+
|
|
580
|
+
:meth:`reset` can restore a specific state, rather than the initial state,
|
|
581
|
+
as long as ``return_state=True``.
|
|
582
|
+
|
|
583
|
+
>>> from torchrl.envs import OpenSpielEnv
|
|
584
|
+
>>> from tensordict import TensorDict
|
|
585
|
+
>>> env = OpenSpielEnv("chess", return_state=True)
|
|
586
|
+
>>> td = env.reset()
|
|
587
|
+
>>> td = env.step(env.full_action_spec.rand())
|
|
588
|
+
>>> td_restore = td["next"]
|
|
589
|
+
>>> td = env.step(env.full_action_spec.rand())
|
|
590
|
+
>>> # Current state is not equal `td_restore`
|
|
591
|
+
>>> (td["next"] == td_restore).all()
|
|
592
|
+
False
|
|
593
|
+
>>> td = env.reset(td_restore)
|
|
594
|
+
>>> # After resetting, now the current state is equal to `td_restore`
|
|
595
|
+
>>> (td == td_restore).all()
|
|
596
|
+
True
|
|
597
|
+
"""
|
|
598
|
+
|
|
599
|
+
def __init__(
|
|
600
|
+
self,
|
|
601
|
+
game_string,
|
|
602
|
+
*,
|
|
603
|
+
group_map: MarlGroupMapType
|
|
604
|
+
| dict[str, list[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP,
|
|
605
|
+
categorical_actions=False,
|
|
606
|
+
return_state: bool = False,
|
|
607
|
+
**kwargs,
|
|
608
|
+
):
|
|
609
|
+
kwargs["game_string"] = game_string
|
|
610
|
+
super().__init__(
|
|
611
|
+
group_map=group_map,
|
|
612
|
+
categorical_actions=categorical_actions,
|
|
613
|
+
return_state=return_state,
|
|
614
|
+
**kwargs,
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
def _build_env(
|
|
618
|
+
self,
|
|
619
|
+
game_string: str,
|
|
620
|
+
**kwargs,
|
|
621
|
+
) -> pyspiel.State: # noqa: F821
|
|
622
|
+
if not _has_pyspiel:
|
|
623
|
+
raise ImportError(
|
|
624
|
+
f"open_spiel not found, unable to create {game_string}. Consider "
|
|
625
|
+
f"downloading and installing open_spiel from {self.git_url}"
|
|
626
|
+
)
|
|
627
|
+
requires_grad = kwargs.pop("requires_grad", False)
|
|
628
|
+
parameters = kwargs.pop("parameters", None)
|
|
629
|
+
if kwargs:
|
|
630
|
+
raise ValueError("kwargs not supported.")
|
|
631
|
+
|
|
632
|
+
if parameters:
|
|
633
|
+
game = self.lib.load_game(game_string, parameters=parameters)
|
|
634
|
+
else:
|
|
635
|
+
game = self.lib.load_game(game_string)
|
|
636
|
+
|
|
637
|
+
env = game.new_initial_state()
|
|
638
|
+
return super()._build_env(
|
|
639
|
+
env,
|
|
640
|
+
requires_grad=requires_grad,
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
@property
|
|
644
|
+
def game_string(self):
|
|
645
|
+
return self._constructor_kwargs["game_string"]
|
|
646
|
+
|
|
647
|
+
def _check_kwargs(self, kwargs: dict):
|
|
648
|
+
if "game_string" not in kwargs:
|
|
649
|
+
raise TypeError("Expected 'game_string' to be part of kwargs")
|
|
650
|
+
|
|
651
|
+
def __repr__(self) -> str:
|
|
652
|
+
return f"{self.__class__.__name__}(env={self.game_string}, batch_size={self.batch_size}, device={self.device})"
|