torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,753 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import abc
|
|
9
|
+
import functools
|
|
10
|
+
import warnings
|
|
11
|
+
from collections.abc import Iterator
|
|
12
|
+
from copy import deepcopy
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from tensordict import is_tensor_collection, TensorDict, TensorDictBase
|
|
17
|
+
from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams
|
|
18
|
+
from tensordict.utils import Buffer
|
|
19
|
+
from torch import nn
|
|
20
|
+
from torch.nn import Parameter
|
|
21
|
+
|
|
22
|
+
from torchrl._utils import rl_warnings
|
|
23
|
+
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
24
|
+
from torchrl.modules.tensordict_module.rnn import set_recurrent_mode
|
|
25
|
+
from torchrl.objectives.utils import ValueEstimators
|
|
26
|
+
from torchrl.objectives.value import ValueEstimatorBase
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
from torch.compiler import is_compiling
|
|
30
|
+
except ImportError:
|
|
31
|
+
from torch._dynamo import is_compiling
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _updater_check_forward_prehook(module, *args, **kwargs):
|
|
35
|
+
if (
|
|
36
|
+
not all(module._has_update_associated.values())
|
|
37
|
+
and rl_warnings()
|
|
38
|
+
and not is_compiling()
|
|
39
|
+
):
|
|
40
|
+
warnings.warn(
|
|
41
|
+
module.TARGET_NET_WARNING,
|
|
42
|
+
category=UserWarning,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _forward_wrapper(func):
|
|
47
|
+
@functools.wraps(func)
|
|
48
|
+
def new_forward(self, *args, **kwargs):
|
|
49
|
+
em = set_exploration_type(self.deterministic_sampling_mode)
|
|
50
|
+
em.__enter__()
|
|
51
|
+
rm = set_recurrent_mode(True)
|
|
52
|
+
rm.__enter__()
|
|
53
|
+
try:
|
|
54
|
+
return func(self, *args, **kwargs)
|
|
55
|
+
finally:
|
|
56
|
+
em.__exit__(None, None, None)
|
|
57
|
+
rm.__exit__(None, None, None)
|
|
58
|
+
|
|
59
|
+
return new_forward
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class _LossMeta(abc.ABCMeta):
|
|
63
|
+
def __init__(cls, name, bases, attr_dict):
|
|
64
|
+
super().__init__(name, bases, attr_dict)
|
|
65
|
+
cls.forward = _forward_wrapper(cls.forward)
|
|
66
|
+
for name, value in cls.__dict__.items():
|
|
67
|
+
if not name.startswith("_") and name.endswith("loss"):
|
|
68
|
+
setattr(cls, name, _forward_wrapper(value))
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class LossModule(TensorDictModuleBase, metaclass=_LossMeta):
|
|
72
|
+
"""A parent class for RL losses.
|
|
73
|
+
|
|
74
|
+
LossModule inherits from nn.Module. It is designed to read an input
|
|
75
|
+
TensorDict and return another tensordict
|
|
76
|
+
with loss keys named ``"loss_*"``.
|
|
77
|
+
|
|
78
|
+
Splitting the loss in its component can then be used by the trainer to log
|
|
79
|
+
the various loss values throughout
|
|
80
|
+
training. Other scalars present in the output tensordict will be logged too.
|
|
81
|
+
|
|
82
|
+
:cvar default_value_estimator: The default value type of the class.
|
|
83
|
+
Losses that require a value estimation are equipped with a default value
|
|
84
|
+
pointer. This class attribute indicates which value estimator will be
|
|
85
|
+
used if none other is specified.
|
|
86
|
+
The value estimator can be changed using the :meth:`~.make_value_estimator` method.
|
|
87
|
+
|
|
88
|
+
By default, the forward method is always decorated with a
|
|
89
|
+
gh :class:`torchrl.envs.ExplorationType.MEAN`
|
|
90
|
+
|
|
91
|
+
To utilize the ability configuring the tensordict keys via
|
|
92
|
+
:meth:`~.set_keys()` a subclass must define an _AcceptedKeys dataclass.
|
|
93
|
+
This dataclass should include all keys that are intended to be configurable.
|
|
94
|
+
In addition, the subclass must implement the
|
|
95
|
+
:meth:._forward_value_estimator_keys() method. This function is crucial for
|
|
96
|
+
forwarding any altered tensordict keys to the underlying value_estimator.
|
|
97
|
+
|
|
98
|
+
Examples:
|
|
99
|
+
>>> class MyLoss(LossModule):
|
|
100
|
+
>>> @dataclass
|
|
101
|
+
>>> class _AcceptedKeys:
|
|
102
|
+
>>> action = "action"
|
|
103
|
+
>>>
|
|
104
|
+
>>> def _forward_value_estimator_keys(self, **kwargs) -> None:
|
|
105
|
+
>>> pass
|
|
106
|
+
>>>
|
|
107
|
+
>>> loss = MyLoss()
|
|
108
|
+
>>> loss.set_keys(action="action2")
|
|
109
|
+
|
|
110
|
+
.. note:: When a policy that is wrapped or augmented with an exploration module is passed
|
|
111
|
+
to the loss, we want to deactivate the exploration through ``set_exploration_type(<exploration>)`` where
|
|
112
|
+
``<exploration>`` is either ``ExplorationType.MEAN``, ``ExplorationType.MODE`` or
|
|
113
|
+
``ExplorationType.DETERMINISTIC``. The default value is ``DETERMINISTIC`` and it is set
|
|
114
|
+
through the ``deterministic_sampling_mode`` loss attribute. If another
|
|
115
|
+
exploration mode is required (or if ``DETERMINISTIC`` is not available), one can
|
|
116
|
+
change the value of this attribute which will change the mode.
|
|
117
|
+
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
@dataclass
|
|
121
|
+
class _AcceptedKeys:
|
|
122
|
+
"""Maintains default values for all configurable tensordict keys.
|
|
123
|
+
|
|
124
|
+
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
|
|
125
|
+
default values.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
tensor_keys: _AcceptedKeys
|
|
129
|
+
_vmap_randomness = None
|
|
130
|
+
default_value_estimator: ValueEstimators = None
|
|
131
|
+
use_prioritized_weights: str | bool = "auto"
|
|
132
|
+
|
|
133
|
+
deterministic_sampling_mode: ExplorationType = ExplorationType.DETERMINISTIC
|
|
134
|
+
|
|
135
|
+
SEP = "."
|
|
136
|
+
TARGET_NET_WARNING = (
|
|
137
|
+
"No target network updater has been associated "
|
|
138
|
+
"with this loss module, but target parameters have been found. "
|
|
139
|
+
"While this is supported, it is expected that the target network "
|
|
140
|
+
"updates will be manually performed. You can deactivate this warning "
|
|
141
|
+
"by turning the RL_WARNINGS env variable to False."
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def tensor_keys(self) -> _AcceptedKeys:
|
|
146
|
+
return self._tensor_keys
|
|
147
|
+
|
|
148
|
+
def __new__(cls, *args, **kwargs):
|
|
149
|
+
self = super().__new__(cls)
|
|
150
|
+
return self
|
|
151
|
+
|
|
152
|
+
def __init__(self):
|
|
153
|
+
super().__init__()
|
|
154
|
+
self._cache = {}
|
|
155
|
+
self._param_maps = {}
|
|
156
|
+
self._value_estimator = None
|
|
157
|
+
self._has_update_associated = {}
|
|
158
|
+
self.value_type = self.default_value_estimator
|
|
159
|
+
self._tensor_keys = self._AcceptedKeys()
|
|
160
|
+
self.register_forward_pre_hook(_updater_check_forward_prehook)
|
|
161
|
+
|
|
162
|
+
@property
|
|
163
|
+
def functional(self):
|
|
164
|
+
"""Whether the module is functional.
|
|
165
|
+
|
|
166
|
+
Unless it has been specifically designed not to be functional, all losses are functional.
|
|
167
|
+
"""
|
|
168
|
+
return True
|
|
169
|
+
|
|
170
|
+
def get_stateful_net(self, network_name: str, copy: bool | None = None):
|
|
171
|
+
"""Returns a stateful version of the network.
|
|
172
|
+
|
|
173
|
+
This can be used to initialize parameters.
|
|
174
|
+
|
|
175
|
+
Such networks will often not be callable out-of-the-box and will require a `vmap` call
|
|
176
|
+
to be executable.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
network_name (str): the network name to gather.
|
|
180
|
+
copy (bool, optional): if ``True``, a deepcopy of the network is made.
|
|
181
|
+
Defaults to ``True``.
|
|
182
|
+
|
|
183
|
+
.. note:: if the module is not functional, no copy is made.
|
|
184
|
+
"""
|
|
185
|
+
net = getattr(self, network_name)
|
|
186
|
+
if not self.functional:
|
|
187
|
+
if copy is not None and copy:
|
|
188
|
+
raise RuntimeError("Cannot copy module in non-functional mode.")
|
|
189
|
+
return net
|
|
190
|
+
copy = True if copy is None else copy
|
|
191
|
+
if copy:
|
|
192
|
+
net = deepcopy(net)
|
|
193
|
+
params = getattr(self, network_name + "_params")
|
|
194
|
+
params.to_module(net)
|
|
195
|
+
return net
|
|
196
|
+
|
|
197
|
+
def from_stateful_net(self, network_name: str, stateful_net: nn.Module):
|
|
198
|
+
"""Populates the parameters of a model given a stateful version of the network.
|
|
199
|
+
|
|
200
|
+
See :meth:`~.get_stateful_net` for details on how to gather a stateful version of the network.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
network_name (str): the network name to reset.
|
|
204
|
+
stateful_net (nn.Module): the stateful network from which the params should be
|
|
205
|
+
gathered.
|
|
206
|
+
|
|
207
|
+
"""
|
|
208
|
+
if not self.functional:
|
|
209
|
+
getattr(self, network_name).load_state_dict(stateful_net.state_dict())
|
|
210
|
+
return
|
|
211
|
+
params = TensorDict.from_module(stateful_net, as_module=True)
|
|
212
|
+
keyset0 = set(params.keys(True, True))
|
|
213
|
+
self_params = getattr(self, network_name + "_params")
|
|
214
|
+
keyset1 = set(self_params.keys(True, True))
|
|
215
|
+
if keyset0 != keyset1:
|
|
216
|
+
raise RuntimeError(
|
|
217
|
+
f"The keys of params and provided module differ: "
|
|
218
|
+
f"{keyset1 - keyset0} are in self.params and not in the module, "
|
|
219
|
+
f"{keyset0 - keyset1} are in the module but not in self.params."
|
|
220
|
+
)
|
|
221
|
+
self_params.data.update_(params.data)
|
|
222
|
+
|
|
223
|
+
def _set_deprecated_ctor_keys(self, **kwargs) -> None:
|
|
224
|
+
for key, value in kwargs.items():
|
|
225
|
+
if value is not None:
|
|
226
|
+
raise RuntimeError(
|
|
227
|
+
f"Setting '{key}' via the constructor is deprecated, use .set_keys(<key>='some_key') instead.",
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
def set_keys(self, **kwargs) -> None:
|
|
231
|
+
"""Set tensordict key names.
|
|
232
|
+
|
|
233
|
+
Examples:
|
|
234
|
+
>>> from torchrl.objectives import DQNLoss
|
|
235
|
+
>>> # initialize the DQN loss
|
|
236
|
+
>>> actor = torch.nn.Linear(3, 4)
|
|
237
|
+
>>> dqn_loss = DQNLoss(actor, action_space="one-hot")
|
|
238
|
+
>>> dqn_loss.set_keys(priority_key="td_error", action_value_key="action_value")
|
|
239
|
+
"""
|
|
240
|
+
for key, value in kwargs.items():
|
|
241
|
+
if key not in self._AcceptedKeys.__dataclass_fields__:
|
|
242
|
+
raise ValueError(
|
|
243
|
+
f"{key} is not an accepted tensordict key. Accepted keys are: {self._AcceptedKeys.__dataclass_fields__}."
|
|
244
|
+
)
|
|
245
|
+
if value is not None:
|
|
246
|
+
setattr(self.tensor_keys, key, value)
|
|
247
|
+
else:
|
|
248
|
+
setattr(self.tensor_keys, key, self.default_keys().key)
|
|
249
|
+
|
|
250
|
+
try:
|
|
251
|
+
self._forward_value_estimator_keys(**kwargs)
|
|
252
|
+
except AttributeError as err:
|
|
253
|
+
raise AttributeError(
|
|
254
|
+
"To utilize `.set_keys(...)` for tensordict key configuration, the subclassed loss module "
|
|
255
|
+
"must define an _AcceptedKeys dataclass containing all keys intended for configuration. "
|
|
256
|
+
"Moreover, the subclass needs to implement `._forward_value_estimator_keys()` method to "
|
|
257
|
+
"facilitate forwarding of any modified tensordict keys to the underlying value_estimator."
|
|
258
|
+
) from err
|
|
259
|
+
|
|
260
|
+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
261
|
+
"""It is designed to read an input TensorDict and return another tensordict with loss keys named "loss*".
|
|
262
|
+
|
|
263
|
+
Splitting the loss in its component can then be used by the trainer to log the various loss values throughout
|
|
264
|
+
training. Other scalars present in the output tensordict will be logged too.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
tensordict: an input tensordict with the values required to compute the loss.
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
A new tensordict with no batch dimension containing various loss scalars which will be named "loss*". It
|
|
271
|
+
is essential that the losses are returned with this name as they will be read by the trainer before
|
|
272
|
+
backpropagation.
|
|
273
|
+
|
|
274
|
+
"""
|
|
275
|
+
raise NotImplementedError
|
|
276
|
+
|
|
277
|
+
def convert_to_functional(
|
|
278
|
+
self,
|
|
279
|
+
module: TensorDictModule,
|
|
280
|
+
module_name: str,
|
|
281
|
+
expand_dim: int | None = None,
|
|
282
|
+
create_target_params: bool = False,
|
|
283
|
+
compare_against: list[Parameter] | None = None,
|
|
284
|
+
**kwargs,
|
|
285
|
+
) -> None:
|
|
286
|
+
"""Converts a module to functional to be used in the loss.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
module (TensorDictModule or compatible): a stateful tensordict module.
|
|
290
|
+
Parameters from this module will be isolated in the `<module_name>_params`
|
|
291
|
+
attribute and a stateless version of the module will be registered
|
|
292
|
+
under the `module_name` attribute.
|
|
293
|
+
module_name (str): name where the module will be found.
|
|
294
|
+
The parameters of the module will be found under ``loss_module.<module_name>_params``
|
|
295
|
+
whereas the module will be found under ``loss_module.<module_name>``.
|
|
296
|
+
expand_dim (int, optional): if provided, the parameters of the module
|
|
297
|
+
will be expanded ``N`` times, where ``N = expand_dim`` along the
|
|
298
|
+
first dimension. This option is to be used whenever a target
|
|
299
|
+
network with more than one configuration is to be used.
|
|
300
|
+
|
|
301
|
+
.. note::
|
|
302
|
+
If a ``compare_against`` list of values is provided, the
|
|
303
|
+
resulting parameters will simply be a detached expansion
|
|
304
|
+
of the original parameters. If ``compare_against`` is not
|
|
305
|
+
provided, the value of the parameters will be resampled uniformly
|
|
306
|
+
between the minimum and maximum value of the parameter content.
|
|
307
|
+
|
|
308
|
+
create_target_params (bool, optional): if ``True``, a detached
|
|
309
|
+
copy of the parameter will be available to feed a target network
|
|
310
|
+
under the name ``loss_module.<module_name>_target_params``.
|
|
311
|
+
If ``False`` (default), this attribute will still be available
|
|
312
|
+
but it will be a detached instance of the parameters, not a copy.
|
|
313
|
+
In other words, any modification of the parameter value
|
|
314
|
+
will directly be reflected in the target parameters.
|
|
315
|
+
compare_against (iterable of parameters, optional): if provided,
|
|
316
|
+
this list of parameters will be used as a comparison set for
|
|
317
|
+
the parameters of the module. If the parameters are expanded
|
|
318
|
+
(``expand_dim > 0``), the resulting parameters for the module
|
|
319
|
+
will be a simple expansion of the original parameter. Otherwise,
|
|
320
|
+
the resulting parameters will be a detached version of the
|
|
321
|
+
original parameters. If ``None``, the resulting parameters
|
|
322
|
+
will carry gradients as expected.
|
|
323
|
+
|
|
324
|
+
"""
|
|
325
|
+
for name in (
|
|
326
|
+
module_name,
|
|
327
|
+
module_name + "_params",
|
|
328
|
+
"target_" + module_name + "_params",
|
|
329
|
+
):
|
|
330
|
+
if name not in self.__class__.__annotations__.keys():
|
|
331
|
+
warnings.warn(
|
|
332
|
+
f"The name {name} wasn't part of the annotations ({self.__class__.__annotations__.keys()}). Make sure it is present in the definition class."
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
if kwargs:
|
|
336
|
+
raise TypeError(f"Unrecognised keyword arguments {list(kwargs.keys())}")
|
|
337
|
+
# To make it robust to device casting, we must register list of
|
|
338
|
+
# tensors as lazy calls to `getattr(self, name_of_tensor)`.
|
|
339
|
+
# Otherwise, casting the module to a device will keep old references
|
|
340
|
+
# to uncast tensors
|
|
341
|
+
sep = self.SEP
|
|
342
|
+
if isinstance(module, (list, tuple)):
|
|
343
|
+
if len(module) != expand_dim:
|
|
344
|
+
raise RuntimeError(
|
|
345
|
+
"The ``expand_dim`` value must match the length of the module list/tuple "
|
|
346
|
+
"if a single module isn't provided."
|
|
347
|
+
)
|
|
348
|
+
params = TensorDict.from_modules(
|
|
349
|
+
*module, as_module=True, expand_identical=True
|
|
350
|
+
)
|
|
351
|
+
else:
|
|
352
|
+
params = TensorDict.from_module(module, as_module=True)
|
|
353
|
+
|
|
354
|
+
for key in params.keys(True):
|
|
355
|
+
if sep in key:
|
|
356
|
+
raise KeyError(
|
|
357
|
+
f"The key {key} contains the '_sep_' pattern which is prohibited. Consider renaming the parameter / buffer."
|
|
358
|
+
)
|
|
359
|
+
if compare_against is not None:
|
|
360
|
+
compare_against = set(compare_against)
|
|
361
|
+
else:
|
|
362
|
+
compare_against = set()
|
|
363
|
+
if expand_dim:
|
|
364
|
+
# Expands the dims of params and buffers.
|
|
365
|
+
# If the param already exist in the module, we return a simple expansion of the
|
|
366
|
+
# original one. Otherwise, we expand and resample it.
|
|
367
|
+
# For buffers, a cloned expansion (or equivalently a repeat) is returned.
|
|
368
|
+
|
|
369
|
+
def _compare_and_expand(param):
|
|
370
|
+
if is_tensor_collection(param):
|
|
371
|
+
return param._apply_nest(
|
|
372
|
+
_compare_and_expand,
|
|
373
|
+
batch_size=[expand_dim, *param.shape],
|
|
374
|
+
filter_empty=False,
|
|
375
|
+
call_on_nested=True,
|
|
376
|
+
)
|
|
377
|
+
if not isinstance(param, nn.Parameter):
|
|
378
|
+
buffer = param.expand(expand_dim, *param.shape).clone()
|
|
379
|
+
return buffer
|
|
380
|
+
if param in compare_against:
|
|
381
|
+
expanded_param = param.data.expand(expand_dim, *param.shape)
|
|
382
|
+
# the expanded parameter must be sent to device when to()
|
|
383
|
+
# is called:
|
|
384
|
+
return expanded_param
|
|
385
|
+
else:
|
|
386
|
+
p_out = param.expand(expand_dim, *param.shape).clone()
|
|
387
|
+
p_out = nn.Parameter(
|
|
388
|
+
p_out.uniform_(
|
|
389
|
+
p_out.data.min().item(), p_out.data.max().item()
|
|
390
|
+
).requires_grad_()
|
|
391
|
+
)
|
|
392
|
+
return p_out
|
|
393
|
+
|
|
394
|
+
params = TensorDictParams(
|
|
395
|
+
params.apply(
|
|
396
|
+
_compare_and_expand,
|
|
397
|
+
batch_size=[expand_dim, *params.shape],
|
|
398
|
+
filter_empty=False,
|
|
399
|
+
call_on_nested=True,
|
|
400
|
+
),
|
|
401
|
+
no_convert=True,
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
param_name = module_name + "_params"
|
|
405
|
+
|
|
406
|
+
prev_set_params = set(self.parameters())
|
|
407
|
+
|
|
408
|
+
# register parameters and buffers
|
|
409
|
+
for key, parameter in list(params.items(True, True)):
|
|
410
|
+
if parameter not in prev_set_params:
|
|
411
|
+
pass
|
|
412
|
+
elif compare_against is not None and parameter in compare_against:
|
|
413
|
+
params.set(key, parameter.data)
|
|
414
|
+
|
|
415
|
+
setattr(self, param_name, params)
|
|
416
|
+
|
|
417
|
+
# Set the module in the __dict__ directly to avoid listing its params
|
|
418
|
+
# A deepcopy with meta device could be used but that assumes that the model is copyable!
|
|
419
|
+
self.__dict__[module_name] = module
|
|
420
|
+
|
|
421
|
+
name_params_target = "target_" + module_name
|
|
422
|
+
if create_target_params:
|
|
423
|
+
# if create_target_params:
|
|
424
|
+
# we create a TensorDictParams to keep the target params as Buffer instances
|
|
425
|
+
target_params = TensorDictParams(
|
|
426
|
+
params.apply(
|
|
427
|
+
_make_target_param(clone=create_target_params), filter_empty=False
|
|
428
|
+
),
|
|
429
|
+
no_convert=True,
|
|
430
|
+
)
|
|
431
|
+
setattr(self, name_params_target + "_params", target_params)
|
|
432
|
+
self._has_update_associated[module_name] = not create_target_params
|
|
433
|
+
|
|
434
|
+
def _clear_weakrefs(self, *tds):
|
|
435
|
+
if is_compiling():
|
|
436
|
+
# Waiting for weakrefs reconstruct to be supported by compile
|
|
437
|
+
for td in tds:
|
|
438
|
+
if isinstance(td, str):
|
|
439
|
+
td = getattr(self, td, None)
|
|
440
|
+
if not is_tensor_collection(td):
|
|
441
|
+
continue
|
|
442
|
+
td.clear_refs_for_compile_()
|
|
443
|
+
|
|
444
|
+
def __getattr__(self, item):
|
|
445
|
+
if item.startswith("target_") and item.endswith("_params"):
|
|
446
|
+
params = self._modules.get(item, None)
|
|
447
|
+
if params is None:
|
|
448
|
+
# no target param, take detached data
|
|
449
|
+
params = getattr(self, item[7:])
|
|
450
|
+
params = params.data
|
|
451
|
+
elif (
|
|
452
|
+
not self._has_update_associated[item[7:-7]]
|
|
453
|
+
and rl_warnings()
|
|
454
|
+
and not is_compiling()
|
|
455
|
+
):
|
|
456
|
+
# no updater associated
|
|
457
|
+
warnings.warn(
|
|
458
|
+
self.TARGET_NET_WARNING,
|
|
459
|
+
category=UserWarning,
|
|
460
|
+
)
|
|
461
|
+
return params
|
|
462
|
+
return super().__getattr__(item)
|
|
463
|
+
|
|
464
|
+
def _apply(self, fn):
|
|
465
|
+
# any call to apply erases the cache: the reason is that detached
|
|
466
|
+
# params will fail to be cast so we need to get the cache back
|
|
467
|
+
self._erase_cache()
|
|
468
|
+
return super()._apply(fn)
|
|
469
|
+
|
|
470
|
+
def _erase_cache(self):
|
|
471
|
+
for key in list(self.__dict__):
|
|
472
|
+
if key.startswith("_cache"):
|
|
473
|
+
delattr(self, key)
|
|
474
|
+
|
|
475
|
+
def _networks(self) -> Iterator[nn.Module]:
|
|
476
|
+
for item in self.__dir__():
|
|
477
|
+
if isinstance(item, nn.Module):
|
|
478
|
+
yield item
|
|
479
|
+
|
|
480
|
+
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
|
481
|
+
for _, param in self.named_parameters(recurse=recurse):
|
|
482
|
+
yield param
|
|
483
|
+
|
|
484
|
+
def named_parameters(
|
|
485
|
+
self, prefix: str = "", recurse: bool = True
|
|
486
|
+
) -> Iterator[tuple[str, Parameter]]:
|
|
487
|
+
for name, param in super().named_parameters(prefix=prefix, recurse=recurse):
|
|
488
|
+
if not name.startswith("_target"):
|
|
489
|
+
yield name, param
|
|
490
|
+
|
|
491
|
+
def reset(self) -> None:
|
|
492
|
+
# mainly used for PPO with KL target
|
|
493
|
+
pass
|
|
494
|
+
|
|
495
|
+
def _maybe_get_priority_weight(
|
|
496
|
+
self, tensordict: TensorDictBase
|
|
497
|
+
) -> torch.Tensor | None:
|
|
498
|
+
"""Extract priority weights from tensordict if prioritized replay is enabled.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
tensordict (TensorDictBase): The input tensordict that may contain priority weights.
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
torch.Tensor | None: The priority weights if available and enabled, None otherwise.
|
|
505
|
+
"""
|
|
506
|
+
weights = None
|
|
507
|
+
if (
|
|
508
|
+
self.use_prioritized_weights in (True, "auto")
|
|
509
|
+
and self.tensor_keys.priority_weight in tensordict.keys()
|
|
510
|
+
):
|
|
511
|
+
weights = tensordict.get(self.tensor_keys.priority_weight)
|
|
512
|
+
return weights
|
|
513
|
+
|
|
514
|
+
def _reset_module_parameters(self, module_name, module):
|
|
515
|
+
params_name = f"{module_name}_params"
|
|
516
|
+
target_name = f"target_{module_name}_params"
|
|
517
|
+
params = self._modules.get(params_name, None)
|
|
518
|
+
target = self._modules.get(target_name, None)
|
|
519
|
+
|
|
520
|
+
if params is not None:
|
|
521
|
+
with params.to_module(module):
|
|
522
|
+
module.reset_parameters_recursive()
|
|
523
|
+
else:
|
|
524
|
+
module.reset_parameters_recursive()
|
|
525
|
+
|
|
526
|
+
if target is not None:
|
|
527
|
+
with target.to_module(module):
|
|
528
|
+
module.reset_parameters_recursive()
|
|
529
|
+
|
|
530
|
+
def reset_parameters_recursive(
|
|
531
|
+
self,
|
|
532
|
+
):
|
|
533
|
+
"""Reset the parameters of the module."""
|
|
534
|
+
for key, item in self.__dict__.items():
|
|
535
|
+
if isinstance(item, nn.Module):
|
|
536
|
+
self._reset_module_parameters(key, item)
|
|
537
|
+
|
|
538
|
+
@property
|
|
539
|
+
def value_estimator(self) -> ValueEstimatorBase:
|
|
540
|
+
"""The value function blends in the reward and value estimate(s) from upcoming state(s)/state-action pair(s) into a target value estimate for the value network."""
|
|
541
|
+
out = self._value_estimator
|
|
542
|
+
if out is None:
|
|
543
|
+
self._default_value_estimator()
|
|
544
|
+
return self._value_estimator
|
|
545
|
+
return out
|
|
546
|
+
|
|
547
|
+
@value_estimator.setter
|
|
548
|
+
def value_estimator(self, value):
|
|
549
|
+
self._value_estimator = value
|
|
550
|
+
|
|
551
|
+
def _default_value_estimator(self):
|
|
552
|
+
"""A value-function constructor when none is provided.
|
|
553
|
+
|
|
554
|
+
No kwarg should be present as default parameters should be retrieved
|
|
555
|
+
from :obj:`torchrl.objectives.utils.DEFAULT_VALUE_FUN_PARAMS`.
|
|
556
|
+
|
|
557
|
+
"""
|
|
558
|
+
self.make_value_estimator(
|
|
559
|
+
self.default_value_estimator, device=self._default_device
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
@property
|
|
563
|
+
def _default_device(self) -> torch.device | None:
|
|
564
|
+
"""A util to find the default device.
|
|
565
|
+
|
|
566
|
+
Returns ``None`` if parameters are spread across multiple devices.
|
|
567
|
+
"""
|
|
568
|
+
devices = set()
|
|
569
|
+
for p in self.parameters():
|
|
570
|
+
devices.add(p.device)
|
|
571
|
+
if len(devices) == 1:
|
|
572
|
+
return list(devices)[0]
|
|
573
|
+
return None
|
|
574
|
+
|
|
575
|
+
def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
|
|
576
|
+
"""Value-function constructor.
|
|
577
|
+
|
|
578
|
+
If the non-default value function is wanted, it must be built using
|
|
579
|
+
this method.
|
|
580
|
+
|
|
581
|
+
Args:
|
|
582
|
+
value_type (ValueEstimators, ValueEstimatorBase, or type): The value
|
|
583
|
+
estimator to use. This can be one of the following:
|
|
584
|
+
|
|
585
|
+
- A :class:`~torchrl.objectives.utils.ValueEstimators` enum type
|
|
586
|
+
indicating which value function to use. If none is provided,
|
|
587
|
+
the default stored in the ``default_value_estimator``
|
|
588
|
+
attribute will be used.
|
|
589
|
+
- A :class:`~torchrl.objectives.value.ValueEstimatorBase` instance,
|
|
590
|
+
which will be used directly as the value estimator.
|
|
591
|
+
- A :class:`~torchrl.objectives.value.ValueEstimatorBase` subclass,
|
|
592
|
+
which will be instantiated with the provided ``hyperparams``.
|
|
593
|
+
|
|
594
|
+
The resulting value estimator class will be registered in
|
|
595
|
+
``self.value_type``, allowing future refinements.
|
|
596
|
+
**hyperparams: hyperparameters to use for the value function.
|
|
597
|
+
If not provided, the value indicated by
|
|
598
|
+
:func:`~torchrl.objectives.utils.default_value_kwargs` will be
|
|
599
|
+
used. When passing a ``ValueEstimatorBase`` subclass, these
|
|
600
|
+
hyperparameters are passed directly to the class constructor.
|
|
601
|
+
|
|
602
|
+
Returns:
|
|
603
|
+
self: Returns the loss module for method chaining.
|
|
604
|
+
|
|
605
|
+
Examples:
|
|
606
|
+
>>> from torchrl.objectives import DQNLoss
|
|
607
|
+
>>> # initialize the DQN loss
|
|
608
|
+
>>> actor = torch.nn.Linear(3, 4)
|
|
609
|
+
>>> dqn_loss = DQNLoss(actor, action_space="one-hot")
|
|
610
|
+
>>> # updating the parameters of the default value estimator
|
|
611
|
+
>>> dqn_loss.make_value_estimator(gamma=0.9)
|
|
612
|
+
>>> dqn_loss.make_value_estimator(
|
|
613
|
+
... ValueEstimators.TD1,
|
|
614
|
+
... gamma=0.9)
|
|
615
|
+
>>> # if we want to change the gamma value
|
|
616
|
+
>>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)
|
|
617
|
+
|
|
618
|
+
Using a :class:`~torchrl.objectives.value.ValueEstimatorBase` subclass:
|
|
619
|
+
|
|
620
|
+
>>> from torchrl.objectives.value import TD0Estimator
|
|
621
|
+
>>> dqn_loss.make_value_estimator(TD0Estimator, gamma=0.99, value_network=value_net)
|
|
622
|
+
|
|
623
|
+
Using a :class:`~torchrl.objectives.value.ValueEstimatorBase` instance:
|
|
624
|
+
|
|
625
|
+
>>> from torchrl.objectives.value import GAE
|
|
626
|
+
>>> gae = GAE(gamma=0.99, lmbda=0.95, value_network=value_net)
|
|
627
|
+
>>> ppo_loss.make_value_estimator(gae)
|
|
628
|
+
|
|
629
|
+
"""
|
|
630
|
+
if value_type is None:
|
|
631
|
+
value_type = self.default_value_estimator
|
|
632
|
+
|
|
633
|
+
if isinstance(value_type, ValueEstimatorBase):
|
|
634
|
+
self._value_estimator = value_type
|
|
635
|
+
self.value_type = type(value_type)
|
|
636
|
+
return self
|
|
637
|
+
|
|
638
|
+
if isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase):
|
|
639
|
+
if "device" not in hyperparams:
|
|
640
|
+
device = self._default_device
|
|
641
|
+
if device is not None:
|
|
642
|
+
hyperparams["device"] = device
|
|
643
|
+
self._value_estimator = value_type(**hyperparams)
|
|
644
|
+
self.value_type = value_type
|
|
645
|
+
return self
|
|
646
|
+
|
|
647
|
+
self.value_type = value_type
|
|
648
|
+
if value_type == ValueEstimators.TD1:
|
|
649
|
+
raise NotImplementedError(
|
|
650
|
+
f"Value type {value_type} it not implemented for loss {type(self)}."
|
|
651
|
+
)
|
|
652
|
+
elif value_type == ValueEstimators.TD0:
|
|
653
|
+
raise NotImplementedError(
|
|
654
|
+
f"Value type {value_type} it not implemented for loss {type(self)}."
|
|
655
|
+
)
|
|
656
|
+
elif value_type == ValueEstimators.GAE:
|
|
657
|
+
raise NotImplementedError(
|
|
658
|
+
f"Value type {value_type} it not implemented for loss {type(self)}."
|
|
659
|
+
)
|
|
660
|
+
elif value_type == ValueEstimators.VTrace:
|
|
661
|
+
raise NotImplementedError(
|
|
662
|
+
f"Value type {value_type} it not implemented for loss {type(self)}."
|
|
663
|
+
)
|
|
664
|
+
elif value_type == ValueEstimators.TDLambda:
|
|
665
|
+
raise NotImplementedError(
|
|
666
|
+
f"Value type {value_type} it not implemented for loss {type(self)}."
|
|
667
|
+
)
|
|
668
|
+
else:
|
|
669
|
+
raise NotImplementedError(f"Unknown value type {value_type}")
|
|
670
|
+
|
|
671
|
+
return self
|
|
672
|
+
|
|
673
|
+
@property
|
|
674
|
+
def vmap_randomness(self):
|
|
675
|
+
"""Vmap random mode.
|
|
676
|
+
|
|
677
|
+
The vmap randomness mode controls what :func:`~torch.vmap` should do when dealing with
|
|
678
|
+
functions with a random outcome such as :func:`~torch.randn` and :func:`~torch.rand`.
|
|
679
|
+
If `"error"`, any random function will raise an exception indicating that `vmap` does not
|
|
680
|
+
know how to handle the random call.
|
|
681
|
+
|
|
682
|
+
If `"different"`, every element of the batch along which vmap is being called will
|
|
683
|
+
behave differently. If `"same"`, vmaps will copy the same result across all elements.
|
|
684
|
+
|
|
685
|
+
``vmap_randomness`` defaults to `"error"` if no random module is detected, and to `"different"` in
|
|
686
|
+
other cases. By default, only a limited number of modules are listed as random, but the list can be extended
|
|
687
|
+
using the :func:`~torchrl.objectives.common.add_random_module` function.
|
|
688
|
+
|
|
689
|
+
This property supports setting its value.
|
|
690
|
+
|
|
691
|
+
"""
|
|
692
|
+
if self._vmap_randomness is None:
|
|
693
|
+
import torchrl.objectives.utils
|
|
694
|
+
|
|
695
|
+
main_modules = list(self.__dict__.values()) + list(self.children())
|
|
696
|
+
modules = (
|
|
697
|
+
module
|
|
698
|
+
for main_module in main_modules
|
|
699
|
+
if isinstance(main_module, nn.Module)
|
|
700
|
+
for module in main_module.modules()
|
|
701
|
+
)
|
|
702
|
+
for val in modules:
|
|
703
|
+
if isinstance(val, torchrl.objectives.utils.RANDOM_MODULE_LIST):
|
|
704
|
+
self._vmap_randomness = "different"
|
|
705
|
+
break
|
|
706
|
+
else:
|
|
707
|
+
self._vmap_randomness = "error"
|
|
708
|
+
|
|
709
|
+
return self._vmap_randomness
|
|
710
|
+
|
|
711
|
+
def set_vmap_randomness(self, value):
|
|
712
|
+
if value not in ("error", "same", "different"):
|
|
713
|
+
raise ValueError(
|
|
714
|
+
"Wrong vmap randomness, should be one of 'error', 'same' or 'different'."
|
|
715
|
+
)
|
|
716
|
+
self._vmap_randomness = value
|
|
717
|
+
self._make_vmap()
|
|
718
|
+
|
|
719
|
+
@staticmethod
|
|
720
|
+
def _make_meta_params(param):
|
|
721
|
+
is_param = isinstance(param, nn.Parameter)
|
|
722
|
+
|
|
723
|
+
pd = param.detach().to("meta")
|
|
724
|
+
|
|
725
|
+
if is_param:
|
|
726
|
+
pd = nn.Parameter(pd, requires_grad=False)
|
|
727
|
+
return pd
|
|
728
|
+
|
|
729
|
+
def _make_vmap(self):
|
|
730
|
+
"""Caches thevmap callers to reduce the overhead at runtime."""
|
|
731
|
+
raise NotImplementedError(
|
|
732
|
+
f"_make_vmap has been called but is not implemented for loss of type {type(self).__name__}."
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
class _make_target_param:
|
|
737
|
+
def __init__(self, clone):
|
|
738
|
+
self.clone = clone
|
|
739
|
+
|
|
740
|
+
def __call__(self, x):
|
|
741
|
+
x = x.data.clone() if self.clone else x.data
|
|
742
|
+
if isinstance(x, nn.Parameter):
|
|
743
|
+
return Buffer(x)
|
|
744
|
+
return x
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
def add_random_module(module):
|
|
748
|
+
"""Adds a random module to the list of modules that will be detected by :meth:`~torchrl.objectives.LossModule.vmap_randomness` as random."""
|
|
749
|
+
import torchrl.objectives.utils
|
|
750
|
+
|
|
751
|
+
torchrl.objectives.utils.RANDOM_MODULE_LIST = (
|
|
752
|
+
torchrl.objectives.utils.RANDOM_MODULE_LIST + (module,)
|
|
753
|
+
)
|