torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from .chess import ChessEnv
|
|
7
|
+
from .llm import LLMHashingEnv
|
|
8
|
+
from .pendulum import PendulumEnv
|
|
9
|
+
from .tictactoeenv import TicTacToeEnv
|
|
10
|
+
|
|
11
|
+
__all__ = ["ChessEnv", "LLMHashingEnv", "PendulumEnv", "TicTacToeEnv"]
|
|
@@ -0,0 +1,617 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib.util
|
|
8
|
+
import io
|
|
9
|
+
import pathlib
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from tensordict import TensorDict, TensorDictBase
|
|
13
|
+
from torchrl.data.tensor_specs import (
|
|
14
|
+
Binary,
|
|
15
|
+
Bounded,
|
|
16
|
+
Categorical,
|
|
17
|
+
Composite,
|
|
18
|
+
NonTensor,
|
|
19
|
+
Unbounded,
|
|
20
|
+
)
|
|
21
|
+
from torchrl.envs import EnvBase
|
|
22
|
+
from torchrl.envs.common import _EnvPostInit
|
|
23
|
+
from torchrl.envs.utils import _classproperty
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class _ChessMeta(_EnvPostInit):
|
|
27
|
+
def __call__(cls, *args, **kwargs):
|
|
28
|
+
instance = super().__call__(*args, **kwargs)
|
|
29
|
+
include_hash = kwargs.get("include_hash")
|
|
30
|
+
include_hash_inv = kwargs.get("include_hash_inv")
|
|
31
|
+
if include_hash:
|
|
32
|
+
from torchrl.envs import Hash
|
|
33
|
+
|
|
34
|
+
in_keys = []
|
|
35
|
+
out_keys = []
|
|
36
|
+
in_keys_inv = [] if include_hash_inv else None
|
|
37
|
+
out_keys_inv = [] if include_hash_inv else None
|
|
38
|
+
|
|
39
|
+
def maybe_add_keys(condition, in_key, out_key):
|
|
40
|
+
if condition:
|
|
41
|
+
in_keys.append(in_key)
|
|
42
|
+
out_keys.append(out_key)
|
|
43
|
+
if include_hash_inv:
|
|
44
|
+
in_keys_inv.append(in_key)
|
|
45
|
+
out_keys_inv.append(out_key)
|
|
46
|
+
|
|
47
|
+
maybe_add_keys(instance.include_san, "san", "san_hash")
|
|
48
|
+
maybe_add_keys(instance.include_fen, "fen", "fen_hash")
|
|
49
|
+
maybe_add_keys(instance.include_pgn, "pgn", "pgn_hash")
|
|
50
|
+
|
|
51
|
+
instance = instance.append_transform(
|
|
52
|
+
Hash(in_keys, out_keys, in_keys_inv, out_keys_inv)
|
|
53
|
+
)
|
|
54
|
+
elif include_hash_inv:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
"'include_hash_inv=True' can only be set if"
|
|
57
|
+
f"'include_hash=True', but got 'include_hash={include_hash}'."
|
|
58
|
+
)
|
|
59
|
+
if kwargs.get("mask_actions", True):
|
|
60
|
+
from torchrl.envs import ActionMask
|
|
61
|
+
|
|
62
|
+
instance = instance.append_transform(ActionMask())
|
|
63
|
+
return instance
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class ChessEnv(EnvBase, metaclass=_ChessMeta):
|
|
67
|
+
r"""A chess environment that follows the TorchRL API.
|
|
68
|
+
|
|
69
|
+
This environment simulates a chess game using the `chess` library. It supports various state representations
|
|
70
|
+
and can be configured to include different types of observations such as SAN, FEN, PGN, and legal moves.
|
|
71
|
+
|
|
72
|
+
Requires: the `chess` library. More info `here <https://python-chess.readthedocs.io/en/latest/>`__.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
stateful (bool): Whether to keep track of the internal state of the board.
|
|
76
|
+
If False, the state will be stored in the observation and passed back
|
|
77
|
+
to the environment on each call. Default: ``True``.
|
|
78
|
+
include_san (bool): Whether to include SAN (Standard Algebraic Notation) in the observations. Default: ``False``.
|
|
79
|
+
The ``"san"`` entry corresponding to ``rollout["action"]`` will be found in ``rollout["next", "san"]``,
|
|
80
|
+
whereas the value at the root ``rollout["san"]`` will correspond to the value of the san preceding the
|
|
81
|
+
same index action.
|
|
82
|
+
include_fen (bool): Whether to include FEN (Forsyth-Edwards Notation) in the observations. Default: ``False``.
|
|
83
|
+
include_pgn (bool): Whether to include PGN (Portable Game Notation) in the observations. Default: ``False``.
|
|
84
|
+
include_legal_moves (bool): Whether to include legal moves in the observations. Default: ``False``.
|
|
85
|
+
include_hash (bool): Whether to include hash transformations in the environment. Default: ``False``.
|
|
86
|
+
mask_actions (bool): if ``True``, a :class:`~torchrl.envs.ActionMask` transform will be appended
|
|
87
|
+
to the env to make sure that the actions are properly masked. Default: ``True``.
|
|
88
|
+
pixels (bool): Whether to include pixel-based observations of the board. Default: ``False``.
|
|
89
|
+
|
|
90
|
+
.. note::
|
|
91
|
+
The action spec is a :class:`~torchrl.data.Categorical` with a number of actions equal to the number of possible SAN moves.
|
|
92
|
+
The action space is structured as a categorical distribution over all possible SAN moves, with the legal moves
|
|
93
|
+
being a subset of this space. The environment uses a mask to ensure only legal moves are selected.
|
|
94
|
+
|
|
95
|
+
Examples:
|
|
96
|
+
>>> import torch
|
|
97
|
+
>>> from torchrl.envs import ChessEnv
|
|
98
|
+
>>> _ = torch.manual_seed(0)
|
|
99
|
+
>>> env = ChessEnv(include_fen=True, include_san=True, include_pgn=True, include_legal_moves=True)
|
|
100
|
+
>>> print(env)
|
|
101
|
+
TransformedEnv(
|
|
102
|
+
env=ChessEnv(),
|
|
103
|
+
transform=ActionMask(keys=['action', 'action_mask']))
|
|
104
|
+
>>> r = env.reset()
|
|
105
|
+
>>> print(env.rand_step(r))
|
|
106
|
+
TensorDict(
|
|
107
|
+
fields={
|
|
108
|
+
action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
109
|
+
action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
110
|
+
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
111
|
+
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1, batch_size=torch.Size([]), device=None),
|
|
112
|
+
legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
113
|
+
next: TensorDict(
|
|
114
|
+
fields={
|
|
115
|
+
action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
116
|
+
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
117
|
+
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/5P2/8/PPPPP1PP/RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None),
|
|
118
|
+
legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
119
|
+
pgn: NonTensorData(data=[Event "?"]
|
|
120
|
+
[Site "?"]
|
|
121
|
+
[Date "????.??.??"]
|
|
122
|
+
[Round "?"]
|
|
123
|
+
[White "?"]
|
|
124
|
+
[Black "?"]
|
|
125
|
+
[Result "*"]
|
|
126
|
+
|
|
127
|
+
1. f4 *, batch_size=torch.Size([]), device=None),
|
|
128
|
+
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
129
|
+
san: NonTensorData(data=f4, batch_size=torch.Size([]), device=None),
|
|
130
|
+
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
131
|
+
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
132
|
+
batch_size=torch.Size([]),
|
|
133
|
+
device=None,
|
|
134
|
+
is_shared=False),
|
|
135
|
+
pgn: NonTensorData(data=[Event "?"]
|
|
136
|
+
[Site "?"]
|
|
137
|
+
[Date "????.??.??"]
|
|
138
|
+
[Round "?"]
|
|
139
|
+
[White "?"]
|
|
140
|
+
[Black "?"]
|
|
141
|
+
[Result "*"]
|
|
142
|
+
|
|
143
|
+
*, batch_size=torch.Size([]), device=None),
|
|
144
|
+
san: NonTensorData(data=<start>, batch_size=torch.Size([]), device=None),
|
|
145
|
+
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
146
|
+
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
147
|
+
batch_size=torch.Size([]),
|
|
148
|
+
device=None,
|
|
149
|
+
is_shared=False)
|
|
150
|
+
>>> print(env.rollout(1000))
|
|
151
|
+
TensorDict(
|
|
152
|
+
fields={
|
|
153
|
+
action: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
154
|
+
action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
155
|
+
done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
156
|
+
fen: NonTensorStack(
|
|
157
|
+
['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQ...,
|
|
158
|
+
batch_size=torch.Size([96]),
|
|
159
|
+
device=None),
|
|
160
|
+
legal_moves: Tensor(shape=torch.Size([96, 219]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
161
|
+
next: TensorDict(
|
|
162
|
+
fields={
|
|
163
|
+
action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
164
|
+
done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
165
|
+
fen: NonTensorStack(
|
|
166
|
+
['rnbqkbnr/pppppppp/8/8/8/5N2/PPPPPPPP/RNBQKB1R b ...,
|
|
167
|
+
batch_size=torch.Size([96]),
|
|
168
|
+
device=None),
|
|
169
|
+
legal_moves: Tensor(shape=torch.Size([96, 219]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
170
|
+
pgn: NonTensorStack(
|
|
171
|
+
['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R...,
|
|
172
|
+
batch_size=torch.Size([96]),
|
|
173
|
+
device=None),
|
|
174
|
+
reward: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
175
|
+
san: NonTensorStack(
|
|
176
|
+
['Nf3', 'Na6', 'c4', 'f6', 'h4', 'Rb8', 'Na3', 'Ra...,
|
|
177
|
+
batch_size=torch.Size([96]),
|
|
178
|
+
device=None),
|
|
179
|
+
terminated: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
180
|
+
turn: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
181
|
+
batch_size=torch.Size([96]),
|
|
182
|
+
device=None,
|
|
183
|
+
is_shared=False),
|
|
184
|
+
pgn: NonTensorStack(
|
|
185
|
+
['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R...,
|
|
186
|
+
batch_size=torch.Size([96]),
|
|
187
|
+
device=None),
|
|
188
|
+
san: NonTensorStack(
|
|
189
|
+
['<start>', 'Nf3', 'Na6', 'c4', 'f6', 'h4', 'Rb8',...,
|
|
190
|
+
batch_size=torch.Size([96]),
|
|
191
|
+
device=None),
|
|
192
|
+
terminated: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
193
|
+
turn: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
194
|
+
batch_size=torch.Size([96]),
|
|
195
|
+
device=None,
|
|
196
|
+
is_shared=False)
|
|
197
|
+
""" # noqa: D301
|
|
198
|
+
|
|
199
|
+
_hash_table: dict[int, str] = {}
|
|
200
|
+
_PGN_RESTART = """[Event "?"]
|
|
201
|
+
[Site "?"]
|
|
202
|
+
[Date "????.??.??"]
|
|
203
|
+
[Round "?"]
|
|
204
|
+
[White "?"]
|
|
205
|
+
[Black "?"]
|
|
206
|
+
[Result "*"]
|
|
207
|
+
|
|
208
|
+
*"""
|
|
209
|
+
|
|
210
|
+
@_classproperty
|
|
211
|
+
def lib(cls):
|
|
212
|
+
try:
|
|
213
|
+
import chess
|
|
214
|
+
import chess.pgn
|
|
215
|
+
except ImportError:
|
|
216
|
+
raise ImportError(
|
|
217
|
+
"The `chess` library could not be found. Make sure you installed it through `pip install chess`."
|
|
218
|
+
)
|
|
219
|
+
return chess
|
|
220
|
+
|
|
221
|
+
_san_moves = []
|
|
222
|
+
|
|
223
|
+
@_classproperty
|
|
224
|
+
def san_moves(cls):
|
|
225
|
+
if not cls._san_moves:
|
|
226
|
+
with open(pathlib.Path(__file__).parent / "san_moves.txt", "r+") as f:
|
|
227
|
+
cls._san_moves.extend(f.read().split("\n"))
|
|
228
|
+
return cls._san_moves
|
|
229
|
+
|
|
230
|
+
def _legal_moves_to_index(
|
|
231
|
+
self,
|
|
232
|
+
tensordict: TensorDictBase | None = None,
|
|
233
|
+
board: chess.Board | None = None, # noqa: F821
|
|
234
|
+
return_mask: bool = False,
|
|
235
|
+
pad: bool = False,
|
|
236
|
+
) -> torch.Tensor:
|
|
237
|
+
if not self.stateful:
|
|
238
|
+
if tensordict is None:
|
|
239
|
+
# trust the board
|
|
240
|
+
pass
|
|
241
|
+
elif self.include_fen:
|
|
242
|
+
fen = tensordict.get("fen", None)
|
|
243
|
+
fen = fen.data
|
|
244
|
+
self.board.set_fen(fen)
|
|
245
|
+
board = self.board
|
|
246
|
+
elif self.include_pgn:
|
|
247
|
+
pgn = tensordict.get("pgn")
|
|
248
|
+
pgn = pgn.data
|
|
249
|
+
board = self._pgn_to_board(pgn, self.board)
|
|
250
|
+
|
|
251
|
+
if board is None:
|
|
252
|
+
board = self.board
|
|
253
|
+
|
|
254
|
+
indices = torch.tensor(
|
|
255
|
+
[self._san_moves.index(board.san(m)) for m in board.legal_moves],
|
|
256
|
+
dtype=torch.int64,
|
|
257
|
+
)
|
|
258
|
+
mask = None
|
|
259
|
+
if return_mask:
|
|
260
|
+
mask = self._move_index_to_mask(indices)
|
|
261
|
+
if pad:
|
|
262
|
+
indices = torch.nn.functional.pad(
|
|
263
|
+
indices, [0, 218 - indices.numel() + 1], value=len(self.san_moves)
|
|
264
|
+
)
|
|
265
|
+
if return_mask:
|
|
266
|
+
return indices, mask
|
|
267
|
+
return indices
|
|
268
|
+
|
|
269
|
+
@classmethod
|
|
270
|
+
def _move_index_to_mask(cls, indices: torch.Tensor) -> torch.Tensor:
|
|
271
|
+
return torch.zeros(len(cls.san_moves), dtype=torch.bool).index_fill_(
|
|
272
|
+
0, indices, True
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
def __init__(
|
|
276
|
+
self,
|
|
277
|
+
*,
|
|
278
|
+
stateful: bool = True,
|
|
279
|
+
include_san: bool = False,
|
|
280
|
+
include_fen: bool = False,
|
|
281
|
+
include_pgn: bool = False,
|
|
282
|
+
include_legal_moves: bool = False,
|
|
283
|
+
include_hash: bool = False,
|
|
284
|
+
include_hash_inv: bool = False,
|
|
285
|
+
mask_actions: bool = True,
|
|
286
|
+
pixels: bool = False,
|
|
287
|
+
):
|
|
288
|
+
chess = self.lib
|
|
289
|
+
super().__init__()
|
|
290
|
+
self.full_observation_spec = Composite(
|
|
291
|
+
turn=Categorical(n=2, dtype=torch.bool, shape=()),
|
|
292
|
+
)
|
|
293
|
+
self.include_san = include_san
|
|
294
|
+
self.include_fen = include_fen
|
|
295
|
+
self.include_pgn = include_pgn
|
|
296
|
+
self.mask_actions = mask_actions
|
|
297
|
+
self.include_legal_moves = include_legal_moves
|
|
298
|
+
if include_legal_moves:
|
|
299
|
+
# 218 max possible legal moves per chess board position
|
|
300
|
+
# https://www.stmintz.com/ccc/index.php?id=424966
|
|
301
|
+
# len(self.san_moves)+1 is the padding value
|
|
302
|
+
self.full_observation_spec["legal_moves"] = Bounded(
|
|
303
|
+
0, 1 + len(self.san_moves), shape=(218,), dtype=torch.int64
|
|
304
|
+
)
|
|
305
|
+
if include_san:
|
|
306
|
+
self.full_observation_spec["san"] = NonTensor(shape=(), example_data="Nc6")
|
|
307
|
+
if include_pgn:
|
|
308
|
+
self.full_observation_spec["pgn"] = NonTensor(
|
|
309
|
+
shape=(), example_data=self._PGN_RESTART
|
|
310
|
+
)
|
|
311
|
+
if include_fen:
|
|
312
|
+
self.full_observation_spec["fen"] = NonTensor(shape=(), example_data="any")
|
|
313
|
+
if not stateful and not (include_pgn or include_fen):
|
|
314
|
+
raise RuntimeError(
|
|
315
|
+
"At least one state representation (pgn or fen) must be enabled when stateful "
|
|
316
|
+
f"is {stateful}."
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
self.stateful = stateful
|
|
320
|
+
|
|
321
|
+
# state_spec is loosely defined as such - it's not really an issue that extra keys
|
|
322
|
+
# can go missing but it allows us to reset the env using fen passed to the reset
|
|
323
|
+
# method.
|
|
324
|
+
self.full_state_spec = self.full_observation_spec.clone()
|
|
325
|
+
|
|
326
|
+
self.pixels = pixels
|
|
327
|
+
if pixels:
|
|
328
|
+
if importlib.util.find_spec("cairosvg") is None:
|
|
329
|
+
raise ImportError(
|
|
330
|
+
"Please install cairosvg to use this environment with pixel rendering."
|
|
331
|
+
)
|
|
332
|
+
if importlib.util.find_spec("torchvision") is None:
|
|
333
|
+
raise ImportError(
|
|
334
|
+
"Please install torchvision to use this environment with pixel rendering."
|
|
335
|
+
)
|
|
336
|
+
self.full_observation_spec["pixels"] = Unbounded(
|
|
337
|
+
shape=(3, 390, 390), dtype=torch.uint8
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
self.full_action_spec = Composite(
|
|
341
|
+
action=Categorical(n=len(self.san_moves), shape=(), dtype=torch.int64)
|
|
342
|
+
)
|
|
343
|
+
self.full_reward_spec = Composite(
|
|
344
|
+
reward=Unbounded(shape=(1,), dtype=torch.float32)
|
|
345
|
+
)
|
|
346
|
+
if self.mask_actions:
|
|
347
|
+
self.full_observation_spec["action_mask"] = Binary(
|
|
348
|
+
n=len(self.san_moves), dtype=torch.bool
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
# done spec generated automatically
|
|
352
|
+
self.board = chess.Board()
|
|
353
|
+
if self.stateful:
|
|
354
|
+
self.action_spec.set_provisional_n(len(list(self.board.legal_moves)))
|
|
355
|
+
|
|
356
|
+
def _is_done(self, board):
|
|
357
|
+
return board.is_game_over() | board.is_fifty_moves()
|
|
358
|
+
|
|
359
|
+
def all_actions(self, tensordict: TensorDictBase | None = None) -> TensorDictBase:
|
|
360
|
+
if not self.mask_actions:
|
|
361
|
+
raise RuntimeError(
|
|
362
|
+
"Cannot generate legal actions since 'mask_actions=False' was "
|
|
363
|
+
"set. If you really want to generate all actions, not just "
|
|
364
|
+
"legal ones, call 'env.full_action_spec.enumerate()'."
|
|
365
|
+
)
|
|
366
|
+
return super().all_actions(tensordict)
|
|
367
|
+
|
|
368
|
+
def _reset(self, tensordict=None):
|
|
369
|
+
fen = None
|
|
370
|
+
pgn = None
|
|
371
|
+
if tensordict is not None:
|
|
372
|
+
dest = tensordict.empty()
|
|
373
|
+
if self.include_fen:
|
|
374
|
+
fen = tensordict.get("fen", None)
|
|
375
|
+
if fen is not None:
|
|
376
|
+
fen = fen.data
|
|
377
|
+
elif self.include_pgn:
|
|
378
|
+
pgn = tensordict.get("pgn", None)
|
|
379
|
+
if pgn is not None:
|
|
380
|
+
pgn = pgn.data
|
|
381
|
+
else:
|
|
382
|
+
dest = TensorDict()
|
|
383
|
+
|
|
384
|
+
if fen is None and pgn is None:
|
|
385
|
+
self.board.reset()
|
|
386
|
+
elif fen is not None:
|
|
387
|
+
self.board.set_fen(fen)
|
|
388
|
+
if self._is_done(self.board):
|
|
389
|
+
raise ValueError(
|
|
390
|
+
"Cannot reset to a fen that is a gameover state." f" fen: {fen}"
|
|
391
|
+
)
|
|
392
|
+
elif pgn is not None:
|
|
393
|
+
self.board = self._pgn_to_board(pgn)
|
|
394
|
+
|
|
395
|
+
if self.include_fen and fen is None:
|
|
396
|
+
fen = self.board.fen()
|
|
397
|
+
if self.include_pgn and pgn is None:
|
|
398
|
+
pgn = self._board_to_pgn(self.board)
|
|
399
|
+
|
|
400
|
+
turn = self.board.turn
|
|
401
|
+
if self.include_san:
|
|
402
|
+
if self.board.move_stack:
|
|
403
|
+
move = self.board.peek()
|
|
404
|
+
else:
|
|
405
|
+
move = None
|
|
406
|
+
if move is None:
|
|
407
|
+
dest.set("san", "<start>")
|
|
408
|
+
else:
|
|
409
|
+
dest.set("san", self.board.san(move))
|
|
410
|
+
if self.include_fen:
|
|
411
|
+
dest.set("fen", fen)
|
|
412
|
+
if self.include_pgn:
|
|
413
|
+
dest.set("pgn", pgn)
|
|
414
|
+
dest.set("turn", turn)
|
|
415
|
+
if self.include_legal_moves:
|
|
416
|
+
moves_idx = self._legal_moves_to_index(
|
|
417
|
+
board=self.board, pad=True, return_mask=self.mask_actions
|
|
418
|
+
)
|
|
419
|
+
if self.mask_actions:
|
|
420
|
+
moves_idx, mask = moves_idx
|
|
421
|
+
dest.set("action_mask", mask)
|
|
422
|
+
dest.set("legal_moves", moves_idx)
|
|
423
|
+
elif self.mask_actions:
|
|
424
|
+
dest.set(
|
|
425
|
+
"action_mask",
|
|
426
|
+
self._legal_moves_to_index(
|
|
427
|
+
board=self.board, pad=True, return_mask=True
|
|
428
|
+
)[1],
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
if self.pixels:
|
|
432
|
+
dest.set("pixels", self._get_tensor_image(board=self.board))
|
|
433
|
+
return dest
|
|
434
|
+
|
|
435
|
+
_cairosvg_lib = None
|
|
436
|
+
|
|
437
|
+
@_classproperty
|
|
438
|
+
def _cairosvg(cls):
|
|
439
|
+
csvg = cls._cairosvg_lib
|
|
440
|
+
if csvg is None:
|
|
441
|
+
import cairosvg
|
|
442
|
+
|
|
443
|
+
csvg = cls._cairosvg_lib = cairosvg
|
|
444
|
+
return csvg
|
|
445
|
+
|
|
446
|
+
_torchvision_lib = None
|
|
447
|
+
|
|
448
|
+
@_classproperty
|
|
449
|
+
def _torchvision(cls):
|
|
450
|
+
tv = cls._torchvision_lib
|
|
451
|
+
if tv is None:
|
|
452
|
+
import torchvision
|
|
453
|
+
|
|
454
|
+
tv = cls._torchvision_lib = torchvision
|
|
455
|
+
return tv
|
|
456
|
+
|
|
457
|
+
@classmethod
|
|
458
|
+
def _get_tensor_image(cls, board):
|
|
459
|
+
try:
|
|
460
|
+
from PIL import Image
|
|
461
|
+
|
|
462
|
+
svg = board._repr_svg_()
|
|
463
|
+
# Convert SVG to PNG using cairosvg
|
|
464
|
+
png_data = io.BytesIO()
|
|
465
|
+
cls._cairosvg.svg2png(bytestring=svg.encode("utf-8"), write_to=png_data)
|
|
466
|
+
png_data.seek(0)
|
|
467
|
+
# Open the PNG image using Pillow
|
|
468
|
+
img = Image.open(png_data)
|
|
469
|
+
img = cls._torchvision.transforms.functional.pil_to_tensor(img)
|
|
470
|
+
except ImportError:
|
|
471
|
+
raise ImportError(
|
|
472
|
+
"Chess rendering requires cairosvg, PIL and torchvision to be installed."
|
|
473
|
+
)
|
|
474
|
+
return img
|
|
475
|
+
|
|
476
|
+
@classmethod
|
|
477
|
+
def _pgn_to_board(
|
|
478
|
+
cls, pgn_string: str, board: chess.Board | None = None # noqa: F821
|
|
479
|
+
) -> chess.Board: # noqa: F821
|
|
480
|
+
pgn_io = io.StringIO(pgn_string)
|
|
481
|
+
game = cls.lib.pgn.read_game(pgn_io)
|
|
482
|
+
if board is None:
|
|
483
|
+
board = cls.lib.Board()
|
|
484
|
+
else:
|
|
485
|
+
board.reset()
|
|
486
|
+
for move in game.mainline_moves():
|
|
487
|
+
board.push(move)
|
|
488
|
+
return board
|
|
489
|
+
|
|
490
|
+
@classmethod
|
|
491
|
+
def _add_move_to_pgn(cls, pgn_string: str, move: chess.Move) -> str: # noqa: F821
|
|
492
|
+
pgn_io = io.StringIO(pgn_string)
|
|
493
|
+
game = cls.lib.pgn.read_game(pgn_io)
|
|
494
|
+
if game is None:
|
|
495
|
+
raise ValueError("Invalid PGN string")
|
|
496
|
+
game.end().add_variation(move)
|
|
497
|
+
return str(game)
|
|
498
|
+
|
|
499
|
+
@classmethod
|
|
500
|
+
def _board_to_pgn(cls, board: chess.Board) -> str: # noqa: F821
|
|
501
|
+
game = cls.lib.pgn.Game.from_board(board)
|
|
502
|
+
pgn_string = str(game)
|
|
503
|
+
return pgn_string
|
|
504
|
+
|
|
505
|
+
def get_legal_moves(self, tensordict=None, uci=False):
|
|
506
|
+
"""List the legal moves in a position.
|
|
507
|
+
|
|
508
|
+
To choose one of the actions, the "action" key can be set to the index
|
|
509
|
+
of the move in this list.
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
tensordict (TensorDict, optional): Tensordict containing the fen
|
|
513
|
+
string of a position. Required if not stateful. If stateful,
|
|
514
|
+
this argument is ignored and the current state of the env is
|
|
515
|
+
used instead.
|
|
516
|
+
|
|
517
|
+
uci (bool, optional): If ``False``, moves are given in SAN format.
|
|
518
|
+
If ``True``, moves are given in UCI format. Default is
|
|
519
|
+
``False``.
|
|
520
|
+
|
|
521
|
+
"""
|
|
522
|
+
board = self.board
|
|
523
|
+
if not self.stateful:
|
|
524
|
+
if tensordict is None:
|
|
525
|
+
raise ValueError(
|
|
526
|
+
"tensordict must be given since this env is not stateful"
|
|
527
|
+
)
|
|
528
|
+
fen = tensordict.get("fen").data
|
|
529
|
+
board.set_fen(fen)
|
|
530
|
+
moves = board.legal_moves
|
|
531
|
+
|
|
532
|
+
if uci:
|
|
533
|
+
return [board.uci(move) for move in moves]
|
|
534
|
+
else:
|
|
535
|
+
return [board.san(move) for move in moves]
|
|
536
|
+
|
|
537
|
+
def _step(self, tensordict):
|
|
538
|
+
# action
|
|
539
|
+
action = tensordict.get("action")
|
|
540
|
+
board = self.board
|
|
541
|
+
|
|
542
|
+
pgn = None
|
|
543
|
+
fen = None
|
|
544
|
+
if not self.stateful:
|
|
545
|
+
if self.include_fen:
|
|
546
|
+
fen = tensordict.get("fen").data
|
|
547
|
+
board.set_fen(fen)
|
|
548
|
+
elif self.include_pgn:
|
|
549
|
+
pgn = tensordict.get("pgn").data
|
|
550
|
+
board = self._pgn_to_board(pgn, board)
|
|
551
|
+
else:
|
|
552
|
+
raise RuntimeError(
|
|
553
|
+
"Not enough information to deduce the board. If stateful=False, include_pgn or include_fen must be True."
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
san = self.san_moves[action]
|
|
557
|
+
board.push_san(san)
|
|
558
|
+
|
|
559
|
+
dest = tensordict.empty()
|
|
560
|
+
|
|
561
|
+
# Collect data
|
|
562
|
+
if self.include_fen:
|
|
563
|
+
fen = board.fen()
|
|
564
|
+
dest.set("fen", fen)
|
|
565
|
+
|
|
566
|
+
if self.include_pgn:
|
|
567
|
+
if pgn is not None:
|
|
568
|
+
pgn = self._add_move_to_pgn(pgn, board.move_stack[-1])
|
|
569
|
+
else:
|
|
570
|
+
pgn = self._board_to_pgn(board)
|
|
571
|
+
dest.set("pgn", pgn)
|
|
572
|
+
|
|
573
|
+
if self.include_san:
|
|
574
|
+
dest.set("san", san)
|
|
575
|
+
|
|
576
|
+
if self.include_legal_moves:
|
|
577
|
+
moves_idx = self._legal_moves_to_index(
|
|
578
|
+
board=board, pad=True, return_mask=self.mask_actions
|
|
579
|
+
)
|
|
580
|
+
if self.mask_actions:
|
|
581
|
+
moves_idx, mask = moves_idx
|
|
582
|
+
dest.set("action_mask", mask)
|
|
583
|
+
dest.set("legal_moves", moves_idx)
|
|
584
|
+
elif self.mask_actions:
|
|
585
|
+
dest.set(
|
|
586
|
+
"action_mask",
|
|
587
|
+
self._legal_moves_to_index(
|
|
588
|
+
board=self.board, pad=True, return_mask=True
|
|
589
|
+
)[1],
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
turn = torch.tensor(board.turn)
|
|
593
|
+
done = self._is_done(board)
|
|
594
|
+
if board.is_checkmate():
|
|
595
|
+
# turn flips after every move, even if the game is over
|
|
596
|
+
# winner = not turn
|
|
597
|
+
reward_val = 1 # if winner == self.lib.WHITE else 0
|
|
598
|
+
elif done:
|
|
599
|
+
reward_val = 0.5
|
|
600
|
+
else:
|
|
601
|
+
reward_val = 0.0
|
|
602
|
+
|
|
603
|
+
reward = torch.tensor([reward_val], dtype=torch.float32)
|
|
604
|
+
dest.set("reward", reward)
|
|
605
|
+
dest.set("turn", turn)
|
|
606
|
+
dest.set("done", torch.tensor([done]))
|
|
607
|
+
dest.set("terminated", torch.tensor([done]))
|
|
608
|
+
if self.pixels:
|
|
609
|
+
dest.set("pixels", self._get_tensor_image(board=self.board))
|
|
610
|
+
return dest
|
|
611
|
+
|
|
612
|
+
def _set_seed(self, *args, **kwargs) -> None:
|
|
613
|
+
...
|
|
614
|
+
|
|
615
|
+
def cardinality(self, tensordict: TensorDictBase | None = None) -> int:
|
|
616
|
+
self._set_action_space(tensordict)
|
|
617
|
+
return self.action_spec.cardinality()
|