torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,1956 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import abc
|
|
8
|
+
import functools
|
|
9
|
+
import warnings
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from contextlib import nullcontext
|
|
12
|
+
from dataclasses import asdict, dataclass
|
|
13
|
+
from functools import wraps
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from tensordict import is_tensor_collection, TensorDictBase
|
|
17
|
+
from tensordict.nn import (
|
|
18
|
+
composite_lp_aggregate,
|
|
19
|
+
dispatch,
|
|
20
|
+
ProbabilisticTensorDictModule,
|
|
21
|
+
set_composite_lp_aggregate,
|
|
22
|
+
set_skip_existing,
|
|
23
|
+
TensorDictModule,
|
|
24
|
+
TensorDictModuleBase,
|
|
25
|
+
)
|
|
26
|
+
from tensordict.nn.probabilistic import interaction_type
|
|
27
|
+
from tensordict.utils import NestedKey, unravel_key
|
|
28
|
+
from torch import Tensor
|
|
29
|
+
|
|
30
|
+
from torchrl._utils import logger, rl_warnings
|
|
31
|
+
from torchrl.envs.utils import step_mdp
|
|
32
|
+
from torchrl.objectives.utils import (
|
|
33
|
+
_maybe_get_or_select,
|
|
34
|
+
_pseudo_vmap,
|
|
35
|
+
_vmap_func,
|
|
36
|
+
hold_out_net,
|
|
37
|
+
)
|
|
38
|
+
from torchrl.objectives.value.functional import (
|
|
39
|
+
generalized_advantage_estimate,
|
|
40
|
+
td0_return_estimate,
|
|
41
|
+
td_lambda_return_estimate,
|
|
42
|
+
vec_generalized_advantage_estimate,
|
|
43
|
+
vec_td1_return_estimate,
|
|
44
|
+
vec_td_lambda_return_estimate,
|
|
45
|
+
vtrace_advantage_estimate,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
from torch.compiler import is_dynamo_compiling
|
|
50
|
+
except ImportError:
|
|
51
|
+
from torch._dynamo import is_compiling as is_dynamo_compiling
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
from torch import vmap
|
|
55
|
+
except ImportError as err:
|
|
56
|
+
try:
|
|
57
|
+
from functorch import vmap
|
|
58
|
+
except ImportError:
|
|
59
|
+
raise ImportError(
|
|
60
|
+
"vmap couldn't be found. Make sure you have torch>2.0 installed."
|
|
61
|
+
) from err
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _self_set_grad_enabled(fun):
|
|
65
|
+
@wraps(fun)
|
|
66
|
+
def new_fun(self, *args, **kwargs):
|
|
67
|
+
with torch.set_grad_enabled(self.differentiable):
|
|
68
|
+
return fun(self, *args, **kwargs)
|
|
69
|
+
|
|
70
|
+
return new_fun
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _self_set_skip_existing(fun):
|
|
74
|
+
@functools.wraps(fun)
|
|
75
|
+
def new_func(self, *args, **kwargs):
|
|
76
|
+
if self.skip_existing is not None:
|
|
77
|
+
with set_skip_existing(self.skip_existing):
|
|
78
|
+
return fun(self, *args, **kwargs)
|
|
79
|
+
return fun(self, *args, **kwargs)
|
|
80
|
+
|
|
81
|
+
return new_func
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _call_actor_net(
|
|
85
|
+
actor_net: ProbabilisticTensorDictModule,
|
|
86
|
+
data: TensorDictBase,
|
|
87
|
+
params: TensorDictBase,
|
|
88
|
+
log_prob_key: NestedKey,
|
|
89
|
+
):
|
|
90
|
+
dist = actor_net.get_dist(data.select(*actor_net.in_keys, strict=False))
|
|
91
|
+
s = actor_net._dist_sample(dist, interaction_type=interaction_type())
|
|
92
|
+
with set_composite_lp_aggregate(True):
|
|
93
|
+
return dist.log_prob(s)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class ValueEstimatorBase(TensorDictModuleBase):
|
|
97
|
+
"""An abstract parent class for value function modules.
|
|
98
|
+
|
|
99
|
+
Its :meth:`ValueFunctionBase.forward` method will compute the value (given
|
|
100
|
+
by the value network) and the value estimate (given by the value estimator)
|
|
101
|
+
as well as the advantage and write these values in the output tensordict.
|
|
102
|
+
|
|
103
|
+
If only the value estimate is needed, the :meth:`ValueFunctionBase.value_estimate`
|
|
104
|
+
should be used instead.
|
|
105
|
+
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
@dataclass
|
|
109
|
+
class _AcceptedKeys:
|
|
110
|
+
"""Maintains default values for all configurable tensordict keys.
|
|
111
|
+
|
|
112
|
+
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
|
|
113
|
+
default values.
|
|
114
|
+
|
|
115
|
+
Attributes:
|
|
116
|
+
advantage (NestedKey): The input tensordict key where the advantage is written to.
|
|
117
|
+
Will be used for the underlying value estimator. Defaults to ``"advantage"``.
|
|
118
|
+
value_target (NestedKey): The input tensordict key where the target state value is written to.
|
|
119
|
+
Will be used for the underlying value estimator Defaults to ``"value_target"``.
|
|
120
|
+
value (NestedKey): The input tensordict key where the state value is expected.
|
|
121
|
+
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
|
|
122
|
+
reward (NestedKey): The input tensordict key where the reward is written to.
|
|
123
|
+
Defaults to ``"reward"``.
|
|
124
|
+
done (NestedKey): The key in the input TensorDict that indicates
|
|
125
|
+
whether a trajectory is done. Defaults to ``"done"``.
|
|
126
|
+
terminated (NestedKey): The key in the input TensorDict that indicates
|
|
127
|
+
whether a trajectory is terminated. Defaults to ``"terminated"``.
|
|
128
|
+
steps_to_next_obs (NestedKey): The key in the input tensordict
|
|
129
|
+
that indicates the number of steps to the next observation.
|
|
130
|
+
Defaults to ``"steps_to_next_obs"``.
|
|
131
|
+
sample_log_prob (NestedKey): The key in the input tensordict that
|
|
132
|
+
indicates the log probability of the sampled action.
|
|
133
|
+
Defaults to ``"sample_log_prob"`` when :func:`~tensordict.nn.composite_lp_aggregate` returns `True`,
|
|
134
|
+
`"action_log_prob"` otherwise.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
advantage: NestedKey = "advantage"
|
|
138
|
+
value_target: NestedKey = "value_target"
|
|
139
|
+
value: NestedKey = "state_value"
|
|
140
|
+
reward: NestedKey = "reward"
|
|
141
|
+
done: NestedKey = "done"
|
|
142
|
+
terminated: NestedKey = "terminated"
|
|
143
|
+
steps_to_next_obs: NestedKey = "steps_to_next_obs"
|
|
144
|
+
sample_log_prob: NestedKey | None = None
|
|
145
|
+
|
|
146
|
+
def __post_init__(self):
|
|
147
|
+
if self.sample_log_prob is None:
|
|
148
|
+
if composite_lp_aggregate(nowarn=True):
|
|
149
|
+
self.sample_log_prob = "sample_log_prob"
|
|
150
|
+
else:
|
|
151
|
+
self.sample_log_prob = "action_log_prob"
|
|
152
|
+
|
|
153
|
+
default_keys = _AcceptedKeys
|
|
154
|
+
tensor_keys: _AcceptedKeys
|
|
155
|
+
value_network: TensorDictModule | Callable
|
|
156
|
+
_vmap_randomness = None
|
|
157
|
+
deactivate_vmap: bool = False
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def advantage_key(self):
|
|
161
|
+
return self.tensor_keys.advantage
|
|
162
|
+
|
|
163
|
+
@property
|
|
164
|
+
def value_key(self):
|
|
165
|
+
return self.tensor_keys.value
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def value_target_key(self):
|
|
169
|
+
return self.tensor_keys.value_target
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def reward_key(self):
|
|
173
|
+
return self.tensor_keys.reward
|
|
174
|
+
|
|
175
|
+
@property
|
|
176
|
+
def done_key(self):
|
|
177
|
+
return self.tensor_keys.done
|
|
178
|
+
|
|
179
|
+
@property
|
|
180
|
+
def terminated_key(self):
|
|
181
|
+
return self.tensor_keys.terminated
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def steps_to_next_obs_key(self):
|
|
185
|
+
return self.tensor_keys.steps_to_next_obs
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def sample_log_prob_key(self):
|
|
189
|
+
return self.tensor_keys.sample_log_prob
|
|
190
|
+
|
|
191
|
+
@abc.abstractmethod
|
|
192
|
+
def forward(
|
|
193
|
+
self,
|
|
194
|
+
tensordict: TensorDictBase,
|
|
195
|
+
*,
|
|
196
|
+
params: TensorDictBase | None = None,
|
|
197
|
+
target_params: TensorDictBase | None = None,
|
|
198
|
+
) -> TensorDictBase:
|
|
199
|
+
"""Computes the advantage estimate given the data in tensordict.
|
|
200
|
+
|
|
201
|
+
If a functional module is provided, a nested TensorDict containing the parameters
|
|
202
|
+
(and if relevant the target parameters) can be passed to the module.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
tensordict (TensorDictBase): A TensorDict containing the data
|
|
206
|
+
(an observation key, ``"action"``, ``("next", "reward")``,
|
|
207
|
+
``("next", "done")``, ``("next", "terminated")``,
|
|
208
|
+
and ``"next"`` tensordict state as returned by the environment)
|
|
209
|
+
necessary to compute the value estimates and the TDEstimate.
|
|
210
|
+
The data passed to this module should be structured as
|
|
211
|
+
:obj:`[*B, T, *F]` where :obj:`B` are
|
|
212
|
+
the batch size, :obj:`T` the time dimension and :obj:`F` the
|
|
213
|
+
feature dimension(s). The tensordict must have shape ``[*B, T]``.
|
|
214
|
+
|
|
215
|
+
Keyword Args:
|
|
216
|
+
params (TensorDictBase, optional): A nested TensorDict containing the params
|
|
217
|
+
to be passed to the functional value network module.
|
|
218
|
+
target_params (TensorDictBase, optional): A nested TensorDict containing the
|
|
219
|
+
target params to be passed to the functional value network module.
|
|
220
|
+
device (torch.device, optional): the device where the buffers will be instantiated.
|
|
221
|
+
Defaults to ``torch.get_default_device()``.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
|
|
225
|
+
"""
|
|
226
|
+
...
|
|
227
|
+
|
|
228
|
+
def __init__(
|
|
229
|
+
self,
|
|
230
|
+
*,
|
|
231
|
+
value_network: TensorDictModule,
|
|
232
|
+
shifted: bool = False,
|
|
233
|
+
differentiable: bool = False,
|
|
234
|
+
skip_existing: bool | None = None,
|
|
235
|
+
advantage_key: NestedKey = None,
|
|
236
|
+
value_target_key: NestedKey = None,
|
|
237
|
+
value_key: NestedKey = None,
|
|
238
|
+
device: torch.device | None = None,
|
|
239
|
+
deactivate_vmap: bool = False,
|
|
240
|
+
):
|
|
241
|
+
super().__init__()
|
|
242
|
+
if device is None:
|
|
243
|
+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
|
|
244
|
+
# this is saved for tracking only and should not be used to cast anything else than buffers during
|
|
245
|
+
# init.
|
|
246
|
+
self._device = device
|
|
247
|
+
self._tensor_keys = None
|
|
248
|
+
self.differentiable = differentiable
|
|
249
|
+
self.deactivate_vmap = deactivate_vmap
|
|
250
|
+
self.skip_existing = skip_existing
|
|
251
|
+
self.__dict__["value_network"] = value_network
|
|
252
|
+
self.dep_keys = {}
|
|
253
|
+
self.shifted = shifted
|
|
254
|
+
|
|
255
|
+
if advantage_key is not None:
|
|
256
|
+
raise RuntimeError(
|
|
257
|
+
"Setting 'advantage_key' via constructor is deprecated, use .set_keys(advantage_key='some_key') instead.",
|
|
258
|
+
)
|
|
259
|
+
if value_target_key is not None:
|
|
260
|
+
raise RuntimeError(
|
|
261
|
+
"Setting 'value_target_key' via constructor is deprecated, use .set_keys(value_target_key='some_key') instead.",
|
|
262
|
+
)
|
|
263
|
+
if value_key is not None:
|
|
264
|
+
raise RuntimeError(
|
|
265
|
+
"Setting 'value_key' via constructor is deprecated, use .set_keys(value_key='some_key') instead.",
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
@property
|
|
269
|
+
def tensor_keys(self) -> _AcceptedKeys:
|
|
270
|
+
if self._tensor_keys is None:
|
|
271
|
+
self.set_keys()
|
|
272
|
+
return self._tensor_keys
|
|
273
|
+
|
|
274
|
+
@tensor_keys.setter
|
|
275
|
+
def tensor_keys(self, value):
|
|
276
|
+
if not isinstance(value, type(self._AcceptedKeys)):
|
|
277
|
+
raise ValueError("value must be an instance of _AcceptedKeys")
|
|
278
|
+
self._keys = value
|
|
279
|
+
|
|
280
|
+
@property
|
|
281
|
+
def in_keys(self):
|
|
282
|
+
try:
|
|
283
|
+
in_keys = (
|
|
284
|
+
self.value_network.in_keys
|
|
285
|
+
+ [
|
|
286
|
+
("next", self.tensor_keys.reward),
|
|
287
|
+
("next", self.tensor_keys.done),
|
|
288
|
+
("next", self.tensor_keys.terminated),
|
|
289
|
+
]
|
|
290
|
+
+ [("next", in_key) for in_key in self.value_network.in_keys]
|
|
291
|
+
)
|
|
292
|
+
except AttributeError:
|
|
293
|
+
# value network does not have an `in_keys` attribute
|
|
294
|
+
in_keys = []
|
|
295
|
+
return in_keys
|
|
296
|
+
|
|
297
|
+
@property
|
|
298
|
+
def out_keys(self):
|
|
299
|
+
return [
|
|
300
|
+
self.tensor_keys.advantage,
|
|
301
|
+
self.tensor_keys.value_target,
|
|
302
|
+
]
|
|
303
|
+
|
|
304
|
+
def set_keys(self, **kwargs) -> None:
|
|
305
|
+
"""Set tensordict key names."""
|
|
306
|
+
for key, value in list(kwargs.items()):
|
|
307
|
+
if isinstance(value, list):
|
|
308
|
+
value = [unravel_key(k) for k in value]
|
|
309
|
+
elif not isinstance(value, (str, tuple)):
|
|
310
|
+
if value is None:
|
|
311
|
+
raise ValueError("tensordict keys cannot be None")
|
|
312
|
+
raise ValueError(
|
|
313
|
+
f"key name must be of type NestedKey (Union[str, Tuple[str]]) but got {type(value)}"
|
|
314
|
+
)
|
|
315
|
+
else:
|
|
316
|
+
value = unravel_key(value)
|
|
317
|
+
|
|
318
|
+
if key not in self._AcceptedKeys.__dict__:
|
|
319
|
+
raise KeyError(
|
|
320
|
+
f"{key} is not an accepted tensordict key for advantages"
|
|
321
|
+
)
|
|
322
|
+
if (
|
|
323
|
+
key == "value"
|
|
324
|
+
and hasattr(self.value_network, "out_keys")
|
|
325
|
+
and (value not in self.value_network.out_keys)
|
|
326
|
+
):
|
|
327
|
+
raise KeyError(
|
|
328
|
+
f"value key '{value}' not found in value network out_keys {self.value_network.out_keys}"
|
|
329
|
+
)
|
|
330
|
+
kwargs[key] = value
|
|
331
|
+
if self._tensor_keys is None:
|
|
332
|
+
conf = asdict(self.default_keys())
|
|
333
|
+
conf.update(self.dep_keys)
|
|
334
|
+
else:
|
|
335
|
+
conf = asdict(self._tensor_keys)
|
|
336
|
+
conf.update(kwargs)
|
|
337
|
+
self._tensor_keys = self._AcceptedKeys(**conf)
|
|
338
|
+
|
|
339
|
+
def value_estimate(
|
|
340
|
+
self,
|
|
341
|
+
tensordict,
|
|
342
|
+
target_params: TensorDictBase | None = None,
|
|
343
|
+
next_value: torch.Tensor | None = None,
|
|
344
|
+
**kwargs,
|
|
345
|
+
):
|
|
346
|
+
"""Gets a value estimate, usually used as a target value for the value network.
|
|
347
|
+
|
|
348
|
+
If the state value key is present under ``tensordict.get(("next", self.tensor_keys.value))``
|
|
349
|
+
then this value will be used without recurring to the value network.
|
|
350
|
+
|
|
351
|
+
Args:
|
|
352
|
+
tensordict (TensorDictBase): the tensordict containing the data to
|
|
353
|
+
read.
|
|
354
|
+
target_params (TensorDictBase, optional): A nested TensorDict containing the
|
|
355
|
+
target params to be passed to the functional value network module.
|
|
356
|
+
next_value (torch.Tensor, optional): the value of the next state
|
|
357
|
+
or state-action pair. Exclusive with ``target_params``.
|
|
358
|
+
**kwargs: the keyword arguments to be passed to the value network.
|
|
359
|
+
|
|
360
|
+
Returns: a tensor corresponding to the state value.
|
|
361
|
+
|
|
362
|
+
"""
|
|
363
|
+
raise NotImplementedError
|
|
364
|
+
|
|
365
|
+
@property
|
|
366
|
+
def is_functional(self):
|
|
367
|
+
# legacy
|
|
368
|
+
return False
|
|
369
|
+
|
|
370
|
+
@property
|
|
371
|
+
def is_stateless(self):
|
|
372
|
+
# legacy
|
|
373
|
+
return False
|
|
374
|
+
|
|
375
|
+
def _next_value(self, tensordict, target_params, kwargs):
|
|
376
|
+
step_td = step_mdp(tensordict, keep_other=False)
|
|
377
|
+
if self.value_network is not None:
|
|
378
|
+
with hold_out_net(
|
|
379
|
+
self.value_network
|
|
380
|
+
) if target_params is None else target_params.to_module(self.value_network):
|
|
381
|
+
self.value_network(step_td)
|
|
382
|
+
next_value = step_td.get(self.tensor_keys.value)
|
|
383
|
+
return next_value
|
|
384
|
+
|
|
385
|
+
@property
|
|
386
|
+
def vmap_randomness(self):
|
|
387
|
+
if self._vmap_randomness is None:
|
|
388
|
+
if is_dynamo_compiling():
|
|
389
|
+
self._vmap_randomness = "different"
|
|
390
|
+
return "different"
|
|
391
|
+
do_break = False
|
|
392
|
+
for val in self.__dict__.values():
|
|
393
|
+
if isinstance(val, torch.nn.Module):
|
|
394
|
+
import torchrl.objectives.utils
|
|
395
|
+
|
|
396
|
+
for module in val.modules():
|
|
397
|
+
if isinstance(
|
|
398
|
+
module, torchrl.objectives.utils.RANDOM_MODULE_LIST
|
|
399
|
+
):
|
|
400
|
+
self._vmap_randomness = "different"
|
|
401
|
+
do_break = True
|
|
402
|
+
break
|
|
403
|
+
if do_break:
|
|
404
|
+
# double break
|
|
405
|
+
break
|
|
406
|
+
else:
|
|
407
|
+
self._vmap_randomness = "error"
|
|
408
|
+
|
|
409
|
+
return self._vmap_randomness
|
|
410
|
+
|
|
411
|
+
def set_vmap_randomness(self, value):
|
|
412
|
+
self._vmap_randomness = value
|
|
413
|
+
|
|
414
|
+
def _get_time_dim(self, time_dim: int | None, data: TensorDictBase):
|
|
415
|
+
if time_dim is not None:
|
|
416
|
+
if time_dim < 0:
|
|
417
|
+
time_dim = data.ndim + time_dim
|
|
418
|
+
return time_dim
|
|
419
|
+
time_dim_attr = getattr(self, "time_dim", None)
|
|
420
|
+
if time_dim_attr is not None:
|
|
421
|
+
if time_dim_attr < 0:
|
|
422
|
+
time_dim_attr = data.ndim + time_dim_attr
|
|
423
|
+
return time_dim_attr
|
|
424
|
+
if data._has_names():
|
|
425
|
+
for i, name in enumerate(data.names):
|
|
426
|
+
if name == "time":
|
|
427
|
+
return i
|
|
428
|
+
return data.ndim - 1
|
|
429
|
+
|
|
430
|
+
def _call_value_nets(
|
|
431
|
+
self,
|
|
432
|
+
data: TensorDictBase,
|
|
433
|
+
params: TensorDictBase,
|
|
434
|
+
next_params: TensorDictBase,
|
|
435
|
+
single_call: bool,
|
|
436
|
+
value_key: NestedKey,
|
|
437
|
+
detach_next: bool,
|
|
438
|
+
vmap_randomness: str = "error",
|
|
439
|
+
*,
|
|
440
|
+
value_net: TensorDictModuleBase | None = None,
|
|
441
|
+
):
|
|
442
|
+
if value_net is None:
|
|
443
|
+
value_net = self.value_network
|
|
444
|
+
in_keys = value_net.in_keys
|
|
445
|
+
if single_call:
|
|
446
|
+
# We are going to flatten the data, then interleave the last observation of each trajectory in between its
|
|
447
|
+
# previous obs (from the root TD) and the first of the next trajectory. Eventually, each trajectory will
|
|
448
|
+
# have T+1 elements (or, for a batch of N trajectories, we will have \Sum_{t=0}^{T-1} length_t + T
|
|
449
|
+
# elements). Then, we can feed that to our RNN which will understand which trajectory is which, pad the data
|
|
450
|
+
# accordingly and process each of them independently.
|
|
451
|
+
try:
|
|
452
|
+
ndim = list(data.names).index("time") + 1
|
|
453
|
+
except ValueError:
|
|
454
|
+
if rl_warnings():
|
|
455
|
+
logger.warning(
|
|
456
|
+
"Got a tensordict without a time-marked dimension, assuming time is along the last dimension. "
|
|
457
|
+
"This warning can be turned off by setting the environment variable RL_WARNINGS to False."
|
|
458
|
+
)
|
|
459
|
+
ndim = data.ndim
|
|
460
|
+
data_copy = data.copy()
|
|
461
|
+
# we are going to modify the done so let's clone it
|
|
462
|
+
done = data_copy["next", "done"].clone()
|
|
463
|
+
# Mark the last step of every sequence as done. We do this because flattening would cause the trajectories
|
|
464
|
+
# of different batches to be merged.
|
|
465
|
+
done[(slice(None),) * (ndim - 1) + (-1,)].fill_(True)
|
|
466
|
+
truncated = data_copy.get(("next", "truncated"), done)
|
|
467
|
+
if truncated is not done:
|
|
468
|
+
truncated[(slice(None),) * (ndim - 1) + (-1,)].fill_(True)
|
|
469
|
+
data_copy["next", "done"] = done
|
|
470
|
+
data_copy["next", "truncated"] = truncated
|
|
471
|
+
# Reshape to -1 because we cannot guarantee that all dims have the same number of done states
|
|
472
|
+
with data_copy.view(-1) as data_copy_view:
|
|
473
|
+
# Interleave next data when done
|
|
474
|
+
data_copy_select = data_copy_view.select(
|
|
475
|
+
*in_keys, value_key, strict=False
|
|
476
|
+
)
|
|
477
|
+
total_elts = (
|
|
478
|
+
data_copy_view.shape[0]
|
|
479
|
+
+ data_copy_view["next", "done"].sum().item()
|
|
480
|
+
)
|
|
481
|
+
data_in = data_copy_select.new_zeros((total_elts,))
|
|
482
|
+
# we can get the indices of non-done data by adding the shifted done cumsum to an arange
|
|
483
|
+
# traj = [0, 0, 0, 1, 1, 2, 2]
|
|
484
|
+
# arange = [0, 1, 2, 3, 4, 5, 6]
|
|
485
|
+
# done = [0, 0, 1, 0, 1, 0, 1]
|
|
486
|
+
# done_cs = [0, 0, 0, 1, 1, 2, 2]
|
|
487
|
+
# indices = [0, 1, 2, 4, 5, 7, 8]
|
|
488
|
+
done_view = data_copy_view["next", "done"]
|
|
489
|
+
if done_view.shape[-1] == 1:
|
|
490
|
+
done_view = done_view.squeeze(-1)
|
|
491
|
+
else:
|
|
492
|
+
done_view = done_view.any(-1)
|
|
493
|
+
done_cs = done_view.cumsum(0)
|
|
494
|
+
done_cs = torch.cat([done_cs.new_zeros((1,)), done_cs[:-1]], dim=0)
|
|
495
|
+
indices = torch.arange(done_cs.shape[0], device=done_cs.device)
|
|
496
|
+
indices = indices + done_cs
|
|
497
|
+
data_in[indices] = data_copy_select
|
|
498
|
+
# To get the indices of the extra data, we can mask indices with done_view and add 1
|
|
499
|
+
indices_interleaved = indices[done_view] + 1
|
|
500
|
+
# assert not set(indices_interleaved.tolist()).intersection(indices.tolist())
|
|
501
|
+
data_in[indices_interleaved] = (
|
|
502
|
+
data_copy_view[done_view]
|
|
503
|
+
.get("next")
|
|
504
|
+
.select(*in_keys, value_key, strict=False)
|
|
505
|
+
)
|
|
506
|
+
if next_params is not None and next_params is not params:
|
|
507
|
+
raise ValueError(
|
|
508
|
+
"the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed."
|
|
509
|
+
)
|
|
510
|
+
if params is not None:
|
|
511
|
+
with params.to_module(value_net):
|
|
512
|
+
value_est = value_net(data_in).get(value_key)
|
|
513
|
+
else:
|
|
514
|
+
value_est = value_net(data_in).get(value_key)
|
|
515
|
+
value, value_ = value_est[indices], value_est[indices + 1]
|
|
516
|
+
value = value.view_as(done)
|
|
517
|
+
value_ = value_.view_as(done)
|
|
518
|
+
else:
|
|
519
|
+
data_root = data.select(*in_keys, value_key, strict=False)
|
|
520
|
+
data_next = data.get("next").select(*in_keys, value_key, strict=False)
|
|
521
|
+
if "is_init" in data_root.keys():
|
|
522
|
+
# We need to mark the first element of the "next" td as being an init step for RNNs
|
|
523
|
+
# otherwise, consecutive elements in the sequence will be considered as part of the same
|
|
524
|
+
# trajectory, even if they're not.
|
|
525
|
+
data_next["is_init"] = data_next["is_init"] | data_root["is_init"]
|
|
526
|
+
data_in = torch.stack(
|
|
527
|
+
[data_root, data_next],
|
|
528
|
+
0,
|
|
529
|
+
)
|
|
530
|
+
if (params is not None) ^ (next_params is not None):
|
|
531
|
+
raise ValueError(
|
|
532
|
+
"params and next_params must be either both provided or not."
|
|
533
|
+
)
|
|
534
|
+
elif params is not None:
|
|
535
|
+
params_stack = torch.stack([params, next_params], 0).contiguous()
|
|
536
|
+
data_out = _vmap_func(
|
|
537
|
+
value_net,
|
|
538
|
+
(0, 0),
|
|
539
|
+
randomness=vmap_randomness,
|
|
540
|
+
pseudo_vmap=self.deactivate_vmap,
|
|
541
|
+
)(data_in, params_stack)
|
|
542
|
+
elif not self.deactivate_vmap:
|
|
543
|
+
data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in)
|
|
544
|
+
else:
|
|
545
|
+
data_out = _pseudo_vmap(value_net, (0,), randomness=vmap_randomness)(
|
|
546
|
+
data_in
|
|
547
|
+
)
|
|
548
|
+
value_est = data_out.get(value_key)
|
|
549
|
+
value, value_ = value_est[0], value_est[1]
|
|
550
|
+
data.set(value_key, value)
|
|
551
|
+
data.set(("next", value_key), value_)
|
|
552
|
+
if detach_next:
|
|
553
|
+
value_ = value_.detach()
|
|
554
|
+
return value, value_
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
class TD0Estimator(ValueEstimatorBase):
|
|
558
|
+
"""Temporal Difference (TD(0)) estimate of advantage function.
|
|
559
|
+
|
|
560
|
+
AKA bootstrapped temporal difference or 1-step return.
|
|
561
|
+
|
|
562
|
+
Keyword Args:
|
|
563
|
+
gamma (scalar): exponential mean discount.
|
|
564
|
+
value_network (TensorDictModule): value operator used to retrieve
|
|
565
|
+
the value estimates.
|
|
566
|
+
shifted (bool, optional): if ``True``, the value and next value are
|
|
567
|
+
estimated with a single call to the value network. This is faster
|
|
568
|
+
but is only valid whenever (1) the ``"next"`` value is shifted by
|
|
569
|
+
only one time step (which is not the case with multi-step value
|
|
570
|
+
estimation, for instance) and (2) when the parameters used at time
|
|
571
|
+
``t`` and ``t+1`` are identical (which is not the case when target
|
|
572
|
+
parameters are to be used). Defaults to ``False``.
|
|
573
|
+
average_rewards (bool, optional): if ``True``, rewards will be standardized
|
|
574
|
+
before the TD is computed.
|
|
575
|
+
differentiable (bool, optional): if ``True``, gradients are propagated through
|
|
576
|
+
the computation of the value function. Default is ``False``.
|
|
577
|
+
|
|
578
|
+
.. note::
|
|
579
|
+
The proper way to make the function call non-differentiable is to
|
|
580
|
+
decorate it in a `torch.no_grad()` context manager/decorator or
|
|
581
|
+
pass detached parameters for functional modules.
|
|
582
|
+
|
|
583
|
+
skip_existing (bool, optional): if ``True``, the value network will skip
|
|
584
|
+
modules which outputs are already present in the tensordict.
|
|
585
|
+
Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()`
|
|
586
|
+
is not affected.
|
|
587
|
+
advantage_key (str or tuple of str, optional): [Deprecated] the key of
|
|
588
|
+
the advantage entry. Defaults to ``"advantage"``.
|
|
589
|
+
value_target_key (str or tuple of str, optional): [Deprecated] the key
|
|
590
|
+
of the advantage entry. Defaults to ``"value_target"``.
|
|
591
|
+
value_key (str or tuple of str, optional): [Deprecated] the value key to
|
|
592
|
+
read from the input tensordict. Defaults to ``"state_value"``.
|
|
593
|
+
device (torch.device, optional): the device where the buffers will be instantiated.
|
|
594
|
+
Defaults to ``torch.get_default_device()``.
|
|
595
|
+
deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
|
|
596
|
+
Defaults to ``False``.
|
|
597
|
+
|
|
598
|
+
"""
|
|
599
|
+
|
|
600
|
+
def __init__(
|
|
601
|
+
self,
|
|
602
|
+
*,
|
|
603
|
+
gamma: float | torch.Tensor,
|
|
604
|
+
value_network: TensorDictModule,
|
|
605
|
+
shifted: bool = False,
|
|
606
|
+
average_rewards: bool = False,
|
|
607
|
+
differentiable: bool = False,
|
|
608
|
+
advantage_key: NestedKey = None,
|
|
609
|
+
value_target_key: NestedKey = None,
|
|
610
|
+
value_key: NestedKey = None,
|
|
611
|
+
skip_existing: bool | None = None,
|
|
612
|
+
device: torch.device | None = None,
|
|
613
|
+
deactivate_vmap: bool = False,
|
|
614
|
+
):
|
|
615
|
+
super().__init__(
|
|
616
|
+
value_network=value_network,
|
|
617
|
+
differentiable=differentiable,
|
|
618
|
+
shifted=shifted,
|
|
619
|
+
advantage_key=advantage_key,
|
|
620
|
+
value_target_key=value_target_key,
|
|
621
|
+
value_key=value_key,
|
|
622
|
+
skip_existing=skip_existing,
|
|
623
|
+
device=device,
|
|
624
|
+
deactivate_vmap=deactivate_vmap,
|
|
625
|
+
)
|
|
626
|
+
self.register_buffer("gamma", torch.tensor(gamma, device=self._device))
|
|
627
|
+
self.average_rewards = average_rewards
|
|
628
|
+
|
|
629
|
+
@_self_set_skip_existing
|
|
630
|
+
@_self_set_grad_enabled
|
|
631
|
+
@dispatch
|
|
632
|
+
def forward(
|
|
633
|
+
self,
|
|
634
|
+
tensordict: TensorDictBase,
|
|
635
|
+
*,
|
|
636
|
+
params: TensorDictBase | None = None,
|
|
637
|
+
target_params: TensorDictBase | None = None,
|
|
638
|
+
) -> TensorDictBase:
|
|
639
|
+
"""Computes the TD(0) advantage given the data in tensordict.
|
|
640
|
+
|
|
641
|
+
If a functional module is provided, a nested TensorDict containing the parameters
|
|
642
|
+
(and if relevant the target parameters) can be passed to the module.
|
|
643
|
+
|
|
644
|
+
Args:
|
|
645
|
+
tensordict (TensorDictBase): A TensorDict containing the data
|
|
646
|
+
(an observation key, ``"action"``, ``("next", "reward")``,
|
|
647
|
+
``("next", "done")``, ``("next", "terminated")``, and ``"next"``
|
|
648
|
+
tensordict state as returned by the environment) necessary to
|
|
649
|
+
compute the value estimates and the TDEstimate.
|
|
650
|
+
The data passed to this module should be structured as
|
|
651
|
+
:obj:`[*B, T, *F]` where :obj:`B` are
|
|
652
|
+
the batch size, :obj:`T` the time dimension and :obj:`F` the
|
|
653
|
+
feature dimension(s). The tensordict must have shape ``[*B, T]``.
|
|
654
|
+
|
|
655
|
+
Keyword Args:
|
|
656
|
+
params (TensorDictBase, optional): A nested TensorDict containing the params
|
|
657
|
+
to be passed to the functional value network module.
|
|
658
|
+
target_params (TensorDictBase, optional): A nested TensorDict containing the
|
|
659
|
+
target params to be passed to the functional value network module.
|
|
660
|
+
|
|
661
|
+
Returns:
|
|
662
|
+
An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
|
|
663
|
+
|
|
664
|
+
Examples:
|
|
665
|
+
>>> from tensordict import TensorDict
|
|
666
|
+
>>> value_net = TensorDictModule(
|
|
667
|
+
... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
|
|
668
|
+
... )
|
|
669
|
+
>>> module = TDEstimate(
|
|
670
|
+
... gamma=0.98,
|
|
671
|
+
... value_network=value_net,
|
|
672
|
+
... )
|
|
673
|
+
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
|
|
674
|
+
>>> reward = torch.randn(1, 10, 1)
|
|
675
|
+
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
676
|
+
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
677
|
+
>>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "terminated": terminated, "reward": reward}}, [1, 10])
|
|
678
|
+
>>> _ = module(tensordict)
|
|
679
|
+
>>> assert "advantage" in tensordict.keys()
|
|
680
|
+
|
|
681
|
+
The module supports non-tensordict (i.e. unpacked tensordict) inputs too:
|
|
682
|
+
|
|
683
|
+
Examples:
|
|
684
|
+
>>> value_net = TensorDictModule(
|
|
685
|
+
... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
|
|
686
|
+
... )
|
|
687
|
+
>>> module = TDEstimate(
|
|
688
|
+
... gamma=0.98,
|
|
689
|
+
... value_network=value_net,
|
|
690
|
+
... )
|
|
691
|
+
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
|
|
692
|
+
>>> reward = torch.randn(1, 10, 1)
|
|
693
|
+
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
694
|
+
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
695
|
+
>>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
|
|
696
|
+
|
|
697
|
+
"""
|
|
698
|
+
if tensordict.batch_dims < 1:
|
|
699
|
+
raise RuntimeError(
|
|
700
|
+
"Expected input tensordict to have at least one dimensions, got"
|
|
701
|
+
f"tensordict.batch_size = {tensordict.batch_size}"
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
if self.is_stateless and params is None:
|
|
705
|
+
raise RuntimeError(
|
|
706
|
+
"Expected params to be passed to advantage module but got none."
|
|
707
|
+
)
|
|
708
|
+
if self.value_network is not None:
|
|
709
|
+
if params is not None:
|
|
710
|
+
params = params.detach()
|
|
711
|
+
if target_params is None:
|
|
712
|
+
target_params = params.clone(False)
|
|
713
|
+
with hold_out_net(self.value_network) if (
|
|
714
|
+
params is None and target_params is None
|
|
715
|
+
) else nullcontext():
|
|
716
|
+
# we may still need to pass gradient, but we don't want to assign grads to
|
|
717
|
+
# value net params
|
|
718
|
+
value, next_value = self._call_value_nets(
|
|
719
|
+
data=tensordict,
|
|
720
|
+
params=params,
|
|
721
|
+
next_params=target_params,
|
|
722
|
+
single_call=self.shifted,
|
|
723
|
+
value_key=self.tensor_keys.value,
|
|
724
|
+
detach_next=True,
|
|
725
|
+
vmap_randomness=self.vmap_randomness,
|
|
726
|
+
)
|
|
727
|
+
else:
|
|
728
|
+
value = tensordict.get(self.tensor_keys.value)
|
|
729
|
+
next_value = tensordict.get(("next", self.tensor_keys.value))
|
|
730
|
+
|
|
731
|
+
value_target = self.value_estimate(tensordict, next_value=next_value)
|
|
732
|
+
tensordict.set(self.tensor_keys.advantage, value_target - value)
|
|
733
|
+
tensordict.set(self.tensor_keys.value_target, value_target)
|
|
734
|
+
return tensordict
|
|
735
|
+
|
|
736
|
+
def value_estimate(
|
|
737
|
+
self,
|
|
738
|
+
tensordict,
|
|
739
|
+
target_params: TensorDictBase | None = None,
|
|
740
|
+
next_value: torch.Tensor | None = None,
|
|
741
|
+
**kwargs,
|
|
742
|
+
):
|
|
743
|
+
reward = tensordict.get(("next", self.tensor_keys.reward))
|
|
744
|
+
device = reward.device
|
|
745
|
+
|
|
746
|
+
if self.gamma.device != device:
|
|
747
|
+
self.gamma = self.gamma.to(device)
|
|
748
|
+
gamma = self.gamma
|
|
749
|
+
steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None)
|
|
750
|
+
if steps_to_next_obs is not None:
|
|
751
|
+
gamma = gamma ** steps_to_next_obs.view_as(reward)
|
|
752
|
+
|
|
753
|
+
if self.average_rewards:
|
|
754
|
+
reward = reward - reward.mean()
|
|
755
|
+
reward = reward / reward.std().clamp_min(1e-5)
|
|
756
|
+
tensordict.set(
|
|
757
|
+
("next", self.tensor_keys.reward), reward
|
|
758
|
+
) # we must update the rewards if they are used later in the code
|
|
759
|
+
if next_value is None:
|
|
760
|
+
next_value = self._next_value(tensordict, target_params, kwargs=kwargs)
|
|
761
|
+
|
|
762
|
+
done = tensordict.get(("next", self.tensor_keys.done))
|
|
763
|
+
terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done)
|
|
764
|
+
value_target = td0_return_estimate(
|
|
765
|
+
gamma=gamma,
|
|
766
|
+
next_state_value=next_value,
|
|
767
|
+
reward=reward,
|
|
768
|
+
done=done,
|
|
769
|
+
terminated=terminated,
|
|
770
|
+
)
|
|
771
|
+
return value_target
|
|
772
|
+
|
|
773
|
+
|
|
774
|
+
class TD1Estimator(ValueEstimatorBase):
|
|
775
|
+
r""":math:`\infty`-Temporal Difference (TD(1)) estimate of advantage function.
|
|
776
|
+
|
|
777
|
+
Keyword Args:
|
|
778
|
+
gamma (scalar): exponential mean discount.
|
|
779
|
+
value_network (TensorDictModule): value operator used to retrieve the value estimates.
|
|
780
|
+
average_rewards (bool, optional): if ``True``, rewards will be standardized
|
|
781
|
+
before the TD is computed.
|
|
782
|
+
differentiable (bool, optional): if ``True``, gradients are propagated through
|
|
783
|
+
the computation of the value function. Default is ``False``.
|
|
784
|
+
|
|
785
|
+
.. note::
|
|
786
|
+
The proper way to make the function call non-differentiable is to
|
|
787
|
+
decorate it in a `torch.no_grad()` context manager/decorator or
|
|
788
|
+
pass detached parameters for functional modules.
|
|
789
|
+
|
|
790
|
+
skip_existing (bool, optional): if ``True``, the value network will skip
|
|
791
|
+
modules which outputs are already present in the tensordict.
|
|
792
|
+
Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()`
|
|
793
|
+
is not affected.
|
|
794
|
+
advantage_key (str or tuple of str, optional): [Deprecated] the key of
|
|
795
|
+
the advantage entry. Defaults to ``"advantage"``.
|
|
796
|
+
value_target_key (str or tuple of str, optional): [Deprecated] the key
|
|
797
|
+
of the advantage entry. Defaults to ``"value_target"``.
|
|
798
|
+
value_key (str or tuple of str, optional): [Deprecated] the value key to
|
|
799
|
+
read from the input tensordict. Defaults to ``"state_value"``.
|
|
800
|
+
shifted (bool, optional): if ``True``, the value and next value are
|
|
801
|
+
estimated with a single call to the value network. This is faster
|
|
802
|
+
but is only valid whenever (1) the ``"next"`` value is shifted by
|
|
803
|
+
only one time step (which is not the case with multi-step value
|
|
804
|
+
estimation, for instance) and (2) when the parameters used at time
|
|
805
|
+
``t`` and ``t+1`` are identical (which is not the case when target
|
|
806
|
+
parameters are to be used). Defaults to ``False``.
|
|
807
|
+
device (torch.device, optional): the device where the buffers will be instantiated.
|
|
808
|
+
Defaults to ``torch.get_default_device()``.
|
|
809
|
+
time_dim (int, optional): the dimension corresponding to the time
|
|
810
|
+
in the input tensordict. If not provided, defaults to the dimension
|
|
811
|
+
marked with the ``"time"`` name if any, and to the last dimension
|
|
812
|
+
otherwise. Can be overridden during a call to
|
|
813
|
+
:meth:`~.value_estimate`.
|
|
814
|
+
Negative dimensions are considered with respect to the input
|
|
815
|
+
tensordict.
|
|
816
|
+
deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
|
|
817
|
+
Defaults to ``False``.
|
|
818
|
+
|
|
819
|
+
"""
|
|
820
|
+
|
|
821
|
+
def __init__(
|
|
822
|
+
self,
|
|
823
|
+
*,
|
|
824
|
+
gamma: float | torch.Tensor,
|
|
825
|
+
value_network: TensorDictModule,
|
|
826
|
+
average_rewards: bool = False,
|
|
827
|
+
differentiable: bool = False,
|
|
828
|
+
skip_existing: bool | None = None,
|
|
829
|
+
advantage_key: NestedKey = None,
|
|
830
|
+
value_target_key: NestedKey = None,
|
|
831
|
+
value_key: NestedKey = None,
|
|
832
|
+
shifted: bool = False,
|
|
833
|
+
device: torch.device | None = None,
|
|
834
|
+
time_dim: int | None = None,
|
|
835
|
+
deactivate_vmap: bool = False,
|
|
836
|
+
):
|
|
837
|
+
super().__init__(
|
|
838
|
+
value_network=value_network,
|
|
839
|
+
differentiable=differentiable,
|
|
840
|
+
advantage_key=advantage_key,
|
|
841
|
+
value_target_key=value_target_key,
|
|
842
|
+
value_key=value_key,
|
|
843
|
+
shifted=shifted,
|
|
844
|
+
skip_existing=skip_existing,
|
|
845
|
+
device=device,
|
|
846
|
+
deactivate_vmap=deactivate_vmap,
|
|
847
|
+
)
|
|
848
|
+
self.register_buffer("gamma", torch.tensor(gamma, device=self._device))
|
|
849
|
+
self.average_rewards = average_rewards
|
|
850
|
+
self.time_dim = time_dim
|
|
851
|
+
|
|
852
|
+
@_self_set_skip_existing
|
|
853
|
+
@_self_set_grad_enabled
|
|
854
|
+
@dispatch
|
|
855
|
+
def forward(
|
|
856
|
+
self,
|
|
857
|
+
tensordict: TensorDictBase,
|
|
858
|
+
*,
|
|
859
|
+
params: TensorDictBase | None = None,
|
|
860
|
+
target_params: TensorDictBase | None = None,
|
|
861
|
+
) -> TensorDictBase:
|
|
862
|
+
"""Computes the TD(1) advantage given the data in tensordict.
|
|
863
|
+
|
|
864
|
+
If a functional module is provided, a nested TensorDict containing the parameters
|
|
865
|
+
(and if relevant the target parameters) can be passed to the module.
|
|
866
|
+
|
|
867
|
+
Args:
|
|
868
|
+
tensordict (TensorDictBase): A TensorDict containing the data
|
|
869
|
+
(an observation key, ``"action"``, ``("next", "reward")``,
|
|
870
|
+
``("next", "done")``, ``("next", "terminated")``,
|
|
871
|
+
and ``"next"`` tensordict state as returned by the environment)
|
|
872
|
+
necessary to compute the value estimates and the TDEstimate.
|
|
873
|
+
The data passed to this module should be structured as :obj:`[*B, T, *F]` where :obj:`B` are
|
|
874
|
+
the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s).
|
|
875
|
+
The tensordict must have shape ``[*B, T]``.
|
|
876
|
+
|
|
877
|
+
Keyword Args:
|
|
878
|
+
params (TensorDictBase, optional): A nested TensorDict containing the params
|
|
879
|
+
to be passed to the functional value network module.
|
|
880
|
+
target_params (TensorDictBase, optional): A nested TensorDict containing the
|
|
881
|
+
target params to be passed to the functional value network module.
|
|
882
|
+
|
|
883
|
+
Returns:
|
|
884
|
+
An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
|
|
885
|
+
|
|
886
|
+
Examples:
|
|
887
|
+
>>> from tensordict import TensorDict
|
|
888
|
+
>>> value_net = TensorDictModule(
|
|
889
|
+
... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
|
|
890
|
+
... )
|
|
891
|
+
>>> module = TDEstimate(
|
|
892
|
+
... gamma=0.98,
|
|
893
|
+
... value_network=value_net,
|
|
894
|
+
... )
|
|
895
|
+
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
|
|
896
|
+
>>> reward = torch.randn(1, 10, 1)
|
|
897
|
+
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
898
|
+
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
899
|
+
>>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "reward": reward, "terminated": terminated}}, [1, 10])
|
|
900
|
+
>>> _ = module(tensordict)
|
|
901
|
+
>>> assert "advantage" in tensordict.keys()
|
|
902
|
+
|
|
903
|
+
The module supports non-tensordict (i.e. unpacked tensordict) inputs too:
|
|
904
|
+
|
|
905
|
+
Examples:
|
|
906
|
+
>>> value_net = TensorDictModule(
|
|
907
|
+
... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
|
|
908
|
+
... )
|
|
909
|
+
>>> module = TDEstimate(
|
|
910
|
+
... gamma=0.98,
|
|
911
|
+
... value_network=value_net,
|
|
912
|
+
... )
|
|
913
|
+
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
|
|
914
|
+
>>> reward = torch.randn(1, 10, 1)
|
|
915
|
+
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
916
|
+
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
917
|
+
>>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
|
|
918
|
+
|
|
919
|
+
"""
|
|
920
|
+
if tensordict.batch_dims < 1:
|
|
921
|
+
raise RuntimeError(
|
|
922
|
+
"Expected input tensordict to have at least one dimensions, got"
|
|
923
|
+
f"tensordict.batch_size = {tensordict.batch_size}"
|
|
924
|
+
)
|
|
925
|
+
|
|
926
|
+
if self.is_stateless and params is None:
|
|
927
|
+
raise RuntimeError(
|
|
928
|
+
"Expected params to be passed to advantage module but got none."
|
|
929
|
+
)
|
|
930
|
+
if self.value_network is not None:
|
|
931
|
+
if params is not None:
|
|
932
|
+
params = params.detach()
|
|
933
|
+
if target_params is None:
|
|
934
|
+
target_params = params.clone(False)
|
|
935
|
+
with hold_out_net(self.value_network) if (
|
|
936
|
+
params is None and target_params is None
|
|
937
|
+
) else nullcontext():
|
|
938
|
+
# we may still need to pass gradient, but we don't want to assign grads to
|
|
939
|
+
# value net params
|
|
940
|
+
value, next_value = self._call_value_nets(
|
|
941
|
+
data=tensordict,
|
|
942
|
+
params=params,
|
|
943
|
+
next_params=target_params,
|
|
944
|
+
single_call=self.shifted,
|
|
945
|
+
value_key=self.tensor_keys.value,
|
|
946
|
+
detach_next=True,
|
|
947
|
+
vmap_randomness=self.vmap_randomness,
|
|
948
|
+
)
|
|
949
|
+
else:
|
|
950
|
+
value = tensordict.get(self.tensor_keys.value)
|
|
951
|
+
next_value = tensordict.get(("next", self.tensor_keys.value))
|
|
952
|
+
|
|
953
|
+
value_target = self.value_estimate(tensordict, next_value=next_value)
|
|
954
|
+
|
|
955
|
+
tensordict.set(self.tensor_keys.advantage, value_target - value)
|
|
956
|
+
tensordict.set(self.tensor_keys.value_target, value_target)
|
|
957
|
+
return tensordict
|
|
958
|
+
|
|
959
|
+
def value_estimate(
|
|
960
|
+
self,
|
|
961
|
+
tensordict,
|
|
962
|
+
target_params: TensorDictBase | None = None,
|
|
963
|
+
next_value: torch.Tensor | None = None,
|
|
964
|
+
time_dim: int | None = None,
|
|
965
|
+
**kwargs,
|
|
966
|
+
):
|
|
967
|
+
reward = tensordict.get(("next", self.tensor_keys.reward))
|
|
968
|
+
device = reward.device
|
|
969
|
+
if self.gamma.device != device:
|
|
970
|
+
self.gamma = self.gamma.to(device)
|
|
971
|
+
gamma = self.gamma
|
|
972
|
+
steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None)
|
|
973
|
+
if steps_to_next_obs is not None:
|
|
974
|
+
gamma = gamma ** steps_to_next_obs.view_as(reward)
|
|
975
|
+
|
|
976
|
+
if self.average_rewards:
|
|
977
|
+
reward = reward - reward.mean()
|
|
978
|
+
reward = reward / reward.std().clamp_min(1e-5)
|
|
979
|
+
tensordict.set(
|
|
980
|
+
("next", self.tensor_keys.reward), reward
|
|
981
|
+
) # we must update the rewards if they are used later in the code
|
|
982
|
+
if next_value is None:
|
|
983
|
+
next_value = self._next_value(tensordict, target_params, kwargs=kwargs)
|
|
984
|
+
|
|
985
|
+
done = tensordict.get(("next", self.tensor_keys.done))
|
|
986
|
+
terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done)
|
|
987
|
+
time_dim = self._get_time_dim(time_dim, tensordict)
|
|
988
|
+
value_target = vec_td1_return_estimate(
|
|
989
|
+
gamma,
|
|
990
|
+
next_value,
|
|
991
|
+
reward,
|
|
992
|
+
done=done,
|
|
993
|
+
terminated=terminated,
|
|
994
|
+
time_dim=time_dim,
|
|
995
|
+
)
|
|
996
|
+
return value_target
|
|
997
|
+
|
|
998
|
+
|
|
999
|
+
class TDLambdaEstimator(ValueEstimatorBase):
|
|
1000
|
+
r"""TD(:math:`\lambda`) estimate of advantage function.
|
|
1001
|
+
|
|
1002
|
+
Args:
|
|
1003
|
+
gamma (scalar): exponential mean discount.
|
|
1004
|
+
lmbda (scalar): trajectory discount.
|
|
1005
|
+
value_network (TensorDictModule): value operator used to retrieve the value estimates.
|
|
1006
|
+
average_rewards (bool, optional): if ``True``, rewards will be standardized
|
|
1007
|
+
before the TD is computed.
|
|
1008
|
+
differentiable (bool, optional): if ``True``, gradients are propagated through
|
|
1009
|
+
the computation of the value function. Default is ``False``.
|
|
1010
|
+
|
|
1011
|
+
.. note::
|
|
1012
|
+
The proper way to make the function call non-differentiable is to
|
|
1013
|
+
decorate it in a `torch.no_grad()` context manager/decorator or
|
|
1014
|
+
pass detached parameters for functional modules.
|
|
1015
|
+
|
|
1016
|
+
vectorized (bool, optional): whether to use the vectorized version of the
|
|
1017
|
+
lambda return. Default is `True`.
|
|
1018
|
+
skip_existing (bool, optional): if ``True``, the value network will skip
|
|
1019
|
+
modules which outputs are already present in the tensordict.
|
|
1020
|
+
Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()`
|
|
1021
|
+
is not affected.
|
|
1022
|
+
advantage_key (str or tuple of str, optional): [Deprecated] the key of
|
|
1023
|
+
the advantage entry. Defaults to ``"advantage"``.
|
|
1024
|
+
value_target_key (str or tuple of str, optional): [Deprecated] the key
|
|
1025
|
+
of the advantage entry. Defaults to ``"value_target"``.
|
|
1026
|
+
value_key (str or tuple of str, optional): [Deprecated] the value key to
|
|
1027
|
+
read from the input tensordict. Defaults to ``"state_value"``.
|
|
1028
|
+
shifted (bool, optional): if ``True``, the value and next value are
|
|
1029
|
+
estimated with a single call to the value network. This is faster
|
|
1030
|
+
but is only valid whenever (1) the ``"next"`` value is shifted by
|
|
1031
|
+
only one time step (which is not the case with multi-step value
|
|
1032
|
+
estimation, for instance) and (2) when the parameters used at time
|
|
1033
|
+
``t`` and ``t+1`` are identical (which is not the case when target
|
|
1034
|
+
parameters are to be used). Defaults to ``False``.
|
|
1035
|
+
device (torch.device, optional): the device where the buffers will be instantiated.
|
|
1036
|
+
Defaults to ``torch.get_default_device()``.
|
|
1037
|
+
time_dim (int, optional): the dimension corresponding to the time
|
|
1038
|
+
in the input tensordict. If not provided, defaults to the dimension
|
|
1039
|
+
marked with the ``"time"`` name if any, and to the last dimension
|
|
1040
|
+
otherwise. Can be overridden during a call to
|
|
1041
|
+
:meth:`~.value_estimate`.
|
|
1042
|
+
Negative dimensions are considered with respect to the input
|
|
1043
|
+
tensordict.
|
|
1044
|
+
deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
|
|
1045
|
+
Defaults to ``False``.
|
|
1046
|
+
|
|
1047
|
+
"""
|
|
1048
|
+
|
|
1049
|
+
def __init__(
|
|
1050
|
+
self,
|
|
1051
|
+
*,
|
|
1052
|
+
gamma: float | torch.Tensor,
|
|
1053
|
+
lmbda: float | torch.Tensor,
|
|
1054
|
+
value_network: TensorDictModule,
|
|
1055
|
+
average_rewards: bool = False,
|
|
1056
|
+
differentiable: bool = False,
|
|
1057
|
+
vectorized: bool = True,
|
|
1058
|
+
skip_existing: bool | None = None,
|
|
1059
|
+
advantage_key: NestedKey = None,
|
|
1060
|
+
value_target_key: NestedKey = None,
|
|
1061
|
+
value_key: NestedKey = None,
|
|
1062
|
+
shifted: bool = False,
|
|
1063
|
+
device: torch.device | None = None,
|
|
1064
|
+
time_dim: int | None = None,
|
|
1065
|
+
deactivate_vmap: bool = False,
|
|
1066
|
+
):
|
|
1067
|
+
super().__init__(
|
|
1068
|
+
value_network=value_network,
|
|
1069
|
+
differentiable=differentiable,
|
|
1070
|
+
advantage_key=advantage_key,
|
|
1071
|
+
value_target_key=value_target_key,
|
|
1072
|
+
value_key=value_key,
|
|
1073
|
+
skip_existing=skip_existing,
|
|
1074
|
+
shifted=shifted,
|
|
1075
|
+
device=device,
|
|
1076
|
+
deactivate_vmap=deactivate_vmap,
|
|
1077
|
+
)
|
|
1078
|
+
self.register_buffer("gamma", torch.tensor(gamma, device=self._device))
|
|
1079
|
+
self.register_buffer("lmbda", torch.tensor(lmbda, device=self._device))
|
|
1080
|
+
self.average_rewards = average_rewards
|
|
1081
|
+
self.vectorized = vectorized
|
|
1082
|
+
self.time_dim = time_dim
|
|
1083
|
+
|
|
1084
|
+
@property
|
|
1085
|
+
def vectorized(self):
|
|
1086
|
+
if is_dynamo_compiling():
|
|
1087
|
+
return False
|
|
1088
|
+
return self._vectorized
|
|
1089
|
+
|
|
1090
|
+
@vectorized.setter
|
|
1091
|
+
def vectorized(self, value):
|
|
1092
|
+
self._vectorized = value
|
|
1093
|
+
|
|
1094
|
+
@_self_set_skip_existing
|
|
1095
|
+
@_self_set_grad_enabled
|
|
1096
|
+
@dispatch
|
|
1097
|
+
def forward(
|
|
1098
|
+
self,
|
|
1099
|
+
tensordict: TensorDictBase,
|
|
1100
|
+
*,
|
|
1101
|
+
params: list[Tensor] | None = None,
|
|
1102
|
+
target_params: list[Tensor] | None = None,
|
|
1103
|
+
) -> TensorDictBase:
|
|
1104
|
+
r"""Computes the TD(:math:`\lambda`) advantage given the data in tensordict.
|
|
1105
|
+
|
|
1106
|
+
If a functional module is provided, a nested TensorDict containing the parameters
|
|
1107
|
+
(and if relevant the target parameters) can be passed to the module.
|
|
1108
|
+
|
|
1109
|
+
Args:
|
|
1110
|
+
tensordict (TensorDictBase): A TensorDict containing the data
|
|
1111
|
+
(an observation key, ``"action"``, ``("next", "reward")``,
|
|
1112
|
+
``("next", "done")``, ``("next", "terminated")``,
|
|
1113
|
+
and ``"next"`` tensordict state as returned by the environment)
|
|
1114
|
+
necessary to compute the value estimates and the TDLambdaEstimate.
|
|
1115
|
+
The data passed to this module should be structured as :obj:`[*B, T, *F]` where :obj:`B` are
|
|
1116
|
+
the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s).
|
|
1117
|
+
The tensordict must have shape ``[*B, T]``.
|
|
1118
|
+
|
|
1119
|
+
Keyword Args:
|
|
1120
|
+
params (TensorDictBase, optional): A nested TensorDict containing the params
|
|
1121
|
+
to be passed to the functional value network module.
|
|
1122
|
+
target_params (TensorDictBase, optional): A nested TensorDict containing the
|
|
1123
|
+
target params to be passed to the functional value network module.
|
|
1124
|
+
|
|
1125
|
+
Returns:
|
|
1126
|
+
An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
|
|
1127
|
+
|
|
1128
|
+
Examples:
|
|
1129
|
+
>>> from tensordict import TensorDict
|
|
1130
|
+
>>> value_net = TensorDictModule(
|
|
1131
|
+
... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
|
|
1132
|
+
... )
|
|
1133
|
+
>>> module = TDLambdaEstimator(
|
|
1134
|
+
... gamma=0.98,
|
|
1135
|
+
... lmbda=0.94,
|
|
1136
|
+
... value_network=value_net,
|
|
1137
|
+
... )
|
|
1138
|
+
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
|
|
1139
|
+
>>> reward = torch.randn(1, 10, 1)
|
|
1140
|
+
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
1141
|
+
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
1142
|
+
>>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "reward": reward, "terminated": terminated}}, [1, 10])
|
|
1143
|
+
>>> _ = module(tensordict)
|
|
1144
|
+
>>> assert "advantage" in tensordict.keys()
|
|
1145
|
+
|
|
1146
|
+
The module supports non-tensordict (i.e. unpacked tensordict) inputs too:
|
|
1147
|
+
|
|
1148
|
+
Examples:
|
|
1149
|
+
>>> value_net = TensorDictModule(
|
|
1150
|
+
... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
|
|
1151
|
+
... )
|
|
1152
|
+
>>> module = TDLambdaEstimator(
|
|
1153
|
+
... gamma=0.98,
|
|
1154
|
+
... lmbda=0.94,
|
|
1155
|
+
... value_network=value_net,
|
|
1156
|
+
... )
|
|
1157
|
+
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
|
|
1158
|
+
>>> reward = torch.randn(1, 10, 1)
|
|
1159
|
+
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
1160
|
+
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
1161
|
+
>>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
|
|
1162
|
+
|
|
1163
|
+
"""
|
|
1164
|
+
if tensordict.batch_dims < 1:
|
|
1165
|
+
raise RuntimeError(
|
|
1166
|
+
"Expected input tensordict to have at least one dimensions, got"
|
|
1167
|
+
f"tensordict.batch_size = {tensordict.batch_size}"
|
|
1168
|
+
)
|
|
1169
|
+
if self.is_stateless and params is None:
|
|
1170
|
+
raise RuntimeError(
|
|
1171
|
+
"Expected params to be passed to advantage module but got none."
|
|
1172
|
+
)
|
|
1173
|
+
if self.value_network is not None:
|
|
1174
|
+
if params is not None:
|
|
1175
|
+
params = params.detach()
|
|
1176
|
+
if target_params is None:
|
|
1177
|
+
target_params = params.clone(False)
|
|
1178
|
+
with hold_out_net(self.value_network) if (
|
|
1179
|
+
params is None and target_params is None
|
|
1180
|
+
) else nullcontext():
|
|
1181
|
+
# we may still need to pass gradient, but we don't want to assign grads to
|
|
1182
|
+
# value net params
|
|
1183
|
+
value, next_value = self._call_value_nets(
|
|
1184
|
+
data=tensordict,
|
|
1185
|
+
params=params,
|
|
1186
|
+
next_params=target_params,
|
|
1187
|
+
single_call=self.shifted,
|
|
1188
|
+
value_key=self.tensor_keys.value,
|
|
1189
|
+
detach_next=True,
|
|
1190
|
+
vmap_randomness=self.vmap_randomness,
|
|
1191
|
+
)
|
|
1192
|
+
else:
|
|
1193
|
+
value = tensordict.get(self.tensor_keys.value)
|
|
1194
|
+
next_value = tensordict.get(("next", self.tensor_keys.value))
|
|
1195
|
+
value_target = self.value_estimate(tensordict, next_value=next_value)
|
|
1196
|
+
|
|
1197
|
+
tensordict.set(self.tensor_keys.advantage, value_target - value)
|
|
1198
|
+
tensordict.set(self.tensor_keys.value_target, value_target)
|
|
1199
|
+
return tensordict
|
|
1200
|
+
|
|
1201
|
+
def value_estimate(
|
|
1202
|
+
self,
|
|
1203
|
+
tensordict,
|
|
1204
|
+
target_params: TensorDictBase | None = None,
|
|
1205
|
+
next_value: torch.Tensor | None = None,
|
|
1206
|
+
time_dim: int | None = None,
|
|
1207
|
+
**kwargs,
|
|
1208
|
+
):
|
|
1209
|
+
reward = tensordict.get(("next", self.tensor_keys.reward))
|
|
1210
|
+
device = reward.device
|
|
1211
|
+
|
|
1212
|
+
if self.gamma.device != device:
|
|
1213
|
+
self.gamma = self.gamma.to(device)
|
|
1214
|
+
gamma = self.gamma
|
|
1215
|
+
steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None)
|
|
1216
|
+
if steps_to_next_obs is not None:
|
|
1217
|
+
gamma = gamma ** steps_to_next_obs.view_as(reward)
|
|
1218
|
+
|
|
1219
|
+
if self.lmbda.device != device:
|
|
1220
|
+
self.lmbda = self.lmbda.to(device)
|
|
1221
|
+
lmbda = self.lmbda
|
|
1222
|
+
if self.average_rewards:
|
|
1223
|
+
reward = reward - reward.mean()
|
|
1224
|
+
reward = reward / reward.std().clamp_min(1e-4)
|
|
1225
|
+
tensordict.set(
|
|
1226
|
+
("next", self.tensor_keys.steps_to_next_obs), reward
|
|
1227
|
+
) # we must update the rewards if they are used later in the code
|
|
1228
|
+
|
|
1229
|
+
if next_value is None:
|
|
1230
|
+
next_value = self._next_value(tensordict, target_params, kwargs=kwargs)
|
|
1231
|
+
|
|
1232
|
+
done = tensordict.get(("next", self.tensor_keys.done))
|
|
1233
|
+
terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done)
|
|
1234
|
+
time_dim = self._get_time_dim(time_dim, tensordict)
|
|
1235
|
+
if self.vectorized:
|
|
1236
|
+
val = vec_td_lambda_return_estimate(
|
|
1237
|
+
gamma,
|
|
1238
|
+
lmbda,
|
|
1239
|
+
next_value,
|
|
1240
|
+
reward,
|
|
1241
|
+
done=done,
|
|
1242
|
+
terminated=terminated,
|
|
1243
|
+
time_dim=time_dim,
|
|
1244
|
+
)
|
|
1245
|
+
else:
|
|
1246
|
+
val = td_lambda_return_estimate(
|
|
1247
|
+
gamma,
|
|
1248
|
+
lmbda,
|
|
1249
|
+
next_value,
|
|
1250
|
+
reward,
|
|
1251
|
+
done=done,
|
|
1252
|
+
terminated=terminated,
|
|
1253
|
+
time_dim=time_dim,
|
|
1254
|
+
)
|
|
1255
|
+
return val
|
|
1256
|
+
|
|
1257
|
+
|
|
1258
|
+
class GAE(ValueEstimatorBase):
|
|
1259
|
+
"""A class wrapper around the generalized advantage estimate functional.
|
|
1260
|
+
|
|
1261
|
+
Refer to "HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION"
|
|
1262
|
+
https://arxiv.org/pdf/1506.02438.pdf for more context.
|
|
1263
|
+
|
|
1264
|
+
Args:
|
|
1265
|
+
gamma (scalar): exponential mean discount.
|
|
1266
|
+
lmbda (scalar): trajectory discount.
|
|
1267
|
+
value_network (TensorDictModule, optional): value operator used to retrieve the value estimates.
|
|
1268
|
+
If ``None``, this module will expect the ``"state_value"`` keys to be already filled, and
|
|
1269
|
+
will not call the value network to produce it.
|
|
1270
|
+
average_gae (bool): if ``True``, the resulting GAE values will be standardized.
|
|
1271
|
+
Default is ``False``.
|
|
1272
|
+
differentiable (bool, optional): if ``True``, gradients are propagated through
|
|
1273
|
+
the computation of the value function. Default is ``False``.
|
|
1274
|
+
|
|
1275
|
+
.. note::
|
|
1276
|
+
The proper way to make the function call non-differentiable is to
|
|
1277
|
+
decorate it in a `torch.no_grad()` context manager/decorator or
|
|
1278
|
+
pass detached parameters for functional modules.
|
|
1279
|
+
|
|
1280
|
+
vectorized (bool, optional): whether to use the vectorized version of the
|
|
1281
|
+
lambda return. Default is `True` if not compiling.
|
|
1282
|
+
skip_existing (bool, optional): if ``True``, the value network will skip
|
|
1283
|
+
modules which outputs are already present in the tensordict.
|
|
1284
|
+
Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()`
|
|
1285
|
+
is not affected.
|
|
1286
|
+
Defaults to "state_value".
|
|
1287
|
+
advantage_key (str or tuple of str, optional): [Deprecated] the key of
|
|
1288
|
+
the advantage entry. Defaults to ``"advantage"``.
|
|
1289
|
+
value_target_key (str or tuple of str, optional): [Deprecated] the key
|
|
1290
|
+
of the advantage entry. Defaults to ``"value_target"``.
|
|
1291
|
+
value_key (str or tuple of str, optional): [Deprecated] the value key to
|
|
1292
|
+
read from the input tensordict. Defaults to ``"state_value"``.
|
|
1293
|
+
shifted (bool, optional): if ``True``, the value and next value are
|
|
1294
|
+
estimated with a single call to the value network. This is faster
|
|
1295
|
+
but is only valid whenever (1) the ``"next"`` value is shifted by
|
|
1296
|
+
only one time step (which is not the case with multi-step value
|
|
1297
|
+
estimation, for instance) and (2) when the parameters used at time
|
|
1298
|
+
``t`` and ``t+1`` are identical (which is not the case when target
|
|
1299
|
+
parameters are to be used). Defaults to ``False``.
|
|
1300
|
+
device (torch.device, optional): the device where the buffers will be instantiated.
|
|
1301
|
+
Defaults to ``torch.get_default_device()``.
|
|
1302
|
+
time_dim (int, optional): the dimension corresponding to the time
|
|
1303
|
+
in the input tensordict. If not provided, defaults to the dimension
|
|
1304
|
+
marked with the ``"time"`` name if any, and to the last dimension
|
|
1305
|
+
otherwise. Can be overridden during a call to
|
|
1306
|
+
:meth:`~.value_estimate`.
|
|
1307
|
+
Negative dimensions are considered with respect to the input
|
|
1308
|
+
tensordict.
|
|
1309
|
+
auto_reset_env (bool, optional): if ``True``, the last ``"next"`` state
|
|
1310
|
+
of the episode isn't valid, so the GAE calculation will use the ``value``
|
|
1311
|
+
instead of ``next_value`` to bootstrap truncated episodes.
|
|
1312
|
+
deactivate_vmap (bool, optional): if ``True``, no vmap call will be used, and
|
|
1313
|
+
vectorized maps will be replaced with simple for loops. Defaults to ``False``.
|
|
1314
|
+
|
|
1315
|
+
GAE will return an :obj:`"advantage"` entry containing the advantage value. It will also
|
|
1316
|
+
return a :obj:`"value_target"` entry with the return value that is to be used
|
|
1317
|
+
to train the value network. Finally, if :obj:`gradient_mode` is ``True``,
|
|
1318
|
+
an additional and differentiable :obj:`"value_error"` entry will be returned,
|
|
1319
|
+
which simply represents the difference between the return and the value network
|
|
1320
|
+
output (i.e. an additional distance loss should be applied to that signed value).
|
|
1321
|
+
|
|
1322
|
+
.. note::
|
|
1323
|
+
As other advantage functions do, if the ``value_key`` is already present
|
|
1324
|
+
in the input tensordict, the GAE module will ignore the calls to the value
|
|
1325
|
+
network (if any) and use the provided value instead.
|
|
1326
|
+
|
|
1327
|
+
.. note:: GAE can be used with value networks that rely on recurrent neural networks, provided that the
|
|
1328
|
+
init markers (`"is_init"`) and terminated / truncated markers are properly set.
|
|
1329
|
+
If `shifted=True`, the trajectory batch will be flattened and the last step of each trajectory will
|
|
1330
|
+
be placed within the flat tensordict after the last step from the root, such that each trajectory has
|
|
1331
|
+
`T+1` elements. If `shifted=False`, the root and `"next"` trajecotries will be stacked and the value
|
|
1332
|
+
network will be called with `vmap` over the stack of trajectories. Because RNNs require fair amount of
|
|
1333
|
+
control flow, they are currently not compatible with `torch.vmap` and, as such, the `deactivate_vmap` option
|
|
1334
|
+
must be turned on in these cases.
|
|
1335
|
+
Similarly, if `shifted=False`, the `"is_init"` entry of the root tensordict will be copied onto the
|
|
1336
|
+
`"is_init"` of the `"next"` entry, such that trajectories are well separated both for root and `"next"` data.
|
|
1337
|
+
"""
|
|
1338
|
+
|
|
1339
|
+
value_network: TensorDictModule | None
|
|
1340
|
+
|
|
1341
|
+
def __init__(
|
|
1342
|
+
self,
|
|
1343
|
+
*,
|
|
1344
|
+
gamma: float | torch.Tensor,
|
|
1345
|
+
lmbda: float | torch.Tensor,
|
|
1346
|
+
value_network: TensorDictModule | None,
|
|
1347
|
+
average_gae: bool = False,
|
|
1348
|
+
differentiable: bool = False,
|
|
1349
|
+
vectorized: bool | None = None,
|
|
1350
|
+
skip_existing: bool | None = None,
|
|
1351
|
+
advantage_key: NestedKey = None,
|
|
1352
|
+
value_target_key: NestedKey = None,
|
|
1353
|
+
value_key: NestedKey = None,
|
|
1354
|
+
shifted: bool = False,
|
|
1355
|
+
device: torch.device | None = None,
|
|
1356
|
+
time_dim: int | None = None,
|
|
1357
|
+
auto_reset_env: bool = False,
|
|
1358
|
+
deactivate_vmap: bool = False,
|
|
1359
|
+
):
|
|
1360
|
+
super().__init__(
|
|
1361
|
+
shifted=shifted,
|
|
1362
|
+
value_network=value_network,
|
|
1363
|
+
differentiable=differentiable,
|
|
1364
|
+
advantage_key=advantage_key,
|
|
1365
|
+
value_target_key=value_target_key,
|
|
1366
|
+
value_key=value_key,
|
|
1367
|
+
skip_existing=skip_existing,
|
|
1368
|
+
device=device,
|
|
1369
|
+
)
|
|
1370
|
+
self.register_buffer(
|
|
1371
|
+
"gamma",
|
|
1372
|
+
gamma.to(self._device)
|
|
1373
|
+
if isinstance(gamma, Tensor)
|
|
1374
|
+
else torch.tensor(gamma, device=self._device),
|
|
1375
|
+
)
|
|
1376
|
+
self.register_buffer(
|
|
1377
|
+
"lmbda",
|
|
1378
|
+
lmbda.to(self._device)
|
|
1379
|
+
if isinstance(lmbda, Tensor)
|
|
1380
|
+
else torch.tensor(lmbda, device=self._device),
|
|
1381
|
+
)
|
|
1382
|
+
self.average_gae = average_gae
|
|
1383
|
+
self.vectorized = vectorized
|
|
1384
|
+
self.time_dim = time_dim
|
|
1385
|
+
self.auto_reset_env = auto_reset_env
|
|
1386
|
+
self.deactivate_vmap = deactivate_vmap
|
|
1387
|
+
|
|
1388
|
+
@property
|
|
1389
|
+
def vectorized(self):
|
|
1390
|
+
if is_dynamo_compiling():
|
|
1391
|
+
return False
|
|
1392
|
+
return self._vectorized
|
|
1393
|
+
|
|
1394
|
+
@vectorized.setter
|
|
1395
|
+
def vectorized(self, value):
|
|
1396
|
+
self._vectorized = value
|
|
1397
|
+
|
|
1398
|
+
@_self_set_skip_existing
|
|
1399
|
+
@_self_set_grad_enabled
|
|
1400
|
+
@dispatch
|
|
1401
|
+
def forward(
|
|
1402
|
+
self,
|
|
1403
|
+
tensordict: TensorDictBase,
|
|
1404
|
+
*,
|
|
1405
|
+
params: list[Tensor] | None = None,
|
|
1406
|
+
target_params: list[Tensor] | None = None,
|
|
1407
|
+
time_dim: int | None = None,
|
|
1408
|
+
) -> TensorDictBase:
|
|
1409
|
+
"""Computes the GAE given the data in tensordict.
|
|
1410
|
+
|
|
1411
|
+
If a functional module is provided, a nested TensorDict containing the parameters
|
|
1412
|
+
(and if relevant the target parameters) can be passed to the module.
|
|
1413
|
+
|
|
1414
|
+
Args:
|
|
1415
|
+
tensordict (TensorDictBase): A TensorDict containing the data
|
|
1416
|
+
(an observation key, ``"action"``, ``("next", "reward")``,
|
|
1417
|
+
``("next", "done")``, ``("next", "terminated")``,
|
|
1418
|
+
and ``"next"`` tensordict state as returned by the environment)
|
|
1419
|
+
necessary to compute the value estimates and the GAE.
|
|
1420
|
+
The data passed to this module should be structured as :obj:`[*B, T, *F]` where :obj:`B` are
|
|
1421
|
+
the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s).
|
|
1422
|
+
The tensordict must have shape ``[*B, T]``.
|
|
1423
|
+
|
|
1424
|
+
Keyword Args:
|
|
1425
|
+
params (TensorDictBase, optional): A nested TensorDict containing the params
|
|
1426
|
+
to be passed to the functional value network module.
|
|
1427
|
+
target_params (TensorDictBase, optional): A nested TensorDict containing the
|
|
1428
|
+
target params to be passed to the functional value network module.
|
|
1429
|
+
time_dim (int, optional): the dimension corresponding to the time
|
|
1430
|
+
in the input tensordict. If not provided, defaults to the dimension
|
|
1431
|
+
marked with the ``"time"`` name if any, and to the last dimension
|
|
1432
|
+
otherwise.
|
|
1433
|
+
Negative dimensions are considered with respect to the input
|
|
1434
|
+
tensordict.
|
|
1435
|
+
|
|
1436
|
+
Returns:
|
|
1437
|
+
An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
|
|
1438
|
+
|
|
1439
|
+
Examples:
|
|
1440
|
+
>>> from tensordict import TensorDict
|
|
1441
|
+
>>> value_net = TensorDictModule(
|
|
1442
|
+
... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
|
|
1443
|
+
... )
|
|
1444
|
+
>>> module = GAE(
|
|
1445
|
+
... gamma=0.98,
|
|
1446
|
+
... lmbda=0.94,
|
|
1447
|
+
... value_network=value_net,
|
|
1448
|
+
... differentiable=False,
|
|
1449
|
+
... )
|
|
1450
|
+
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
|
|
1451
|
+
>>> reward = torch.randn(1, 10, 1)
|
|
1452
|
+
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
1453
|
+
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
1454
|
+
>>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs}, "done": done, "reward": reward, "terminated": terminated}, [1, 10])
|
|
1455
|
+
>>> _ = module(tensordict)
|
|
1456
|
+
>>> assert "advantage" in tensordict.keys()
|
|
1457
|
+
|
|
1458
|
+
The module supports non-tensordict (i.e. unpacked tensordict) inputs too:
|
|
1459
|
+
|
|
1460
|
+
Examples:
|
|
1461
|
+
>>> value_net = TensorDictModule(
|
|
1462
|
+
... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
|
|
1463
|
+
... )
|
|
1464
|
+
>>> module = GAE(
|
|
1465
|
+
... gamma=0.98,
|
|
1466
|
+
... lmbda=0.94,
|
|
1467
|
+
... value_network=value_net,
|
|
1468
|
+
... differentiable=False,
|
|
1469
|
+
... )
|
|
1470
|
+
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
|
|
1471
|
+
>>> reward = torch.randn(1, 10, 1)
|
|
1472
|
+
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
1473
|
+
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
1474
|
+
>>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated)
|
|
1475
|
+
|
|
1476
|
+
"""
|
|
1477
|
+
if tensordict.batch_dims < 1:
|
|
1478
|
+
raise RuntimeError(
|
|
1479
|
+
"Expected input tensordict to have at least one dimension, got "
|
|
1480
|
+
f"tensordict.batch_size = {tensordict.batch_size}"
|
|
1481
|
+
)
|
|
1482
|
+
reward = tensordict.get(("next", self.tensor_keys.reward))
|
|
1483
|
+
device = reward.device
|
|
1484
|
+
if self.gamma.device != device:
|
|
1485
|
+
self.gamma = self.gamma.to(device)
|
|
1486
|
+
gamma = self.gamma
|
|
1487
|
+
if self.lmbda.device != device:
|
|
1488
|
+
self.lmbda = self.lmbda.to(device)
|
|
1489
|
+
lmbda = self.lmbda
|
|
1490
|
+
steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None)
|
|
1491
|
+
if steps_to_next_obs is not None:
|
|
1492
|
+
gamma = gamma ** steps_to_next_obs.view_as(reward)
|
|
1493
|
+
|
|
1494
|
+
if self.value_network is not None:
|
|
1495
|
+
if params is not None:
|
|
1496
|
+
params = params.detach()
|
|
1497
|
+
if target_params is None:
|
|
1498
|
+
target_params = params.clone(False)
|
|
1499
|
+
with hold_out_net(self.value_network) if (
|
|
1500
|
+
params is None and target_params is None
|
|
1501
|
+
) else nullcontext():
|
|
1502
|
+
# with torch.no_grad():
|
|
1503
|
+
# we may still need to pass gradient, but we don't want to assign grads to
|
|
1504
|
+
# value net params
|
|
1505
|
+
value, next_value = self._call_value_nets(
|
|
1506
|
+
data=tensordict,
|
|
1507
|
+
params=params,
|
|
1508
|
+
next_params=target_params,
|
|
1509
|
+
single_call=self.shifted,
|
|
1510
|
+
value_key=self.tensor_keys.value,
|
|
1511
|
+
detach_next=True,
|
|
1512
|
+
vmap_randomness=self.vmap_randomness,
|
|
1513
|
+
)
|
|
1514
|
+
else:
|
|
1515
|
+
value = tensordict.get(self.tensor_keys.value)
|
|
1516
|
+
next_value = tensordict.get(("next", self.tensor_keys.value))
|
|
1517
|
+
|
|
1518
|
+
if value is None:
|
|
1519
|
+
raise ValueError(
|
|
1520
|
+
f"The tensor with key {self.tensor_keys.value} is missing, and no value network was provided."
|
|
1521
|
+
)
|
|
1522
|
+
if next_value is None:
|
|
1523
|
+
raise ValueError(
|
|
1524
|
+
f"The tensor with key {('next', self.tensor_keys.value)} is missing, and no value network was provided."
|
|
1525
|
+
)
|
|
1526
|
+
|
|
1527
|
+
done = tensordict.get(("next", self.tensor_keys.done))
|
|
1528
|
+
terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done)
|
|
1529
|
+
time_dim = self._get_time_dim(time_dim, tensordict)
|
|
1530
|
+
|
|
1531
|
+
if self.auto_reset_env:
|
|
1532
|
+
truncated = tensordict.get(("next", "truncated"))
|
|
1533
|
+
if truncated.any():
|
|
1534
|
+
reward += gamma * value * truncated
|
|
1535
|
+
|
|
1536
|
+
if self.vectorized:
|
|
1537
|
+
adv, value_target = vec_generalized_advantage_estimate(
|
|
1538
|
+
gamma,
|
|
1539
|
+
lmbda,
|
|
1540
|
+
value,
|
|
1541
|
+
next_value,
|
|
1542
|
+
reward,
|
|
1543
|
+
done=done,
|
|
1544
|
+
terminated=terminated if not self.auto_reset_env else done,
|
|
1545
|
+
time_dim=time_dim,
|
|
1546
|
+
)
|
|
1547
|
+
else:
|
|
1548
|
+
adv, value_target = generalized_advantage_estimate(
|
|
1549
|
+
gamma,
|
|
1550
|
+
lmbda,
|
|
1551
|
+
value,
|
|
1552
|
+
next_value,
|
|
1553
|
+
reward,
|
|
1554
|
+
done=done,
|
|
1555
|
+
terminated=terminated if not self.auto_reset_env else done,
|
|
1556
|
+
time_dim=time_dim,
|
|
1557
|
+
)
|
|
1558
|
+
|
|
1559
|
+
if self.average_gae:
|
|
1560
|
+
loc = adv.mean()
|
|
1561
|
+
scale = adv.std().clamp_min(1e-4)
|
|
1562
|
+
adv = adv - loc
|
|
1563
|
+
adv = adv / scale
|
|
1564
|
+
|
|
1565
|
+
tensordict.set(self.tensor_keys.advantage, adv)
|
|
1566
|
+
tensordict.set(self.tensor_keys.value_target, value_target)
|
|
1567
|
+
|
|
1568
|
+
return tensordict
|
|
1569
|
+
|
|
1570
|
+
def value_estimate(
|
|
1571
|
+
self,
|
|
1572
|
+
tensordict,
|
|
1573
|
+
params: TensorDictBase | None = None,
|
|
1574
|
+
target_params: TensorDictBase | None = None,
|
|
1575
|
+
time_dim: int | None = None,
|
|
1576
|
+
**kwargs,
|
|
1577
|
+
):
|
|
1578
|
+
if tensordict.batch_dims < 1:
|
|
1579
|
+
raise RuntimeError(
|
|
1580
|
+
"Expected input tensordict to have at least one dimensions, got"
|
|
1581
|
+
f"tensordict.batch_size = {tensordict.batch_size}"
|
|
1582
|
+
)
|
|
1583
|
+
reward = tensordict.get(("next", self.tensor_keys.reward))
|
|
1584
|
+
device = reward.device
|
|
1585
|
+
if self.gamma.device != device:
|
|
1586
|
+
self.gamma = self.gamma.to(device)
|
|
1587
|
+
gamma = self.gamma
|
|
1588
|
+
if self.lmbda.device != device:
|
|
1589
|
+
self.lmbda = self.lmbda.to(device)
|
|
1590
|
+
lmbda = self.lmbda
|
|
1591
|
+
steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None)
|
|
1592
|
+
if steps_to_next_obs is not None:
|
|
1593
|
+
gamma = gamma ** steps_to_next_obs.view_as(reward)
|
|
1594
|
+
|
|
1595
|
+
time_dim = self._get_time_dim(time_dim, tensordict)
|
|
1596
|
+
|
|
1597
|
+
if self.is_stateless and params is None:
|
|
1598
|
+
raise RuntimeError(
|
|
1599
|
+
"Expected params to be passed to advantage module but got none."
|
|
1600
|
+
)
|
|
1601
|
+
if self.value_network is not None:
|
|
1602
|
+
if params is not None:
|
|
1603
|
+
params = params.detach()
|
|
1604
|
+
if target_params is None:
|
|
1605
|
+
target_params = params.clone(False)
|
|
1606
|
+
with hold_out_net(self.value_network) if (
|
|
1607
|
+
params is None and target_params is None
|
|
1608
|
+
) else nullcontext():
|
|
1609
|
+
# we may still need to pass gradient, but we don't want to assign grads to
|
|
1610
|
+
# value net params
|
|
1611
|
+
value, next_value = self._call_value_nets(
|
|
1612
|
+
data=tensordict,
|
|
1613
|
+
params=params,
|
|
1614
|
+
next_params=target_params,
|
|
1615
|
+
single_call=self.shifted,
|
|
1616
|
+
value_key=self.tensor_keys.value,
|
|
1617
|
+
detach_next=True,
|
|
1618
|
+
vmap_randomness=self.vmap_randomness,
|
|
1619
|
+
)
|
|
1620
|
+
else:
|
|
1621
|
+
value = tensordict.get(self.tensor_keys.value)
|
|
1622
|
+
next_value = tensordict.get(("next", self.tensor_keys.value))
|
|
1623
|
+
done = tensordict.get(("next", self.tensor_keys.done))
|
|
1624
|
+
terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done)
|
|
1625
|
+
_, value_target = vec_generalized_advantage_estimate(
|
|
1626
|
+
gamma,
|
|
1627
|
+
lmbda,
|
|
1628
|
+
value,
|
|
1629
|
+
next_value,
|
|
1630
|
+
reward,
|
|
1631
|
+
done=done,
|
|
1632
|
+
terminated=terminated,
|
|
1633
|
+
time_dim=time_dim,
|
|
1634
|
+
)
|
|
1635
|
+
return value_target
|
|
1636
|
+
|
|
1637
|
+
|
|
1638
|
+
class VTrace(ValueEstimatorBase):
|
|
1639
|
+
"""A class wrapper around V-Trace estimate functional.
|
|
1640
|
+
|
|
1641
|
+
Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures"
|
|
1642
|
+
:ref:`here <https://arxiv.org/abs/1802.01561>`_ for more context.
|
|
1643
|
+
|
|
1644
|
+
Keyword Args:
|
|
1645
|
+
gamma (scalar): exponential mean discount.
|
|
1646
|
+
value_network (TensorDictModule): value operator used to retrieve the value estimates.
|
|
1647
|
+
actor_network (TensorDictModule): actor operator used to retrieve the log prob.
|
|
1648
|
+
rho_thresh (Union[float, Tensor]): rho clipping parameter for importance weights.
|
|
1649
|
+
Defaults to ``1.0``.
|
|
1650
|
+
c_thresh (Union[float, Tensor]): c clipping parameter for importance weights.
|
|
1651
|
+
Defaults to ``1.0``.
|
|
1652
|
+
average_adv (bool): if ``True``, the resulting advantage values will be standardized.
|
|
1653
|
+
Default is ``False``.
|
|
1654
|
+
differentiable (bool, optional): if ``True``, gradients are propagated through
|
|
1655
|
+
the computation of the value function. Default is ``False``.
|
|
1656
|
+
|
|
1657
|
+
.. note::
|
|
1658
|
+
The proper way to make the function call non-differentiable is to
|
|
1659
|
+
decorate it in a `torch.no_grad()` context manager/decorator or
|
|
1660
|
+
pass detached parameters for functional modules.
|
|
1661
|
+
skip_existing (bool, optional): if ``True``, the value network will skip
|
|
1662
|
+
modules which outputs are already present in the tensordict.
|
|
1663
|
+
Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()`
|
|
1664
|
+
is not affected.
|
|
1665
|
+
Defaults to "state_value".
|
|
1666
|
+
advantage_key (str or tuple of str, optional): [Deprecated] the key of
|
|
1667
|
+
the advantage entry. Defaults to ``"advantage"``.
|
|
1668
|
+
value_target_key (str or tuple of str, optional): [Deprecated] the key
|
|
1669
|
+
of the advantage entry. Defaults to ``"value_target"``.
|
|
1670
|
+
value_key (str or tuple of str, optional): [Deprecated] the value key to
|
|
1671
|
+
read from the input tensordict. Defaults to ``"state_value"``.
|
|
1672
|
+
shifted (bool, optional): if ``True``, the value and next value are
|
|
1673
|
+
estimated with a single call to the value network. This is faster
|
|
1674
|
+
but is only valid whenever (1) the ``"next"`` value is shifted by
|
|
1675
|
+
only one time step (which is not the case with multi-step value
|
|
1676
|
+
estimation, for instance) and (2) when the parameters used at time
|
|
1677
|
+
``t`` and ``t+1`` are identical (which is not the case when target
|
|
1678
|
+
parameters are to be used). Defaults to ``False``.
|
|
1679
|
+
device (torch.device, optional): the device where the buffers will be instantiated.
|
|
1680
|
+
Defaults to ``torch.get_default_device()``.
|
|
1681
|
+
time_dim (int, optional): the dimension corresponding to the time
|
|
1682
|
+
in the input tensordict. If not provided, defaults to the dimension
|
|
1683
|
+
marked with the ``"time"`` name if any, and to the last dimension
|
|
1684
|
+
otherwise. Can be overridden during a call to
|
|
1685
|
+
:meth:`~.value_estimate`.
|
|
1686
|
+
Negative dimensions are considered with respect to the input
|
|
1687
|
+
tensordict.
|
|
1688
|
+
|
|
1689
|
+
VTrace will return an :obj:`"advantage"` entry containing the advantage value. It will also
|
|
1690
|
+
return a :obj:`"value_target"` entry with the V-Trace target value.
|
|
1691
|
+
|
|
1692
|
+
.. note::
|
|
1693
|
+
As other advantage functions do, if the ``value_key`` is already present
|
|
1694
|
+
in the input tensordict, the VTrace module will ignore the calls to the value
|
|
1695
|
+
network (if any) and use the provided value instead.
|
|
1696
|
+
|
|
1697
|
+
"""
|
|
1698
|
+
|
|
1699
|
+
def __init__(
|
|
1700
|
+
self,
|
|
1701
|
+
*,
|
|
1702
|
+
gamma: float | torch.Tensor,
|
|
1703
|
+
actor_network: TensorDictModule,
|
|
1704
|
+
value_network: TensorDictModule,
|
|
1705
|
+
rho_thresh: float | torch.Tensor = 1.0,
|
|
1706
|
+
c_thresh: float | torch.Tensor = 1.0,
|
|
1707
|
+
average_adv: bool = False,
|
|
1708
|
+
differentiable: bool = False,
|
|
1709
|
+
skip_existing: bool | None = None,
|
|
1710
|
+
advantage_key: NestedKey | None = None,
|
|
1711
|
+
value_target_key: NestedKey | None = None,
|
|
1712
|
+
value_key: NestedKey | None = None,
|
|
1713
|
+
shifted: bool = False,
|
|
1714
|
+
device: torch.device | None = None,
|
|
1715
|
+
time_dim: int | None = None,
|
|
1716
|
+
):
|
|
1717
|
+
super().__init__(
|
|
1718
|
+
shifted=shifted,
|
|
1719
|
+
value_network=value_network,
|
|
1720
|
+
differentiable=differentiable,
|
|
1721
|
+
advantage_key=advantage_key,
|
|
1722
|
+
value_target_key=value_target_key,
|
|
1723
|
+
value_key=value_key,
|
|
1724
|
+
skip_existing=skip_existing,
|
|
1725
|
+
device=device,
|
|
1726
|
+
)
|
|
1727
|
+
if not isinstance(gamma, torch.Tensor):
|
|
1728
|
+
gamma = torch.tensor(gamma, device=self._device)
|
|
1729
|
+
if not isinstance(rho_thresh, torch.Tensor):
|
|
1730
|
+
rho_thresh = torch.tensor(rho_thresh, device=self._device)
|
|
1731
|
+
if not isinstance(c_thresh, torch.Tensor):
|
|
1732
|
+
c_thresh = torch.tensor(c_thresh, device=self._device)
|
|
1733
|
+
|
|
1734
|
+
self.register_buffer("gamma", gamma)
|
|
1735
|
+
self.register_buffer("rho_thresh", rho_thresh)
|
|
1736
|
+
self.register_buffer("c_thresh", c_thresh)
|
|
1737
|
+
self.average_adv = average_adv
|
|
1738
|
+
self.actor_network = actor_network
|
|
1739
|
+
self.time_dim = time_dim
|
|
1740
|
+
|
|
1741
|
+
if isinstance(gamma, torch.Tensor) and gamma.shape != ():
|
|
1742
|
+
raise NotImplementedError(
|
|
1743
|
+
"Per-value gamma is not supported yet. Gamma must be a scalar."
|
|
1744
|
+
)
|
|
1745
|
+
|
|
1746
|
+
@property
|
|
1747
|
+
def in_keys(self):
|
|
1748
|
+
parent_in_keys = super().in_keys
|
|
1749
|
+
extended_in_keys = parent_in_keys + [self.tensor_keys.sample_log_prob]
|
|
1750
|
+
return extended_in_keys
|
|
1751
|
+
|
|
1752
|
+
@_self_set_skip_existing
|
|
1753
|
+
@_self_set_grad_enabled
|
|
1754
|
+
@dispatch
|
|
1755
|
+
def forward(
|
|
1756
|
+
self,
|
|
1757
|
+
tensordict: TensorDictBase,
|
|
1758
|
+
*,
|
|
1759
|
+
params: list[Tensor] | None = None,
|
|
1760
|
+
target_params: list[Tensor] | None = None,
|
|
1761
|
+
time_dim: int | None = None,
|
|
1762
|
+
) -> TensorDictBase:
|
|
1763
|
+
"""Computes the V-Trace correction given the data in tensordict.
|
|
1764
|
+
|
|
1765
|
+
If a functional module is provided, a nested TensorDict containing the parameters
|
|
1766
|
+
(and if relevant the target parameters) can be passed to the module.
|
|
1767
|
+
|
|
1768
|
+
Args:
|
|
1769
|
+
tensordict (TensorDictBase): A TensorDict containing the data
|
|
1770
|
+
(an observation key, "action", "reward", "done" and "next" tensordict state
|
|
1771
|
+
as returned by the environment) necessary to compute the value estimates and the GAE.
|
|
1772
|
+
The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are
|
|
1773
|
+
the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s).
|
|
1774
|
+
|
|
1775
|
+
Keyword Args:
|
|
1776
|
+
params (TensorDictBase, optional): A nested TensorDict containing the params
|
|
1777
|
+
to be passed to the functional value network module.
|
|
1778
|
+
target_params (TensorDictBase, optional): A nested TensorDict containing the
|
|
1779
|
+
target params to be passed to the functional value network module.
|
|
1780
|
+
time_dim (int, optional): the dimension corresponding to the time
|
|
1781
|
+
in the input tensordict. If not provided, defaults to the dimension
|
|
1782
|
+
marked with the ``"time"`` name if any, and to the last dimension
|
|
1783
|
+
otherwise.
|
|
1784
|
+
Negative dimensions are considered with respect to the input
|
|
1785
|
+
tensordict.
|
|
1786
|
+
|
|
1787
|
+
Returns:
|
|
1788
|
+
An updated TensorDict with an advantage and a value_error keys as defined in the constructor.
|
|
1789
|
+
|
|
1790
|
+
Examples:
|
|
1791
|
+
>>> value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"])
|
|
1792
|
+
>>> actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"])
|
|
1793
|
+
>>> actor_net = ProbabilisticActor(
|
|
1794
|
+
... module=actor_net,
|
|
1795
|
+
... in_keys=["logits"],
|
|
1796
|
+
... out_keys=["action"],
|
|
1797
|
+
... distribution_class=OneHotCategorical,
|
|
1798
|
+
... return_log_prob=True,
|
|
1799
|
+
... )
|
|
1800
|
+
>>> module = VTrace(
|
|
1801
|
+
... gamma=0.98,
|
|
1802
|
+
... value_network=value_net,
|
|
1803
|
+
... actor_network=actor_net,
|
|
1804
|
+
... differentiable=False,
|
|
1805
|
+
... )
|
|
1806
|
+
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
|
|
1807
|
+
>>> reward = torch.randn(1, 10, 1)
|
|
1808
|
+
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
1809
|
+
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
1810
|
+
>>> sample_log_prob = torch.randn(1, 10, 1)
|
|
1811
|
+
>>> tensordict = TensorDict({
|
|
1812
|
+
... "obs": obs,
|
|
1813
|
+
... "done": done,
|
|
1814
|
+
... "terminated": terminated,
|
|
1815
|
+
... "sample_log_prob": sample_log_prob,
|
|
1816
|
+
... "next": {"obs": next_obs, "reward": reward, "done": done, "terminated": terminated},
|
|
1817
|
+
... }, batch_size=[1, 10])
|
|
1818
|
+
>>> _ = module(tensordict)
|
|
1819
|
+
>>> assert "advantage" in tensordict.keys()
|
|
1820
|
+
|
|
1821
|
+
The module supports non-tensordict (i.e. unpacked tensordict) inputs too:
|
|
1822
|
+
|
|
1823
|
+
Examples:
|
|
1824
|
+
>>> value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"])
|
|
1825
|
+
>>> actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"])
|
|
1826
|
+
>>> actor_net = ProbabilisticActor(
|
|
1827
|
+
... module=actor_net,
|
|
1828
|
+
... in_keys=["logits"],
|
|
1829
|
+
... out_keys=["action"],
|
|
1830
|
+
... distribution_class=OneHotCategorical,
|
|
1831
|
+
... return_log_prob=True,
|
|
1832
|
+
... )
|
|
1833
|
+
>>> module = VTrace(
|
|
1834
|
+
... gamma=0.98,
|
|
1835
|
+
... value_network=value_net,
|
|
1836
|
+
... actor_network=actor_net,
|
|
1837
|
+
... differentiable=False,
|
|
1838
|
+
... )
|
|
1839
|
+
>>> obs, next_obs = torch.randn(2, 1, 10, 3)
|
|
1840
|
+
>>> reward = torch.randn(1, 10, 1)
|
|
1841
|
+
>>> done = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
1842
|
+
>>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
|
|
1843
|
+
>>> sample_log_prob = torch.randn(1, 10, 1)
|
|
1844
|
+
>>> tensordict = TensorDict({
|
|
1845
|
+
... "obs": obs,
|
|
1846
|
+
... "done": done,
|
|
1847
|
+
... "terminated": terminated,
|
|
1848
|
+
... "sample_log_prob": sample_log_prob,
|
|
1849
|
+
... "next": {"obs": next_obs, "reward": reward, "done": done, "terminated": terminated},
|
|
1850
|
+
... }, batch_size=[1, 10])
|
|
1851
|
+
>>> advantage, value_target = module(
|
|
1852
|
+
... obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated, sample_log_prob=sample_log_prob
|
|
1853
|
+
... )
|
|
1854
|
+
|
|
1855
|
+
"""
|
|
1856
|
+
if tensordict.batch_dims < 1:
|
|
1857
|
+
raise RuntimeError(
|
|
1858
|
+
"Expected input tensordict to have at least one dimensions, got "
|
|
1859
|
+
f"tensordict.batch_size = {tensordict.batch_size}"
|
|
1860
|
+
)
|
|
1861
|
+
reward = tensordict.get(("next", self.tensor_keys.reward))
|
|
1862
|
+
device = reward.device
|
|
1863
|
+
|
|
1864
|
+
if self.gamma.device != device:
|
|
1865
|
+
self.gamma = self.gamma.to(device)
|
|
1866
|
+
gamma = self.gamma
|
|
1867
|
+
steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None)
|
|
1868
|
+
if steps_to_next_obs is not None:
|
|
1869
|
+
gamma = gamma ** steps_to_next_obs.view_as(reward)
|
|
1870
|
+
|
|
1871
|
+
# Make sure we have the value and next value
|
|
1872
|
+
if self.value_network is not None:
|
|
1873
|
+
if params is not None:
|
|
1874
|
+
params = params.detach()
|
|
1875
|
+
if target_params is None:
|
|
1876
|
+
target_params = params.clone(False)
|
|
1877
|
+
with hold_out_net(self.value_network):
|
|
1878
|
+
# we may still need to pass gradient, but we don't want to assign grads to
|
|
1879
|
+
# value net params
|
|
1880
|
+
value, next_value = self._call_value_nets(
|
|
1881
|
+
data=tensordict,
|
|
1882
|
+
params=params,
|
|
1883
|
+
next_params=target_params,
|
|
1884
|
+
single_call=self.shifted,
|
|
1885
|
+
value_key=self.tensor_keys.value,
|
|
1886
|
+
detach_next=True,
|
|
1887
|
+
vmap_randomness=self.vmap_randomness,
|
|
1888
|
+
)
|
|
1889
|
+
else:
|
|
1890
|
+
value = tensordict.get(self.tensor_keys.value)
|
|
1891
|
+
next_value = tensordict.get(("next", self.tensor_keys.value))
|
|
1892
|
+
|
|
1893
|
+
lp = _maybe_get_or_select(tensordict, self.tensor_keys.sample_log_prob)
|
|
1894
|
+
if is_tensor_collection(lp):
|
|
1895
|
+
# Sum all values to match the batch size
|
|
1896
|
+
lp = lp.sum(dim="feature", reduce=True)
|
|
1897
|
+
log_mu = lp.view_as(value)
|
|
1898
|
+
|
|
1899
|
+
# Compute log prob with current policy
|
|
1900
|
+
with hold_out_net(self.actor_network):
|
|
1901
|
+
log_pi = _call_actor_net(
|
|
1902
|
+
actor_net=self.actor_network,
|
|
1903
|
+
data=tensordict,
|
|
1904
|
+
params=None,
|
|
1905
|
+
log_prob_key=self.tensor_keys.sample_log_prob,
|
|
1906
|
+
)
|
|
1907
|
+
log_pi = log_pi.view_as(value)
|
|
1908
|
+
|
|
1909
|
+
# Compute the V-Trace correction
|
|
1910
|
+
done = tensordict.get(("next", self.tensor_keys.done))
|
|
1911
|
+
terminated = tensordict.get(("next", self.tensor_keys.terminated))
|
|
1912
|
+
|
|
1913
|
+
time_dim = self._get_time_dim(time_dim, tensordict)
|
|
1914
|
+
adv, value_target = vtrace_advantage_estimate(
|
|
1915
|
+
gamma,
|
|
1916
|
+
log_pi,
|
|
1917
|
+
log_mu,
|
|
1918
|
+
value,
|
|
1919
|
+
next_value,
|
|
1920
|
+
reward,
|
|
1921
|
+
done,
|
|
1922
|
+
terminated,
|
|
1923
|
+
rho_thresh=self.rho_thresh,
|
|
1924
|
+
c_thresh=self.c_thresh,
|
|
1925
|
+
time_dim=time_dim,
|
|
1926
|
+
)
|
|
1927
|
+
|
|
1928
|
+
if self.average_adv:
|
|
1929
|
+
loc = adv.mean()
|
|
1930
|
+
scale = adv.std().clamp_min(1e-5)
|
|
1931
|
+
adv = adv - loc
|
|
1932
|
+
adv = adv / scale
|
|
1933
|
+
|
|
1934
|
+
tensordict.set(self.tensor_keys.advantage, adv)
|
|
1935
|
+
tensordict.set(self.tensor_keys.value_target, value_target)
|
|
1936
|
+
|
|
1937
|
+
return tensordict
|
|
1938
|
+
|
|
1939
|
+
|
|
1940
|
+
def _deprecate_class(cls, new_cls):
|
|
1941
|
+
@wraps(cls.__init__)
|
|
1942
|
+
def new_init(self, *args, **kwargs):
|
|
1943
|
+
warnings.warn(f"class {cls} is deprecated, please use {new_cls} instead.")
|
|
1944
|
+
cls.__init__(self, *args, **kwargs)
|
|
1945
|
+
|
|
1946
|
+
cls.__init__ = new_init
|
|
1947
|
+
|
|
1948
|
+
|
|
1949
|
+
TD0Estimate = type("TD0Estimate", TD0Estimator.__bases__, dict(TD0Estimator.__dict__))
|
|
1950
|
+
_deprecate_class(TD0Estimate, TD0Estimator)
|
|
1951
|
+
TD1Estimate = type("TD1Estimate", TD1Estimator.__bases__, dict(TD1Estimator.__dict__))
|
|
1952
|
+
_deprecate_class(TD1Estimate, TD1Estimator)
|
|
1953
|
+
TDLambdaEstimate = type(
|
|
1954
|
+
"TDLambdaEstimate", TDLambdaEstimator.__bases__, dict(TDLambdaEstimator.__dict__)
|
|
1955
|
+
)
|
|
1956
|
+
_deprecate_class(TDLambdaEstimate, TDLambdaEstimator)
|