torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import contextlib
|
|
9
|
+
import os
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@contextlib.contextmanager
|
|
15
|
+
def _cuda_visible_devices(devices: list[torch.device | int]):
|
|
16
|
+
devices = [torch.device(d).index if not isinstance(d, int) else d for d in devices]
|
|
17
|
+
CUDA_VISIBLE_DEVICES = os.getenv("CUDA_VISIBLE_DEVICES")
|
|
18
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, devices))
|
|
19
|
+
yield
|
|
20
|
+
if CUDA_VISIBLE_DEVICES:
|
|
21
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES
|
|
22
|
+
else:
|
|
23
|
+
os.unsetenv("CUDA_VISIBLE_DEVICES")
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from .scores import (
|
|
6
|
+
EXP3Score,
|
|
7
|
+
MCTSScore,
|
|
8
|
+
MCTSScores,
|
|
9
|
+
PUCTScore,
|
|
10
|
+
UCB1TunedScore,
|
|
11
|
+
UCBScore,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"EXP3Score",
|
|
16
|
+
"MCTSScore",
|
|
17
|
+
"MCTSScores",
|
|
18
|
+
"PUCTScore",
|
|
19
|
+
"UCB1TunedScore",
|
|
20
|
+
"UCBScore",
|
|
21
|
+
]
|
|
@@ -0,0 +1,579 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
import math
|
|
9
|
+
import warnings
|
|
10
|
+
from abc import abstractmethod
|
|
11
|
+
from enum import Enum
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
|
|
15
|
+
from tensordict import NestedKey, TensorDictBase
|
|
16
|
+
from tensordict.nn import TensorDictModuleBase
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MCTSScore(TensorDictModuleBase):
|
|
20
|
+
"""Abstract base class for MCTS score computation modules."""
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def forward(self, node: TensorDictBase) -> TensorDictBase:
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PUCTScore(MCTSScore):
|
|
28
|
+
"""Computes the PUCT (Polynomial Upper Confidence Trees) score for MCTS.
|
|
29
|
+
|
|
30
|
+
PUCT is a widely used score in MCTS algorithms, notably in AlphaGo and AlphaZero,
|
|
31
|
+
to balance exploration and exploitation. It incorporates prior probabilities from a
|
|
32
|
+
policy network, encouraging exploration of actions deemed promising by the policy,
|
|
33
|
+
while also considering visit counts and accumulated rewards.
|
|
34
|
+
|
|
35
|
+
The formula used is:
|
|
36
|
+
`score = (win_count / visits) + c * prior_prob * sqrt(total_visits) / (1 + visits)`
|
|
37
|
+
|
|
38
|
+
Where:
|
|
39
|
+
- `win_count`: Sum of rewards (or win counts) for the action.
|
|
40
|
+
- `visits`: Visit count for the action.
|
|
41
|
+
- `total_visits`: Visit count of the parent node (N).
|
|
42
|
+
- `prior_prob`: Prior probability of selecting the action (e.g., from a policy network).
|
|
43
|
+
- `c`: The exploration constant, controlling the trade-off between exploitation
|
|
44
|
+
(first term) and exploration (second term).
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
c (float): The exploration constant.
|
|
48
|
+
win_count_key (NestedKey, optional): Key for the tensor in the input `TensorDictBase`
|
|
49
|
+
containing the sum of rewards (or win counts) for each action.
|
|
50
|
+
Defaults to "win_count".
|
|
51
|
+
visits_key (NestedKey, optional): Key for the tensor containing the visit
|
|
52
|
+
count for each action. Defaults to "visits".
|
|
53
|
+
total_visits_key (NestedKey, optional): Key for the tensor (or scalar)
|
|
54
|
+
representing the visit count of the parent node (N). Defaults to "total_visits".
|
|
55
|
+
prior_prob_key (NestedKey, optional): Key for the tensor containing the
|
|
56
|
+
prior probabilities for each action. Defaults to "prior_prob".
|
|
57
|
+
score_key (NestedKey, optional): Key where the calculated PUCT scores
|
|
58
|
+
will be stored in the output `TensorDictBase`. Defaults to "score".
|
|
59
|
+
|
|
60
|
+
Input Keys:
|
|
61
|
+
- `win_count_key` (torch.Tensor): Tensor of shape (..., num_actions)
|
|
62
|
+
or matching `visits_key`.
|
|
63
|
+
- `visits_key` (torch.Tensor): Tensor of shape (..., num_actions). If an action
|
|
64
|
+
has zero visits, its exploitation term (win_count / visits) will result in NaN
|
|
65
|
+
if win_count is also zero, or +/-inf if win_count is non-zero. The exploration
|
|
66
|
+
term will still be valid due to `(1 + visits)`.
|
|
67
|
+
- `total_visits_key` (torch.Tensor): Scalar or tensor broadcastable to other inputs,
|
|
68
|
+
representing the parent node's visit count.
|
|
69
|
+
- `prior_prob_key` (torch.Tensor): Tensor of shape (..., num_actions) containing
|
|
70
|
+
prior probabilities.
|
|
71
|
+
|
|
72
|
+
Output Keys:
|
|
73
|
+
- `score_key` (torch.Tensor): Tensor of the same shape as `visits_key`, containing
|
|
74
|
+
the calculated PUCT scores.
|
|
75
|
+
|
|
76
|
+
Example:
|
|
77
|
+
```python
|
|
78
|
+
from tensordict import TensorDict
|
|
79
|
+
from torchrl.modules.mcts.scores import PUCTScore
|
|
80
|
+
|
|
81
|
+
# Create a PUCTScore instance
|
|
82
|
+
puct = PUCTScore(c=1.5)
|
|
83
|
+
|
|
84
|
+
# Define a TensorDict with required keys
|
|
85
|
+
node = TensorDict(
|
|
86
|
+
{
|
|
87
|
+
"win_count": torch.tensor([10.0, 20.0]),
|
|
88
|
+
"visits": torch.tensor([5.0, 10.0]),
|
|
89
|
+
"total_visits": torch.tensor(50.0),
|
|
90
|
+
"prior_prob": torch.tensor([0.6, 0.4]),
|
|
91
|
+
},
|
|
92
|
+
batch_size=[],
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Compute the PUCT scores
|
|
96
|
+
result = puct(node)
|
|
97
|
+
print(result["score"]) # Output: Tensor with PUCT scores
|
|
98
|
+
```
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
c: float
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
*,
|
|
106
|
+
c: float,
|
|
107
|
+
win_count_key: NestedKey = "win_count",
|
|
108
|
+
visits_key: NestedKey = "visits",
|
|
109
|
+
total_visits_key: NestedKey = "total_visits",
|
|
110
|
+
prior_prob_key: NestedKey = "prior_prob",
|
|
111
|
+
score_key: NestedKey = "score",
|
|
112
|
+
):
|
|
113
|
+
super().__init__()
|
|
114
|
+
self.c = c
|
|
115
|
+
self.win_count_key = win_count_key
|
|
116
|
+
self.visits_key = visits_key
|
|
117
|
+
self.total_visits_key = total_visits_key
|
|
118
|
+
self.prior_prob_key = prior_prob_key
|
|
119
|
+
self.score_key = score_key
|
|
120
|
+
self.in_keys = [
|
|
121
|
+
self.win_count_key,
|
|
122
|
+
self.prior_prob_key,
|
|
123
|
+
self.total_visits_key,
|
|
124
|
+
self.visits_key,
|
|
125
|
+
]
|
|
126
|
+
self.out_keys = [self.score_key]
|
|
127
|
+
|
|
128
|
+
def forward(self, node: TensorDictBase) -> TensorDictBase:
|
|
129
|
+
win_count = node.get(self.win_count_key)
|
|
130
|
+
visits = node.get(self.visits_key)
|
|
131
|
+
n_total = node.get(self.total_visits_key)
|
|
132
|
+
prior_prob = node.get(self.prior_prob_key)
|
|
133
|
+
# Handle broadcasting for batched inputs
|
|
134
|
+
if n_total.ndim > 0 and n_total.ndim < visits.ndim:
|
|
135
|
+
n_total = n_total.unsqueeze(-1)
|
|
136
|
+
node.set(
|
|
137
|
+
self.score_key,
|
|
138
|
+
(win_count / visits) + self.c * prior_prob * n_total.sqrt() / (1 + visits),
|
|
139
|
+
)
|
|
140
|
+
return node
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class UCBScore(MCTSScore):
|
|
144
|
+
"""Computes the UCB (Upper Confidence Bound) score, specifically UCB1, for MCTS.
|
|
145
|
+
|
|
146
|
+
UCB1 is a classic algorithm for the multi-armed bandit problem that balances
|
|
147
|
+
exploration and exploitation. In MCTS, it's used to select which action to
|
|
148
|
+
explore from a given node. The score encourages trying actions with high
|
|
149
|
+
empirical rewards and actions that have been visited less frequently.
|
|
150
|
+
|
|
151
|
+
The formula used is:
|
|
152
|
+
`score = (win_count / visits) + c * sqrt(total_visits) / (1 + visits)`
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
c (float): The exploration constant. A common value is `sqrt(2)`.
|
|
156
|
+
win_count_key (NestedKey, optional): Key for the tensor in the input `TensorDictBase`
|
|
157
|
+
containing the sum of rewards (or win counts) for each action.
|
|
158
|
+
Defaults to "win_count".
|
|
159
|
+
visits_key (NestedKey, optional): Key for the tensor containing the visit
|
|
160
|
+
count for each action. Defaults to "visits".
|
|
161
|
+
total_visits_key (NestedKey, optional): Key for the tensor (or scalar)
|
|
162
|
+
representing the visit count of the parent node (N). This is used in the
|
|
163
|
+
exploration term. Defaults to "total_visits".
|
|
164
|
+
score_key (NestedKey, optional): Key where the calculated UCB scores
|
|
165
|
+
will be stored in the output `TensorDictBase`. Defaults to "score".
|
|
166
|
+
|
|
167
|
+
Input Keys:
|
|
168
|
+
- `win_count_key` (torch.Tensor): Tensor of shape (..., num_actions).
|
|
169
|
+
- `visits_key` (torch.Tensor): Tensor of shape (..., num_actions).
|
|
170
|
+
- `total_visits_key` (torch.Tensor): Scalar or tensor broadcastable to other inputs.
|
|
171
|
+
|
|
172
|
+
Output Keys:
|
|
173
|
+
- `score_key` (torch.Tensor): Tensor of the same shape as `visits_key`, containing
|
|
174
|
+
the calculated UCB scores.
|
|
175
|
+
|
|
176
|
+
Example:
|
|
177
|
+
```python
|
|
178
|
+
from tensordict import TensorDict
|
|
179
|
+
from torchrl.modules.mcts.scores import UCBScore
|
|
180
|
+
|
|
181
|
+
# Create a UCBScore instance
|
|
182
|
+
ucb = UCBScore(c=1.414)
|
|
183
|
+
|
|
184
|
+
# Define a TensorDict with required keys
|
|
185
|
+
node = TensorDict(
|
|
186
|
+
{
|
|
187
|
+
"win_count": torch.tensor([15.0, 25.0]),
|
|
188
|
+
"visits": torch.tensor([10.0, 20.0]),
|
|
189
|
+
"total_visits": torch.tensor(100.0),
|
|
190
|
+
},
|
|
191
|
+
batch_size=[],
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# Compute the UCB scores
|
|
195
|
+
result = ucb(node)
|
|
196
|
+
print(result["score"]) # Output: Tensor with UCB scores
|
|
197
|
+
```
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
c: float
|
|
201
|
+
|
|
202
|
+
def __init__(
|
|
203
|
+
self,
|
|
204
|
+
*,
|
|
205
|
+
c: float,
|
|
206
|
+
win_count_key: NestedKey = "win_count",
|
|
207
|
+
visits_key: NestedKey = "visits",
|
|
208
|
+
total_visits_key: NestedKey = "total_visits",
|
|
209
|
+
score_key: NestedKey = "score",
|
|
210
|
+
):
|
|
211
|
+
super().__init__()
|
|
212
|
+
self.c = c
|
|
213
|
+
self.win_count_key = win_count_key
|
|
214
|
+
self.visits_key = visits_key
|
|
215
|
+
self.total_visits_key = total_visits_key
|
|
216
|
+
self.score_key = score_key
|
|
217
|
+
self.in_keys = [self.win_count_key, self.total_visits_key, self.visits_key]
|
|
218
|
+
self.out_keys = [self.score_key]
|
|
219
|
+
|
|
220
|
+
def forward(self, node: TensorDictBase) -> TensorDictBase:
|
|
221
|
+
win_count = node.get(self.win_count_key)
|
|
222
|
+
visits = node.get(self.visits_key)
|
|
223
|
+
n_total = node.get(self.total_visits_key)
|
|
224
|
+
# Handle broadcasting for batched inputs
|
|
225
|
+
if n_total.ndim > 0 and n_total.ndim < visits.ndim:
|
|
226
|
+
n_total = n_total.unsqueeze(-1)
|
|
227
|
+
node.set(
|
|
228
|
+
self.score_key,
|
|
229
|
+
(win_count / visits) + self.c * n_total.sqrt() / (1 + visits),
|
|
230
|
+
)
|
|
231
|
+
return node
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class EXP3Score(MCTSScore):
|
|
235
|
+
"""Computes action selection probabilities for the EXP3 algorithm in MCTS.
|
|
236
|
+
|
|
237
|
+
EXP3 (Exponential-weight algorithm for Exploration and Exploitation) is a bandit
|
|
238
|
+
algorithm that performs well in adversarial or non-stationary environments.
|
|
239
|
+
It maintains weights for each action and adjusts them based on received rewards.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
gamma (float, optional): Exploration factor, balancing uniform exploration
|
|
243
|
+
and exploitation of current weights. Must be in [0, 1]. Defaults to 0.1.
|
|
244
|
+
weights_key (NestedKey, optional): Key in the input `TensorDictBase` for
|
|
245
|
+
the tensor containing current action weights. Defaults to "weights".
|
|
246
|
+
action_prob_key (NestedKey, optional): Key to store the calculated action
|
|
247
|
+
probabilities. Defaults to "action_prob".
|
|
248
|
+
score_key (NestedKey, optional): Key where the calculated action probabilities
|
|
249
|
+
will be stored. Defaults to "score".
|
|
250
|
+
num_actions_key (NestedKey, optional): Key for the number of available
|
|
251
|
+
actions (K). Defaults to "num_actions".
|
|
252
|
+
|
|
253
|
+
Input Keys:
|
|
254
|
+
- `weights_key` (torch.Tensor): Tensor of shape (..., num_actions).
|
|
255
|
+
- `num_actions_key` (int or torch.Tensor): Scalar representing K, the number of actions.
|
|
256
|
+
|
|
257
|
+
Output Keys:
|
|
258
|
+
- `score_key` (torch.Tensor): Tensor of shape (..., num_actions) containing
|
|
259
|
+
the calculated action probabilities.
|
|
260
|
+
|
|
261
|
+
Example:
|
|
262
|
+
```python
|
|
263
|
+
from tensordict import TensorDict
|
|
264
|
+
from torchrl.modules.mcts.scores import EXP3Score
|
|
265
|
+
|
|
266
|
+
# Create an EXP3Score instance
|
|
267
|
+
exp3 = EXP3Score(gamma=0.1)
|
|
268
|
+
|
|
269
|
+
# Define a TensorDict with required keys
|
|
270
|
+
node = TensorDict(
|
|
271
|
+
{
|
|
272
|
+
"weights": torch.tensor([1.0, 1.0]),
|
|
273
|
+
"num_actions": torch.tensor(2),
|
|
274
|
+
},
|
|
275
|
+
batch_size=[],
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# Compute the action probabilities
|
|
279
|
+
result = exp3(node)
|
|
280
|
+
print(result["score"]) # Output: Tensor with action probabilities
|
|
281
|
+
```
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
def __init__(
|
|
285
|
+
self,
|
|
286
|
+
*,
|
|
287
|
+
gamma: float = 0.1,
|
|
288
|
+
weights_key: NestedKey = "weights",
|
|
289
|
+
action_prob_key: NestedKey = "action_prob",
|
|
290
|
+
reward_key: NestedKey = "reward",
|
|
291
|
+
score_key: NestedKey = "score",
|
|
292
|
+
num_actions_key: NestedKey = "num_actions",
|
|
293
|
+
):
|
|
294
|
+
super().__init__()
|
|
295
|
+
if not 0 <= gamma <= 1:
|
|
296
|
+
raise ValueError(f"gamma must be between 0 and 1, got {gamma}")
|
|
297
|
+
self.gamma = gamma
|
|
298
|
+
self.weights_key = weights_key
|
|
299
|
+
self.action_prob_key = action_prob_key
|
|
300
|
+
self.reward_key = reward_key
|
|
301
|
+
self.score_key = score_key
|
|
302
|
+
self.num_actions_key = num_actions_key
|
|
303
|
+
|
|
304
|
+
self.in_keys = [self.weights_key, self.num_actions_key]
|
|
305
|
+
self.out_keys = [self.score_key]
|
|
306
|
+
|
|
307
|
+
def forward(self, node: TensorDictBase) -> TensorDictBase:
|
|
308
|
+
num_actions = node.get(self.num_actions_key)
|
|
309
|
+
|
|
310
|
+
# Extract scalar value from num_actions (handles batched tensors too)
|
|
311
|
+
if isinstance(num_actions, torch.Tensor):
|
|
312
|
+
# For batched tensors, take the first element (all should be same)
|
|
313
|
+
k = int(num_actions.flatten()[0].item())
|
|
314
|
+
elif isinstance(num_actions, int):
|
|
315
|
+
k = num_actions
|
|
316
|
+
else:
|
|
317
|
+
raise ValueError(
|
|
318
|
+
f"'{self.num_actions_key}' ('num_actions') must be an integer or a tensor."
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
if self.weights_key not in node.keys(include_nested=True):
|
|
322
|
+
batch_size = node.batch_size
|
|
323
|
+
weights_shape = (*batch_size, k)
|
|
324
|
+
weights = torch.ones(weights_shape, device=node.device)
|
|
325
|
+
node.set(self.weights_key, weights)
|
|
326
|
+
else:
|
|
327
|
+
weights = node.get(self.weights_key)
|
|
328
|
+
|
|
329
|
+
k_from_weights = weights.shape[-1]
|
|
330
|
+
if k_from_weights != k:
|
|
331
|
+
raise ValueError(
|
|
332
|
+
f"Shape of weights {weights.shape} implies {k_from_weights} actions, "
|
|
333
|
+
f"but num_actions is {k}."
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
sum_weights = torch.sum(weights, dim=-1, keepdim=True)
|
|
337
|
+
sum_weights = torch.where(
|
|
338
|
+
sum_weights == 0, torch.ones_like(sum_weights), sum_weights
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
p_i = (1 - self.gamma) * (weights / sum_weights) + (self.gamma / k)
|
|
342
|
+
node.set(self.score_key, p_i)
|
|
343
|
+
if self.action_prob_key != self.score_key:
|
|
344
|
+
node.set(self.action_prob_key, p_i)
|
|
345
|
+
return node
|
|
346
|
+
|
|
347
|
+
def update_weights(
|
|
348
|
+
self, node: TensorDictBase, action_idx: int, reward: float
|
|
349
|
+
) -> None:
|
|
350
|
+
"""Updates the weight of the chosen action based on the reward.
|
|
351
|
+
|
|
352
|
+
This method updates the weight of the selected action using the EXP3 algorithm.
|
|
353
|
+
The weight update formula is:
|
|
354
|
+
`w_i(t+1) = w_i(t) * exp((gamma / K) * (reward / p_i(t)))`
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
node (TensorDictBase): The node containing the current weights and probabilities.
|
|
358
|
+
Must include the keys specified by `weights_key` and `score_key`.
|
|
359
|
+
action_idx (int): The index of the action that was selected.
|
|
360
|
+
reward (float): The reward received for the selected action. Must be in the range [0, 1].
|
|
361
|
+
|
|
362
|
+
Raises:
|
|
363
|
+
ValueError: If the reward is not in the range [0, 1].
|
|
364
|
+
ValueError: If the probability of the selected action is less than or equal to 0.
|
|
365
|
+
|
|
366
|
+
Example:
|
|
367
|
+
```python
|
|
368
|
+
from tensordict import TensorDict
|
|
369
|
+
from torchrl.modules.mcts.scores import EXP3Score
|
|
370
|
+
|
|
371
|
+
# Create an EXP3Score instance
|
|
372
|
+
exp3 = EXP3Score(gamma=0.1)
|
|
373
|
+
|
|
374
|
+
# Define a TensorDict with required keys
|
|
375
|
+
node = TensorDict(
|
|
376
|
+
{
|
|
377
|
+
"weights": torch.tensor([1.0, 1.0]),
|
|
378
|
+
"num_actions": torch.tensor(2),
|
|
379
|
+
},
|
|
380
|
+
batch_size=[],
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
# Compute the action probabilities
|
|
384
|
+
result = exp3(node)
|
|
385
|
+
print(result["score"]) # Output: Tensor with action probabilities
|
|
386
|
+
|
|
387
|
+
# Update the weights based on the reward for action 0
|
|
388
|
+
exp3.update_weights(node, action_idx=0, reward=0.8)
|
|
389
|
+
print(node["weights"]) # Updated weights
|
|
390
|
+
```
|
|
391
|
+
"""
|
|
392
|
+
if not (0 <= reward <= 1):
|
|
393
|
+
warnings.warn(
|
|
394
|
+
f"Reward {reward} is outside the expected [0,1] range for EXP3.",
|
|
395
|
+
UserWarning,
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
weights = node.get(self.weights_key)
|
|
399
|
+
action_probs = node.get(self.score_key)
|
|
400
|
+
k = weights.shape[-1]
|
|
401
|
+
|
|
402
|
+
if weights.ndim == 1:
|
|
403
|
+
current_weight = weights[action_idx]
|
|
404
|
+
prob_i = action_probs[action_idx]
|
|
405
|
+
elif weights.ndim > 1:
|
|
406
|
+
current_weight = weights[..., action_idx]
|
|
407
|
+
prob_i = action_probs[..., action_idx]
|
|
408
|
+
else:
|
|
409
|
+
raise ValueError(f"Invalid weights dimensions: {weights.ndim}")
|
|
410
|
+
|
|
411
|
+
if torch.any(prob_i <= 0):
|
|
412
|
+
prob_i_val = prob_i.item() if prob_i.numel() == 1 else prob_i
|
|
413
|
+
warnings.warn(
|
|
414
|
+
f"Probability p_i(t) for action {action_idx} is {prob_i_val}. "
|
|
415
|
+
"Weight will not be updated for zero probability actions.",
|
|
416
|
+
UserWarning,
|
|
417
|
+
)
|
|
418
|
+
# Don't update weights for zero probability - just return
|
|
419
|
+
return
|
|
420
|
+
|
|
421
|
+
reward_tensor = torch.as_tensor(
|
|
422
|
+
reward, device=current_weight.device, dtype=current_weight.dtype
|
|
423
|
+
)
|
|
424
|
+
exponent = (self.gamma / k) * (reward_tensor / prob_i)
|
|
425
|
+
new_weight = current_weight * torch.exp(exponent)
|
|
426
|
+
|
|
427
|
+
if weights.ndim == 1:
|
|
428
|
+
weights[action_idx] = new_weight
|
|
429
|
+
else:
|
|
430
|
+
weights[..., action_idx] = new_weight
|
|
431
|
+
node.set(self.weights_key, weights)
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
class UCB1TunedScore(MCTSScore):
|
|
435
|
+
"""Computes the UCB1-Tuned score for MCTS, using variance estimation.
|
|
436
|
+
|
|
437
|
+
UCB1-Tuned is an enhancement of the UCB1 algorithm that incorporates an estimate
|
|
438
|
+
of the variance of rewards for each action. This allows for a more refined
|
|
439
|
+
balance between exploration and exploitation, potentially leading to better
|
|
440
|
+
performance, especially when reward variances differ significantly across actions.
|
|
441
|
+
|
|
442
|
+
The score for an action `i` is calculated as:
|
|
443
|
+
`score_i = avg_reward_i + sqrt(log(N) / N_i * min(0.25, V_i))`
|
|
444
|
+
|
|
445
|
+
The variance estimate `V_i` for action `i` is calculated as:
|
|
446
|
+
`V_i = (sum_squared_rewards_i / N_i) - avg_reward_i^2 + sqrt(exploration_constant * log(N) / N_i)`
|
|
447
|
+
|
|
448
|
+
Where:
|
|
449
|
+
- `avg_reward_i`: Average reward obtained from action `i`.
|
|
450
|
+
- `N_i`: Number of times action `i` has been visited.
|
|
451
|
+
- `N`: Total number of times the parent node has been visited.
|
|
452
|
+
- `sum_squared_rewards_i`: Sum of the squares of rewards received from action `i`.
|
|
453
|
+
- `exploration_constant`: A constant used in the bias correction term of `V_i`.
|
|
454
|
+
Auer et al. (2002) suggest a value of 2.0 for rewards in the range [0,1].
|
|
455
|
+
- The term `min(0.25, V_i)` implies that rewards are scaled to `[0, 1]`, as 0.25 is
|
|
456
|
+
the maximum variance for a distribution in this range (e.g., Bernoulli(0.5)).
|
|
457
|
+
|
|
458
|
+
Reference: "Finite-time Analysis of the Multiarmed Bandit Problem"
|
|
459
|
+
(Auer, Cesa-Bianchi, Fischer, 2002).
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
exploration_constant (float, optional): The constant `C` used in the bias
|
|
463
|
+
correction term for the variance estimate `V_i`. Defaults to `2.0`,
|
|
464
|
+
as suggested for rewards in `[0,1]`.
|
|
465
|
+
win_count_key (NestedKey, optional): Key for the tensor in the input `TensorDictBase`
|
|
466
|
+
containing the sum of rewards for each action (Q_i * N_i). Defaults to "win_count".
|
|
467
|
+
visits_key (NestedKey, optional): Key for the tensor containing the visit
|
|
468
|
+
count for each action (N_i). Defaults to "visits".
|
|
469
|
+
total_visits_key (NestedKey, optional): Key for the tensor (or scalar)
|
|
470
|
+
representing the visit count of the parent node (N). Defaults to "total_visits".
|
|
471
|
+
sum_squared_rewards_key (NestedKey, optional): Key for the tensor containing
|
|
472
|
+
the sum of squared rewards received for each action. This is crucial for
|
|
473
|
+
calculating the empirical variance. Defaults to "sum_squared_rewards".
|
|
474
|
+
score_key (NestedKey, optional): Key where the calculated UCB1-Tuned scores
|
|
475
|
+
will be stored in the output `TensorDictBase`. Defaults to "score".
|
|
476
|
+
|
|
477
|
+
Input Keys:
|
|
478
|
+
- `win_count_key` (torch.Tensor): Sum of rewards for each action.
|
|
479
|
+
- `visits_key` (torch.Tensor): Visit counts for each action (N_i).
|
|
480
|
+
- `total_visits_key` (torch.Tensor): Parent node's visit count (N).
|
|
481
|
+
- `sum_squared_rewards_key` (torch.Tensor): Sum of squared rewards for each action.
|
|
482
|
+
|
|
483
|
+
Output Keys:
|
|
484
|
+
- `score_key` (torch.Tensor): Calculated UCB1-Tuned scores for each action.
|
|
485
|
+
|
|
486
|
+
Important Notes:
|
|
487
|
+
- **Unvisited Nodes**: Actions with zero visits (`visits_key` is 0) are assigned a
|
|
488
|
+
very large positive score to ensure they are selected for exploration.
|
|
489
|
+
- **Reward Range**: The `min(0.25, V_i)` term is theoretically most sound when
|
|
490
|
+
rewards are normalized to the range `[0, 1]`.
|
|
491
|
+
- **Logarithm of N**: `log(N)` (log of parent visits) is calculated using `torch.log(torch.clamp(N, min=1.0))`
|
|
492
|
+
to prevent issues with `N=0` or `N` between 0 and 1.
|
|
493
|
+
"""
|
|
494
|
+
|
|
495
|
+
def __init__(
|
|
496
|
+
self,
|
|
497
|
+
*,
|
|
498
|
+
win_count_key: NestedKey = "win_count",
|
|
499
|
+
visits_key: NestedKey = "visits",
|
|
500
|
+
total_visits_key: NestedKey = "total_visits",
|
|
501
|
+
sum_squared_rewards_key: NestedKey = "sum_squared_rewards",
|
|
502
|
+
score_key: NestedKey = "score",
|
|
503
|
+
exploration_constant: float = 2.0,
|
|
504
|
+
):
|
|
505
|
+
super().__init__()
|
|
506
|
+
self.win_count_key = win_count_key
|
|
507
|
+
self.visits_key = visits_key
|
|
508
|
+
self.total_visits_key = total_visits_key
|
|
509
|
+
self.sum_squared_rewards_key = sum_squared_rewards_key
|
|
510
|
+
self.score_key = score_key
|
|
511
|
+
self.exploration_constant = exploration_constant
|
|
512
|
+
|
|
513
|
+
self.in_keys = [
|
|
514
|
+
self.win_count_key,
|
|
515
|
+
self.visits_key,
|
|
516
|
+
self.total_visits_key,
|
|
517
|
+
self.sum_squared_rewards_key,
|
|
518
|
+
]
|
|
519
|
+
self.out_keys = [self.score_key]
|
|
520
|
+
|
|
521
|
+
def forward(self, node: TensorDictBase) -> TensorDictBase:
|
|
522
|
+
q_sum_i = node.get(self.win_count_key)
|
|
523
|
+
n_i = node.get(self.visits_key)
|
|
524
|
+
n_parent = node.get(self.total_visits_key)
|
|
525
|
+
sum_sq_rewards_i = node.get(self.sum_squared_rewards_key)
|
|
526
|
+
|
|
527
|
+
if n_parent.ndim > 0 and n_parent.ndim < q_sum_i.ndim:
|
|
528
|
+
n_parent_expanded = n_parent.unsqueeze(-1)
|
|
529
|
+
else:
|
|
530
|
+
n_parent_expanded = n_parent
|
|
531
|
+
|
|
532
|
+
safe_n_parent_for_log = torch.clamp(n_parent_expanded, min=1.0)
|
|
533
|
+
log_n_parent = torch.log(safe_n_parent_for_log)
|
|
534
|
+
|
|
535
|
+
scores = torch.zeros_like(q_sum_i, device=q_sum_i.device)
|
|
536
|
+
|
|
537
|
+
visited_mask = n_i > 0
|
|
538
|
+
|
|
539
|
+
if torch.any(visited_mask):
|
|
540
|
+
q_sum_i_v = q_sum_i[visited_mask]
|
|
541
|
+
n_i_v = n_i[visited_mask]
|
|
542
|
+
sum_sq_rewards_i_v = sum_sq_rewards_i[visited_mask]
|
|
543
|
+
|
|
544
|
+
log_n_parent_v = log_n_parent.expand_as(n_i)[visited_mask]
|
|
545
|
+
|
|
546
|
+
avg_reward_i_v = q_sum_i_v / n_i_v
|
|
547
|
+
|
|
548
|
+
empirical_variance_v = (sum_sq_rewards_i_v / n_i_v) - avg_reward_i_v.pow(2)
|
|
549
|
+
bias_correction_v = (
|
|
550
|
+
self.exploration_constant * log_n_parent_v / n_i_v
|
|
551
|
+
).sqrt()
|
|
552
|
+
|
|
553
|
+
v_i_v = empirical_variance_v + bias_correction_v
|
|
554
|
+
v_i_v = v_i_v.clamp(min=0)
|
|
555
|
+
|
|
556
|
+
min_variance_term_v = torch.min(torch.full_like(v_i_v, 0.25), v_i_v)
|
|
557
|
+
exploration_component_v = (
|
|
558
|
+
log_n_parent_v / n_i_v * min_variance_term_v
|
|
559
|
+
).sqrt()
|
|
560
|
+
|
|
561
|
+
scores[visited_mask] = avg_reward_i_v + exploration_component_v
|
|
562
|
+
|
|
563
|
+
unvisited_mask = ~visited_mask
|
|
564
|
+
if torch.any(unvisited_mask):
|
|
565
|
+
scores[unvisited_mask] = torch.finfo(scores.dtype).max / 10.0
|
|
566
|
+
|
|
567
|
+
node.set(self.score_key, scores)
|
|
568
|
+
return node
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
class MCTSScores(Enum):
|
|
572
|
+
"""Enum providing factory functions for common MCTS score configurations."""
|
|
573
|
+
|
|
574
|
+
PUCT = functools.partial(PUCTScore, c=5) # AlphaGo default value
|
|
575
|
+
UCB = functools.partial(UCBScore, c=math.sqrt(2)) # default from Auer et al. 2002
|
|
576
|
+
UCB1_TUNED = functools.partial(
|
|
577
|
+
UCB1TunedScore, exploration_constant=2.0
|
|
578
|
+
) # Auer et al. (2002) C=2 for rewards in [0,1]
|
|
579
|
+
EXP3 = functools.partial(EXP3Score, gamma=0.1)
|