torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,782 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
import re
|
|
9
|
+
import warnings
|
|
10
|
+
from collections.abc import Callable, Iterable
|
|
11
|
+
from copy import copy
|
|
12
|
+
from enum import Enum
|
|
13
|
+
from typing import Any, TypeVar
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key
|
|
17
|
+
from tensordict.nn import TensorDictModule
|
|
18
|
+
from torch import nn, Tensor
|
|
19
|
+
from torch.nn import functional as F
|
|
20
|
+
from torch.nn.modules import dropout
|
|
21
|
+
from torch.utils._pytree import tree_map
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
from torch import vmap
|
|
25
|
+
except ImportError as err:
|
|
26
|
+
try:
|
|
27
|
+
from functorch import vmap
|
|
28
|
+
except ImportError as err_ft:
|
|
29
|
+
raise err_ft from err
|
|
30
|
+
from torchrl._utils import implement_for
|
|
31
|
+
from torchrl.envs.utils import step_mdp
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
from torch.compiler import is_dynamo_compiling
|
|
35
|
+
except ImportError:
|
|
36
|
+
from torch._dynamo import is_compiling as is_dynamo_compiling
|
|
37
|
+
|
|
38
|
+
_GAMMA_LMBDA_DEPREC_ERROR = (
|
|
39
|
+
"Passing gamma / lambda parameters through the loss constructor "
|
|
40
|
+
"is a deprecated feature. To customize your value function, "
|
|
41
|
+
"run `loss_module.make_value_estimator(ValueEstimators.<value_fun>, gamma=val)`."
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
RANDOM_MODULE_LIST = (dropout._DropoutNd,)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ValueEstimators(Enum):
|
|
48
|
+
"""Value function enumerator for custom-built estimators.
|
|
49
|
+
|
|
50
|
+
Allows for a flexible usage of various value functions when the loss module
|
|
51
|
+
allows it.
|
|
52
|
+
|
|
53
|
+
Examples:
|
|
54
|
+
>>> dqn_loss = DQNLoss(actor)
|
|
55
|
+
>>> dqn_loss.make_value_estimator(ValueEstimators.TD0, gamma=0.9)
|
|
56
|
+
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
TD0 = "Bootstrapped TD (1-step return)"
|
|
60
|
+
TD1 = "TD(1) (infinity-step return)"
|
|
61
|
+
TDLambda = "TD(lambda)"
|
|
62
|
+
GAE = "Generalized advantage estimate"
|
|
63
|
+
VTrace = "V-trace"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def default_value_kwargs(value_type: ValueEstimators):
|
|
67
|
+
"""Default value function keyword argument generator.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
value_type (Enum.value): the value function type, from the
|
|
71
|
+
:class:`~torchrl.objectives.utils.ValueEstimators` class.
|
|
72
|
+
|
|
73
|
+
Examples:
|
|
74
|
+
>>> kwargs = default_value_kwargs(ValueEstimators.TDLambda)
|
|
75
|
+
{"gamma": 0.99, "lmbda": 0.95}
|
|
76
|
+
|
|
77
|
+
"""
|
|
78
|
+
if value_type == ValueEstimators.TD1:
|
|
79
|
+
return {"gamma": 0.99, "differentiable": True}
|
|
80
|
+
elif value_type == ValueEstimators.TD0:
|
|
81
|
+
return {"gamma": 0.99, "differentiable": True}
|
|
82
|
+
elif value_type == ValueEstimators.GAE:
|
|
83
|
+
return {"gamma": 0.99, "lmbda": 0.95, "differentiable": True}
|
|
84
|
+
elif value_type == ValueEstimators.TDLambda:
|
|
85
|
+
return {"gamma": 0.99, "lmbda": 0.95, "differentiable": True}
|
|
86
|
+
elif value_type == ValueEstimators.VTrace:
|
|
87
|
+
return {"gamma": 0.99, "differentiable": True}
|
|
88
|
+
else:
|
|
89
|
+
raise NotImplementedError(f"Unknown value type {value_type}.")
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class _context_manager:
|
|
93
|
+
def __init__(self, value=True):
|
|
94
|
+
self.value = value
|
|
95
|
+
self.prev = []
|
|
96
|
+
|
|
97
|
+
def __call__(self, func):
|
|
98
|
+
@functools.wraps(func)
|
|
99
|
+
def decorate_context(*args, **kwargs):
|
|
100
|
+
with self:
|
|
101
|
+
return func(*args, **kwargs)
|
|
102
|
+
|
|
103
|
+
return decorate_context
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
TensorLike = TypeVar("TensorLike", Tensor, TensorDict)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def distance_loss(
|
|
110
|
+
v1: TensorLike,
|
|
111
|
+
v2: TensorLike,
|
|
112
|
+
loss_function: str,
|
|
113
|
+
strict_shape: bool = True,
|
|
114
|
+
) -> TensorLike:
|
|
115
|
+
"""Computes a distance loss between two tensors.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
v1 (Tensor | TensorDict): a tensor or tensordict with a shape compatible with v2.
|
|
119
|
+
v2 (Tensor | TensorDict): a tensor or tensordict with a shape compatible with v1.
|
|
120
|
+
loss_function (str): One of "l2", "l1" or "smooth_l1" representing which loss function is to be used.
|
|
121
|
+
strict_shape (bool): if False, v1 and v2 are allowed to have a different shape.
|
|
122
|
+
Default is ``True``.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
A tensor or tensordict of the shape v1.view_as(v2) or v2.view_as(v1)
|
|
126
|
+
with values equal to the distance loss between the two.
|
|
127
|
+
|
|
128
|
+
"""
|
|
129
|
+
if v1.shape != v2.shape and strict_shape:
|
|
130
|
+
raise RuntimeError(
|
|
131
|
+
f"The input tensors or tensordicts have shapes {v1.shape} and {v2.shape} which are incompatible."
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
if loss_function == "l2":
|
|
135
|
+
return F.mse_loss(v1, v2, reduction="none")
|
|
136
|
+
|
|
137
|
+
if loss_function == "l1":
|
|
138
|
+
return F.l1_loss(v1, v2, reduction="none")
|
|
139
|
+
|
|
140
|
+
if loss_function == "smooth_l1":
|
|
141
|
+
return F.smooth_l1_loss(v1, v2, reduction="none")
|
|
142
|
+
|
|
143
|
+
raise NotImplementedError(f"Unknown loss {loss_function}.")
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class TargetNetUpdater:
|
|
147
|
+
"""An abstract class for target network update in Double DQN/DDPG.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
loss_module (DQNLoss or DDPGLoss): loss module where the target network should be updated.
|
|
151
|
+
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __init__(
|
|
155
|
+
self,
|
|
156
|
+
loss_module: LossModule, # noqa: F821
|
|
157
|
+
):
|
|
158
|
+
from torchrl.objectives.common import LossModule
|
|
159
|
+
|
|
160
|
+
if not isinstance(loss_module, LossModule):
|
|
161
|
+
raise ValueError("The loss_module must be a LossModule instance.")
|
|
162
|
+
_has_update_associated = getattr(loss_module, "_has_update_associated", None)
|
|
163
|
+
for k in loss_module._has_update_associated.keys():
|
|
164
|
+
loss_module._has_update_associated[k] = True
|
|
165
|
+
try:
|
|
166
|
+
_target_names = []
|
|
167
|
+
for name, _ in loss_module.named_children():
|
|
168
|
+
# the TensorDictParams is a nn.Module instance
|
|
169
|
+
if name.startswith("target_") and name.endswith("_params"):
|
|
170
|
+
_target_names.append(name)
|
|
171
|
+
|
|
172
|
+
if len(_target_names) == 0:
|
|
173
|
+
raise RuntimeError(
|
|
174
|
+
"Did not find any target parameters or buffers in the loss module."
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
_source_names = ["".join(name.split("target_")) for name in _target_names]
|
|
178
|
+
|
|
179
|
+
for _source in _source_names:
|
|
180
|
+
try:
|
|
181
|
+
getattr(loss_module, _source)
|
|
182
|
+
except AttributeError as err:
|
|
183
|
+
raise RuntimeError(
|
|
184
|
+
f"Incongruent target and source parameter lists: "
|
|
185
|
+
f"{_source} is not an attribute of the loss_module"
|
|
186
|
+
) from err
|
|
187
|
+
|
|
188
|
+
self._target_names = _target_names
|
|
189
|
+
self._source_names = _source_names
|
|
190
|
+
self.loss_module = loss_module
|
|
191
|
+
self.initialized = False
|
|
192
|
+
self.init_()
|
|
193
|
+
_has_update_associated = True
|
|
194
|
+
finally:
|
|
195
|
+
for k in loss_module._has_update_associated.keys():
|
|
196
|
+
loss_module._has_update_associated[k] = _has_update_associated
|
|
197
|
+
|
|
198
|
+
@property
|
|
199
|
+
def _targets(self):
|
|
200
|
+
targets = self.__dict__.get("_targets_val", None)
|
|
201
|
+
if targets is None:
|
|
202
|
+
targets = self.__dict__["_targets_val"] = TensorDict(
|
|
203
|
+
{name: getattr(self.loss_module, name) for name in self._target_names},
|
|
204
|
+
[],
|
|
205
|
+
)
|
|
206
|
+
return targets
|
|
207
|
+
|
|
208
|
+
@_targets.setter
|
|
209
|
+
def _targets(self, targets):
|
|
210
|
+
self.__dict__["_targets_val"] = targets
|
|
211
|
+
|
|
212
|
+
@property
|
|
213
|
+
def _sources(self):
|
|
214
|
+
sources = self.__dict__.get("_sources_val", None)
|
|
215
|
+
if sources is None:
|
|
216
|
+
sources = self.__dict__["_sources_val"] = TensorDict(
|
|
217
|
+
{name: getattr(self.loss_module, name) for name in self._source_names},
|
|
218
|
+
[],
|
|
219
|
+
)
|
|
220
|
+
return sources
|
|
221
|
+
|
|
222
|
+
@_sources.setter
|
|
223
|
+
def _sources(self, sources):
|
|
224
|
+
self.__dict__["_sources_val"] = sources
|
|
225
|
+
|
|
226
|
+
def init_(self) -> None:
|
|
227
|
+
if self.initialized:
|
|
228
|
+
warnings.warn("Updated already initialized.")
|
|
229
|
+
found_distinct = False
|
|
230
|
+
self._distinct_and_params = {}
|
|
231
|
+
for key, source in self._sources.items(True, True):
|
|
232
|
+
if not isinstance(key, tuple):
|
|
233
|
+
key = (key,)
|
|
234
|
+
key = ("target_" + key[0], *key[1:])
|
|
235
|
+
target = self._targets[key]
|
|
236
|
+
# for p_source, p_target in zip(source, target):
|
|
237
|
+
if target.requires_grad:
|
|
238
|
+
raise RuntimeError("the target parameter is part of a graph.")
|
|
239
|
+
self._distinct_and_params[key] = (
|
|
240
|
+
target.is_leaf
|
|
241
|
+
and source.requires_grad
|
|
242
|
+
and target.data_ptr() != source.data.data_ptr()
|
|
243
|
+
)
|
|
244
|
+
found_distinct = found_distinct or self._distinct_and_params[key]
|
|
245
|
+
target.data.copy_(source.data)
|
|
246
|
+
if not found_distinct:
|
|
247
|
+
raise RuntimeError(
|
|
248
|
+
f"The target and source data are identical for all params. "
|
|
249
|
+
"Have you created proper target parameters? "
|
|
250
|
+
"If the loss has a ``delay_value`` kwarg, make sure to set it "
|
|
251
|
+
"to True if it is not done by default. "
|
|
252
|
+
f"If no target parameter is needed, do not use a target updater such as {type(self)}."
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# filter the target_ out
|
|
256
|
+
def filter_target(key):
|
|
257
|
+
if isinstance(key, tuple):
|
|
258
|
+
return (filter_target(key[0]), *key[1:])
|
|
259
|
+
return key[7:]
|
|
260
|
+
|
|
261
|
+
self._sources = self._sources.select(
|
|
262
|
+
*[
|
|
263
|
+
filter_target(key)
|
|
264
|
+
for (key, val) in self._distinct_and_params.items()
|
|
265
|
+
if val
|
|
266
|
+
]
|
|
267
|
+
).lock_()
|
|
268
|
+
self._targets = self._targets.select(
|
|
269
|
+
*(key for (key, val) in self._distinct_and_params.items() if val)
|
|
270
|
+
).lock_()
|
|
271
|
+
|
|
272
|
+
self.initialized = True
|
|
273
|
+
|
|
274
|
+
def step(self) -> None:
|
|
275
|
+
if not self.initialized:
|
|
276
|
+
raise Exception(
|
|
277
|
+
f"{self.__class__.__name__} must be "
|
|
278
|
+
f"initialized (`{self.__class__.__name__}.init_()`) before calling step()"
|
|
279
|
+
)
|
|
280
|
+
for key, param in self._sources.items():
|
|
281
|
+
target = self._targets.get(f"target_{key}")
|
|
282
|
+
if target.requires_grad:
|
|
283
|
+
raise RuntimeError("the target parameter is part of a graph.")
|
|
284
|
+
self._step(param, target)
|
|
285
|
+
|
|
286
|
+
def _step(self, p_source: Tensor, p_target: Tensor) -> None:
|
|
287
|
+
raise NotImplementedError
|
|
288
|
+
|
|
289
|
+
def __repr__(self) -> str:
|
|
290
|
+
string = (
|
|
291
|
+
f"{self.__class__.__name__}(sources={self._sources}, targets="
|
|
292
|
+
f"{self._targets})"
|
|
293
|
+
)
|
|
294
|
+
return string
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
class SoftUpdate(TargetNetUpdater):
|
|
298
|
+
r"""A soft-update class for target network update in Double DQN/DDPG.
|
|
299
|
+
|
|
300
|
+
This was proposed in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf
|
|
301
|
+
|
|
302
|
+
One and only one decay factor (tau or eps) must be specified.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
loss_module (DQNLoss or DDPGLoss): loss module where the target network should be updated.
|
|
306
|
+
eps (scalar): epsilon in the update equation:
|
|
307
|
+
.. math::
|
|
308
|
+
|
|
309
|
+
\theta_t = \theta_{t-1} * \epsilon + \theta_t * (1-\epsilon)
|
|
310
|
+
|
|
311
|
+
Exclusive with ``tau``.
|
|
312
|
+
tau (scalar): Polyak tau. It is equal to ``1-eps``, and exclusive with it.
|
|
313
|
+
"""
|
|
314
|
+
|
|
315
|
+
def __init__(
|
|
316
|
+
self,
|
|
317
|
+
loss_module: (
|
|
318
|
+
DQNLoss # noqa: F821
|
|
319
|
+
| DDPGLoss # noqa: F821
|
|
320
|
+
| SACLoss # noqa: F821
|
|
321
|
+
| REDQLoss # noqa: F821
|
|
322
|
+
| TD3Loss # noqa: F821 # noqa: F821
|
|
323
|
+
),
|
|
324
|
+
*,
|
|
325
|
+
eps: float | None = None,
|
|
326
|
+
tau: float | None = None,
|
|
327
|
+
):
|
|
328
|
+
if eps is None and tau is None:
|
|
329
|
+
raise RuntimeError(
|
|
330
|
+
"Neither eps nor tau was provided. This behavior is deprecated.",
|
|
331
|
+
)
|
|
332
|
+
eps = 0.999
|
|
333
|
+
if (eps is None) ^ (tau is None):
|
|
334
|
+
if eps is None:
|
|
335
|
+
eps = 1 - tau
|
|
336
|
+
else:
|
|
337
|
+
raise ValueError("One and only one argument (tau or eps) can be specified.")
|
|
338
|
+
if eps < 0.5:
|
|
339
|
+
warnings.warn(
|
|
340
|
+
"Found an eps value < 0.5, which is unexpected. "
|
|
341
|
+
"You may want to use the `tau` keyword argument instead."
|
|
342
|
+
)
|
|
343
|
+
if not (eps <= 1.0 and eps >= 0.0):
|
|
344
|
+
raise ValueError(
|
|
345
|
+
f"Got eps = {eps} when it was supposed to be between 0 and 1."
|
|
346
|
+
)
|
|
347
|
+
super().__init__(loss_module)
|
|
348
|
+
self.eps = eps
|
|
349
|
+
|
|
350
|
+
def _step(
|
|
351
|
+
self, p_source: Tensor | TensorDictBase, p_target: Tensor | TensorDictBase
|
|
352
|
+
) -> None:
|
|
353
|
+
p_target.data.lerp_(p_source.data, 1 - self.eps)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
class HardUpdate(TargetNetUpdater):
|
|
357
|
+
"""A hard-update class for target network update in Double DQN/DDPG (by contrast with soft updates).
|
|
358
|
+
|
|
359
|
+
This was proposed in the original Double DQN paper: "Deep Reinforcement Learning with Double Q-learning",
|
|
360
|
+
https://arxiv.org/abs/1509.06461.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
loss_module (DQNLoss or DDPGLoss): loss module where the target network should be updated.
|
|
364
|
+
|
|
365
|
+
Keyword Args:
|
|
366
|
+
value_network_update_interval (scalar): how often the target network should be updated.
|
|
367
|
+
default: 1000
|
|
368
|
+
"""
|
|
369
|
+
|
|
370
|
+
def __init__(
|
|
371
|
+
self,
|
|
372
|
+
loss_module: DQNLoss | DDPGLoss | SACLoss | TD3Loss, # noqa: F821
|
|
373
|
+
*,
|
|
374
|
+
value_network_update_interval: float = 1000,
|
|
375
|
+
):
|
|
376
|
+
super().__init__(loss_module)
|
|
377
|
+
self.value_network_update_interval = value_network_update_interval
|
|
378
|
+
self.counter = 0
|
|
379
|
+
|
|
380
|
+
def _step(self, p_source: Tensor, p_target: Tensor) -> None:
|
|
381
|
+
if self.counter == self.value_network_update_interval:
|
|
382
|
+
p_target.data.copy_(p_source.data)
|
|
383
|
+
|
|
384
|
+
def step(self) -> None:
|
|
385
|
+
super().step()
|
|
386
|
+
if self.counter == self.value_network_update_interval:
|
|
387
|
+
self.counter = 0
|
|
388
|
+
else:
|
|
389
|
+
self.counter += 1
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
class hold_out_net(_context_manager):
|
|
393
|
+
"""Context manager to hold a network out of a computational graph."""
|
|
394
|
+
|
|
395
|
+
def __init__(self, network: nn.Module) -> None:
|
|
396
|
+
self.network = network
|
|
397
|
+
for p in network.parameters():
|
|
398
|
+
self.mode = p.requires_grad
|
|
399
|
+
break
|
|
400
|
+
else:
|
|
401
|
+
self.mode = True
|
|
402
|
+
|
|
403
|
+
def __enter__(self) -> None:
|
|
404
|
+
if self.mode:
|
|
405
|
+
if is_dynamo_compiling():
|
|
406
|
+
self._params = TensorDict.from_module(self.network)
|
|
407
|
+
self._params.data.to_module(self.network)
|
|
408
|
+
else:
|
|
409
|
+
self.network.requires_grad_(False)
|
|
410
|
+
|
|
411
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
412
|
+
if self.mode:
|
|
413
|
+
if is_dynamo_compiling():
|
|
414
|
+
self._params.to_module(self.network)
|
|
415
|
+
else:
|
|
416
|
+
self.network.requires_grad_()
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
class hold_out_params(_context_manager):
|
|
420
|
+
"""Context manager to hold a list of parameters out of a computational graph."""
|
|
421
|
+
|
|
422
|
+
def __init__(self, params: Iterable[Tensor]) -> None:
|
|
423
|
+
if isinstance(params, TensorDictBase):
|
|
424
|
+
self.params = params.detach()
|
|
425
|
+
else:
|
|
426
|
+
self.params = tuple(p.detach() for p in params)
|
|
427
|
+
|
|
428
|
+
def __enter__(self) -> None:
|
|
429
|
+
return self.params
|
|
430
|
+
|
|
431
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
432
|
+
pass
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
@torch.no_grad()
|
|
436
|
+
def next_state_value(
|
|
437
|
+
tensordict: TensorDictBase,
|
|
438
|
+
operator: TensorDictModule | None = None,
|
|
439
|
+
next_val_key: str = "state_action_value",
|
|
440
|
+
gamma: float = 0.99,
|
|
441
|
+
pred_next_val: Tensor | None = None,
|
|
442
|
+
**kwargs,
|
|
443
|
+
) -> torch.Tensor:
|
|
444
|
+
"""Computes the next state value (without gradient) to compute a target value.
|
|
445
|
+
|
|
446
|
+
The target value is usually used to compute a distance loss (e.g. MSE):
|
|
447
|
+
L = Sum[ (q_value - target_value)^2 ]
|
|
448
|
+
The target value is computed as
|
|
449
|
+
r + gamma ** n_steps_to_next * value_next_state
|
|
450
|
+
If the reward is the immediate reward, n_steps_to_next=1. If N-steps rewards are used, n_steps_to_next is gathered
|
|
451
|
+
from the input tensordict.
|
|
452
|
+
|
|
453
|
+
Args:
|
|
454
|
+
tensordict (TensorDictBase): Tensordict containing a reward and done key (and a n_steps_to_next key for n-steps
|
|
455
|
+
rewards).
|
|
456
|
+
operator (ProbabilisticTDModule, optional): the value function operator. Should write a 'next_val_key'
|
|
457
|
+
key-value in the input tensordict when called. It does not need to be provided if pred_next_val is given.
|
|
458
|
+
next_val_key (str, optional): key where the next value will be written.
|
|
459
|
+
Default: 'state_action_value'
|
|
460
|
+
gamma (:obj:`float`, optional): return discount rate.
|
|
461
|
+
default: 0.99
|
|
462
|
+
pred_next_val (Tensor, optional): the next state value can be provided if it is not computed with the operator.
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
a Tensor of the size of the input tensordict containing the predicted value state.
|
|
466
|
+
|
|
467
|
+
"""
|
|
468
|
+
if "steps_to_next_obs" in tensordict.keys():
|
|
469
|
+
steps_to_next_obs = tensordict.get("steps_to_next_obs").squeeze(-1)
|
|
470
|
+
else:
|
|
471
|
+
steps_to_next_obs = 1
|
|
472
|
+
|
|
473
|
+
rewards = tensordict.get(("next", "reward")).squeeze(-1)
|
|
474
|
+
done = tensordict.get(("next", "done")).squeeze(-1)
|
|
475
|
+
if done.all() or gamma == 0:
|
|
476
|
+
return rewards
|
|
477
|
+
|
|
478
|
+
if pred_next_val is None:
|
|
479
|
+
next_td = step_mdp(tensordict) # next_observation -> observation
|
|
480
|
+
next_td = next_td.select(*operator.in_keys)
|
|
481
|
+
operator(next_td, **kwargs)
|
|
482
|
+
pred_next_val_detach = next_td.get(next_val_key).squeeze(-1)
|
|
483
|
+
else:
|
|
484
|
+
pred_next_val_detach = pred_next_val.squeeze(-1)
|
|
485
|
+
done = done.to(torch.float)
|
|
486
|
+
target_value = (1 - done) * pred_next_val_detach
|
|
487
|
+
rewards = rewards.to(torch.float)
|
|
488
|
+
target_value = rewards + (gamma**steps_to_next_obs) * target_value
|
|
489
|
+
return target_value
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
def _cache_values(func):
|
|
493
|
+
"""Caches the tensordict returned by a property."""
|
|
494
|
+
name = func.__name__
|
|
495
|
+
|
|
496
|
+
@functools.wraps(func)
|
|
497
|
+
def new_func(self, netname=None):
|
|
498
|
+
if is_dynamo_compiling():
|
|
499
|
+
if netname is not None:
|
|
500
|
+
return func(self, netname)
|
|
501
|
+
else:
|
|
502
|
+
return func(self)
|
|
503
|
+
__dict__ = self.__dict__
|
|
504
|
+
_cache = __dict__.setdefault("_cache", {})
|
|
505
|
+
attr_name = name
|
|
506
|
+
if netname is not None:
|
|
507
|
+
attr_name += "_" + netname
|
|
508
|
+
if attr_name in _cache:
|
|
509
|
+
out = _cache[attr_name]
|
|
510
|
+
return out
|
|
511
|
+
if netname is not None:
|
|
512
|
+
out = func(self, netname)
|
|
513
|
+
else:
|
|
514
|
+
out = func(self)
|
|
515
|
+
# TODO: decide what to do with locked tds in functional calls
|
|
516
|
+
# if is_tensor_collection(out):
|
|
517
|
+
# out.lock_()
|
|
518
|
+
_cache[attr_name] = out
|
|
519
|
+
return out
|
|
520
|
+
|
|
521
|
+
return new_func
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
def _vmap_func(module, *args, func=None, pseudo_vmap: bool = False, **kwargs):
|
|
525
|
+
try:
|
|
526
|
+
|
|
527
|
+
def decorated_module(*module_args_params):
|
|
528
|
+
params = module_args_params[-1]
|
|
529
|
+
module_args = module_args_params[:-1]
|
|
530
|
+
with params.to_module(module):
|
|
531
|
+
if func is None:
|
|
532
|
+
r = module(*module_args)
|
|
533
|
+
else:
|
|
534
|
+
r = getattr(module, func)(*module_args)
|
|
535
|
+
return r
|
|
536
|
+
|
|
537
|
+
if not pseudo_vmap:
|
|
538
|
+
return vmap(decorated_module, *args, **kwargs) # noqa: TOR101
|
|
539
|
+
return _pseudo_vmap(decorated_module, *args, **kwargs)
|
|
540
|
+
|
|
541
|
+
except RuntimeError as err:
|
|
542
|
+
if re.match(
|
|
543
|
+
r"vmap: called random operation while in randomness error mode", str(err)
|
|
544
|
+
):
|
|
545
|
+
raise RuntimeError(
|
|
546
|
+
"Please use <loss_module>.set_vmap_randomness('different') to handle random operations during vmap."
|
|
547
|
+
) from err
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
@implement_for("torch", "2.7")
|
|
551
|
+
def _pseudo_vmap(
|
|
552
|
+
func: Callable,
|
|
553
|
+
in_dims: Any = 0,
|
|
554
|
+
out_dims: Any = 0,
|
|
555
|
+
randomness: str | None = None,
|
|
556
|
+
*,
|
|
557
|
+
chunk_size=None,
|
|
558
|
+
):
|
|
559
|
+
if randomness is not None and randomness not in ("different", "error"):
|
|
560
|
+
raise ValueError(
|
|
561
|
+
f"pseudo_vmap only supports 'different' or 'error' randomness modes, but got {randomness=}. If another mode is required, please "
|
|
562
|
+
"submit an issue in TorchRL."
|
|
563
|
+
)
|
|
564
|
+
from tensordict.nn.functional_modules import _exclude_td_from_pytree
|
|
565
|
+
|
|
566
|
+
def _unbind(d, x):
|
|
567
|
+
if d is not None and hasattr(x, "unbind"):
|
|
568
|
+
return x.unbind(d)
|
|
569
|
+
# Generator to reprod the value
|
|
570
|
+
return (copy(x) for _ in range(1000))
|
|
571
|
+
|
|
572
|
+
def _stack(d, x):
|
|
573
|
+
if d is not None:
|
|
574
|
+
x = list(x)
|
|
575
|
+
return torch.stack(list(x), d)
|
|
576
|
+
return x
|
|
577
|
+
|
|
578
|
+
@functools.wraps(func)
|
|
579
|
+
def new_func(*args, in_dims=in_dims, out_dims=out_dims, **kwargs):
|
|
580
|
+
with _exclude_td_from_pytree():
|
|
581
|
+
# Unbind inputs
|
|
582
|
+
if isinstance(in_dims, int):
|
|
583
|
+
in_dims = (in_dims,) * len(args)
|
|
584
|
+
if isinstance(out_dims, int):
|
|
585
|
+
out_dims = (out_dims,)
|
|
586
|
+
|
|
587
|
+
vs = zip(*tuple(tree_map(_unbind, in_dims, args)))
|
|
588
|
+
rs = []
|
|
589
|
+
for v in vs:
|
|
590
|
+
r = func(*v, **kwargs)
|
|
591
|
+
if not isinstance(r, tuple):
|
|
592
|
+
r = (r,)
|
|
593
|
+
rs.append(r)
|
|
594
|
+
rs = tuple(zip(*rs))
|
|
595
|
+
vs = tuple(tree_map(_stack, out_dims, rs))
|
|
596
|
+
if len(vs) == 1:
|
|
597
|
+
return vs[0]
|
|
598
|
+
return vs
|
|
599
|
+
|
|
600
|
+
return new_func
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
@implement_for("torch", None, "2.7")
|
|
604
|
+
def _pseudo_vmap( # noqa: F811
|
|
605
|
+
func: Callable,
|
|
606
|
+
in_dims: Any = 0,
|
|
607
|
+
out_dims: Any = 0,
|
|
608
|
+
randomness: str | None = None,
|
|
609
|
+
*,
|
|
610
|
+
chunk_size=None,
|
|
611
|
+
):
|
|
612
|
+
@functools.wraps(func)
|
|
613
|
+
def new_func(*args, in_dims=in_dims, out_dims=out_dims, **kwargs):
|
|
614
|
+
raise NotImplementedError("This implementation is not supported for torch<2.7")
|
|
615
|
+
|
|
616
|
+
return new_func
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def _reduce(
|
|
620
|
+
tensor: torch.Tensor,
|
|
621
|
+
reduction: str,
|
|
622
|
+
mask: torch.Tensor | None = None,
|
|
623
|
+
weights: torch.Tensor | None = None,
|
|
624
|
+
) -> float | torch.Tensor:
|
|
625
|
+
"""Reduces a tensor given the reduction method.
|
|
626
|
+
|
|
627
|
+
Args:
|
|
628
|
+
tensor (torch.Tensor): The tensor to reduce.
|
|
629
|
+
reduction (str): The reduction method.
|
|
630
|
+
mask (torch.Tensor, optional): A mask to apply to the tensor before reducing.
|
|
631
|
+
weights (torch.Tensor, optional): Importance sampling weights for weighted reduction.
|
|
632
|
+
When provided with reduction="mean", computes: (tensor * weights).sum() / weights.sum()
|
|
633
|
+
When provided with reduction="sum", computes: (tensor * weights).sum()
|
|
634
|
+
This is used for proper bias correction with prioritized replay buffers.
|
|
635
|
+
|
|
636
|
+
Returns:
|
|
637
|
+
float | torch.Tensor: The reduced tensor.
|
|
638
|
+
"""
|
|
639
|
+
if reduction == "none":
|
|
640
|
+
if weights is None:
|
|
641
|
+
result = tensor
|
|
642
|
+
if mask is not None:
|
|
643
|
+
result = result[mask]
|
|
644
|
+
elif mask is not None:
|
|
645
|
+
masked_weight = weights[mask]
|
|
646
|
+
masked_tensor = tensor[mask]
|
|
647
|
+
result = masked_tensor * masked_weight
|
|
648
|
+
else:
|
|
649
|
+
result = tensor * weights
|
|
650
|
+
elif reduction == "mean":
|
|
651
|
+
if weights is not None:
|
|
652
|
+
# Weighted average: (tensor * weights).sum() / weights.sum()
|
|
653
|
+
if mask is not None:
|
|
654
|
+
masked_weight = weights[mask]
|
|
655
|
+
masked_tensor = tensor[mask]
|
|
656
|
+
result = (masked_tensor * masked_weight).sum() / masked_weight.sum()
|
|
657
|
+
else:
|
|
658
|
+
if tensor.shape != weights.shape:
|
|
659
|
+
raise ValueError(
|
|
660
|
+
f"Tensor and weights shapes must match, but got {tensor.shape} and {weights.shape}"
|
|
661
|
+
)
|
|
662
|
+
result = (tensor * weights).sum() / weights.sum()
|
|
663
|
+
elif mask is not None:
|
|
664
|
+
result = tensor[mask].mean()
|
|
665
|
+
else:
|
|
666
|
+
result = tensor.mean()
|
|
667
|
+
elif reduction == "sum":
|
|
668
|
+
if weights is not None:
|
|
669
|
+
# Weighted sum: (tensor * weights).sum()
|
|
670
|
+
if mask is not None:
|
|
671
|
+
masked_weight = weights[mask]
|
|
672
|
+
masked_tensor = tensor[mask]
|
|
673
|
+
result = (masked_tensor * masked_weight).sum()
|
|
674
|
+
else:
|
|
675
|
+
if tensor.shape != weights.shape:
|
|
676
|
+
raise ValueError(
|
|
677
|
+
f"Tensor and weights shapes must match, but got {tensor.shape} and {weights.shape}"
|
|
678
|
+
)
|
|
679
|
+
result = (tensor * weights).sum()
|
|
680
|
+
elif mask is not None:
|
|
681
|
+
result = tensor[mask].sum()
|
|
682
|
+
else:
|
|
683
|
+
result = tensor.sum()
|
|
684
|
+
else:
|
|
685
|
+
raise NotImplementedError(f"Unknown reduction method {reduction}")
|
|
686
|
+
return result
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
def _clip_value_loss(
|
|
690
|
+
old_state_value: torch.Tensor | TensorDict,
|
|
691
|
+
state_value: torch.Tensor | TensorDict,
|
|
692
|
+
clip_value: torch.Tensor | TensorDict,
|
|
693
|
+
target_return: torch.Tensor | TensorDict,
|
|
694
|
+
loss_value: torch.Tensor | TensorDict,
|
|
695
|
+
loss_critic_type: str,
|
|
696
|
+
) -> tuple[torch.Tensor | TensorDict, torch.Tensor]:
|
|
697
|
+
"""Value clipping method for loss computation.
|
|
698
|
+
|
|
699
|
+
This method computes a clipped state value from the old state value and the state value,
|
|
700
|
+
and returns the most pessimistic value prediction between clipped and non-clipped options.
|
|
701
|
+
It also computes the clip fraction.
|
|
702
|
+
"""
|
|
703
|
+
pre_clipped = state_value - old_state_value
|
|
704
|
+
clipped = pre_clipped.clamp(-clip_value, clip_value)
|
|
705
|
+
with torch.no_grad():
|
|
706
|
+
clip_fraction = (pre_clipped != clipped).to(state_value.dtype).mean()
|
|
707
|
+
state_value_clipped = old_state_value + clipped
|
|
708
|
+
loss_value_clipped = distance_loss(
|
|
709
|
+
target_return,
|
|
710
|
+
state_value_clipped,
|
|
711
|
+
loss_function=loss_critic_type,
|
|
712
|
+
)
|
|
713
|
+
# Chose the most pessimistic value prediction between clipped and non-clipped
|
|
714
|
+
loss_value = torch.maximum(loss_value, loss_value_clipped)
|
|
715
|
+
return loss_value, clip_fraction
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
def _get_default_device(net):
|
|
719
|
+
for p in net.parameters():
|
|
720
|
+
return p.device
|
|
721
|
+
else:
|
|
722
|
+
return getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
|
|
723
|
+
|
|
724
|
+
|
|
725
|
+
def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimizer:
|
|
726
|
+
"""Groups multiple optimizers into a single one.
|
|
727
|
+
|
|
728
|
+
All optimizers are expected to have the same type.
|
|
729
|
+
"""
|
|
730
|
+
cls = None
|
|
731
|
+
params = []
|
|
732
|
+
for optimizer in optimizers:
|
|
733
|
+
if optimizer is None:
|
|
734
|
+
continue
|
|
735
|
+
if cls is None:
|
|
736
|
+
cls = type(optimizer)
|
|
737
|
+
if cls is not type(optimizer):
|
|
738
|
+
raise ValueError("Cannot group optimizers of different type.")
|
|
739
|
+
params.extend(optimizer.param_groups)
|
|
740
|
+
return cls(params)
|
|
741
|
+
|
|
742
|
+
|
|
743
|
+
def _sum_td_features(data: TensorDictBase) -> torch.Tensor:
|
|
744
|
+
# Sum all features and return a tensor
|
|
745
|
+
return data.sum(dim="feature", reduce=True)
|
|
746
|
+
|
|
747
|
+
|
|
748
|
+
def _maybe_get_or_select(
|
|
749
|
+
td,
|
|
750
|
+
key_or_keys,
|
|
751
|
+
target_shape=None,
|
|
752
|
+
padding_side: str = "left",
|
|
753
|
+
padding_value: int = 0,
|
|
754
|
+
):
|
|
755
|
+
if isinstance(key_or_keys, (str, tuple)):
|
|
756
|
+
return td.get(
|
|
757
|
+
key_or_keys,
|
|
758
|
+
as_padded_tensor=True,
|
|
759
|
+
padding_side=padding_side,
|
|
760
|
+
padding_value=padding_value,
|
|
761
|
+
)
|
|
762
|
+
result = td.select(*key_or_keys)
|
|
763
|
+
if target_shape is not None and result.shape != target_shape:
|
|
764
|
+
result.batch_size = target_shape
|
|
765
|
+
return result
|
|
766
|
+
|
|
767
|
+
|
|
768
|
+
def _maybe_add_or_extend_key(
|
|
769
|
+
tensor_keys: list[NestedKey],
|
|
770
|
+
key_or_list_of_keys: NestedKey | list[NestedKey],
|
|
771
|
+
prefix: NestedKey = None,
|
|
772
|
+
):
|
|
773
|
+
if prefix is not None:
|
|
774
|
+
if isinstance(key_or_list_of_keys, NestedKey):
|
|
775
|
+
tensor_keys.append(unravel_key((prefix, key_or_list_of_keys)))
|
|
776
|
+
else:
|
|
777
|
+
tensor_keys.extend([unravel_key((prefix, k)) for k in key_or_list_of_keys])
|
|
778
|
+
return
|
|
779
|
+
if isinstance(key_or_list_of_keys, NestedKey):
|
|
780
|
+
tensor_keys.append(key_or_list_of_keys)
|
|
781
|
+
else:
|
|
782
|
+
tensor_keys.extend(key_or_list_of_keys)
|