torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,529 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import importlib.util
|
|
9
|
+
import inspect
|
|
10
|
+
import re
|
|
11
|
+
import warnings
|
|
12
|
+
from collections.abc import Iterable
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
from tensordict import TensorDictBase, unravel_key_list
|
|
16
|
+
from tensordict.nn import dispatch, TensorDictModule, TensorDictModuleBase
|
|
17
|
+
from tensordict.utils import NestedKey
|
|
18
|
+
from torch import nn
|
|
19
|
+
from torch.nn import functional as F
|
|
20
|
+
|
|
21
|
+
from torchrl.data.tensor_specs import Composite, TensorSpec
|
|
22
|
+
from torchrl.data.utils import DEVICE_TYPING
|
|
23
|
+
|
|
24
|
+
_has_functorch = importlib.util.find_spec("functorch") is not None
|
|
25
|
+
if _has_functorch:
|
|
26
|
+
from functorch import FunctionalModule, FunctionalModuleWithBuffers
|
|
27
|
+
else:
|
|
28
|
+
warnings.warn(
|
|
29
|
+
"failed to import functorch. TorchRL's features that do not require "
|
|
30
|
+
"functional programming should work, but functionality and performance "
|
|
31
|
+
"may be affected. Consider installing functorch and/or upgrading pytorch."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
class FunctionalModule: # noqa: D101
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
class FunctionalModuleWithBuffers: # noqa: D101
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _check_all_str(list_of_str, first_level=True):
|
|
42
|
+
if isinstance(list_of_str, str) and first_level:
|
|
43
|
+
raise RuntimeError(
|
|
44
|
+
f"Expected a list of strings but got a string: {list_of_str}"
|
|
45
|
+
)
|
|
46
|
+
elif not isinstance(list_of_str, str):
|
|
47
|
+
try:
|
|
48
|
+
return [_check_all_str(item, False) for item in list_of_str]
|
|
49
|
+
except Exception as err:
|
|
50
|
+
raise TypeError(
|
|
51
|
+
f"Expected a list of strings but got: {list_of_str}."
|
|
52
|
+
) from err
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _forward_hook_safe_action(module, tensordict_in, tensordict_out):
|
|
56
|
+
try:
|
|
57
|
+
spec = module.spec
|
|
58
|
+
if len(module.out_keys) > 1 and not isinstance(spec, Composite):
|
|
59
|
+
raise RuntimeError(
|
|
60
|
+
"safe TensorDictModules with multiple out_keys require a Composite with matching keys. Got "
|
|
61
|
+
f"keys {module.out_keys}."
|
|
62
|
+
)
|
|
63
|
+
elif not isinstance(spec, Composite):
|
|
64
|
+
out_key = module.out_keys[0]
|
|
65
|
+
keys = [out_key]
|
|
66
|
+
values = [spec]
|
|
67
|
+
else:
|
|
68
|
+
# Make dynamo happy with the list creation
|
|
69
|
+
keys = [key for key in spec.keys(True, True)] # noqa: C416
|
|
70
|
+
values = [spec[key] for key in keys]
|
|
71
|
+
for _spec, _key in zip(values, keys):
|
|
72
|
+
if _spec is None:
|
|
73
|
+
continue
|
|
74
|
+
item = tensordict_out.get(_key)
|
|
75
|
+
if item is None:
|
|
76
|
+
# this will happen when an exploration (e.g. OU) writes a key only
|
|
77
|
+
# during exploration, but is missing otherwise.
|
|
78
|
+
# it's fine since what we want here it to make sure that a key
|
|
79
|
+
# is within bounds if it is present
|
|
80
|
+
continue
|
|
81
|
+
tensordict_out.set(
|
|
82
|
+
_key,
|
|
83
|
+
_spec.project(item),
|
|
84
|
+
)
|
|
85
|
+
except RuntimeError as err:
|
|
86
|
+
if re.search(
|
|
87
|
+
"attempting to use a Tensor in some data-dependent control flow", str(err)
|
|
88
|
+
):
|
|
89
|
+
# "_is_stateless" in module.__dict__ and module._is_stateless:
|
|
90
|
+
raise RuntimeError(
|
|
91
|
+
"vmap cannot be used with safe=True, consider turning the safe mode off."
|
|
92
|
+
) from err
|
|
93
|
+
else:
|
|
94
|
+
raise err
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class SafeModule(TensorDictModule):
|
|
98
|
+
""":class:`tensordict.nn.TensorDictModule` subclass that accepts a :class:`~torchrl.data.TensorSpec` as argument to control the output domain.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
module (nn.Module): a nn.Module used to map the input to the output
|
|
102
|
+
parameter space. Can be a functional
|
|
103
|
+
module (FunctionalModule or FunctionalModuleWithBuffers), in which
|
|
104
|
+
case the :obj:`forward` method will expect
|
|
105
|
+
the params (and possibly) buffers keyword arguments.
|
|
106
|
+
in_keys (iterable of str): keys to be read from input tensordict and
|
|
107
|
+
passed to the module. If it
|
|
108
|
+
contains more than one element, the values will be passed in the
|
|
109
|
+
order given by the in_keys iterable.
|
|
110
|
+
out_keys (iterable of str): keys to be written to the input tensordict.
|
|
111
|
+
The length of out_keys must match the
|
|
112
|
+
number of tensors returned by the embedded module. Using "_" as a
|
|
113
|
+
key avoid writing tensor to output.
|
|
114
|
+
spec (TensorSpec, optional): specs of the output tensor. If the module
|
|
115
|
+
outputs multiple output tensors,
|
|
116
|
+
spec characterize the space of the first output tensor.
|
|
117
|
+
safe (bool): if ``True``, the value of the output is checked against the
|
|
118
|
+
input spec. Out-of-domain sampling can
|
|
119
|
+
occur because of exploration policies or numerical under/overflow issues.
|
|
120
|
+
If this value is out of bounds, it is projected back onto the
|
|
121
|
+
desired space using the :obj:`TensorSpec.project`
|
|
122
|
+
method. Default is ``False``.
|
|
123
|
+
inplace (bool or str, optional): if `True`, the input tensordict is modified in-place. If `False`, a new empty
|
|
124
|
+
:class:`~tensordict.TensorDict` instance is created. If `"empty"`, `input.empty()` is used instead (ie, the
|
|
125
|
+
output preserves type, device and batch-size). Defaults to `True`.
|
|
126
|
+
|
|
127
|
+
Embedding a neural network in a TensorDictModule only requires to specify the input and output keys. The domain spec can
|
|
128
|
+
be passed along if needed. TensorDictModule support functional and regular :obj:`nn.Module` objects. In the functional
|
|
129
|
+
case, the 'params' (and 'buffers') keyword argument must be specified:
|
|
130
|
+
|
|
131
|
+
Examples:
|
|
132
|
+
>>> import torch
|
|
133
|
+
>>> from tensordict import TensorDict
|
|
134
|
+
>>> from torchrl.data import Unbounded
|
|
135
|
+
>>> from torchrl.modules import TensorDictModule
|
|
136
|
+
>>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,])
|
|
137
|
+
>>> spec = Unbounded(8)
|
|
138
|
+
>>> module = torch.nn.GRUCell(4, 8)
|
|
139
|
+
>>> td_fmodule = TensorDictModule(
|
|
140
|
+
... module=module,
|
|
141
|
+
... spec=spec,
|
|
142
|
+
... in_keys=["input", "hidden"],
|
|
143
|
+
... out_keys=["output"],
|
|
144
|
+
... )
|
|
145
|
+
>>> params = TensorDict.from_module(td_fmodule)
|
|
146
|
+
>>> with params.to_module(td_module):
|
|
147
|
+
... td_functional = td_fmodule(td.clone())
|
|
148
|
+
>>> print(td_functional)
|
|
149
|
+
TensorDict(
|
|
150
|
+
fields={
|
|
151
|
+
hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32),
|
|
152
|
+
input: Tensor(torch.Size([3, 4]), dtype=torch.float32),
|
|
153
|
+
output: Tensor(torch.Size([3, 8]), dtype=torch.float32)},
|
|
154
|
+
batch_size=torch.Size([3]),
|
|
155
|
+
device=None,
|
|
156
|
+
is_shared=False)
|
|
157
|
+
|
|
158
|
+
In the stateful case:
|
|
159
|
+
>>> td_module = TensorDictModule(
|
|
160
|
+
... module=torch.nn.GRUCell(4, 8),
|
|
161
|
+
... spec=spec,
|
|
162
|
+
... in_keys=["input", "hidden"],
|
|
163
|
+
... out_keys=["output"],
|
|
164
|
+
... )
|
|
165
|
+
>>> td_stateful = td_module(td.clone())
|
|
166
|
+
>>> print(td_stateful)
|
|
167
|
+
TensorDict(
|
|
168
|
+
fields={
|
|
169
|
+
hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32),
|
|
170
|
+
input: Tensor(torch.Size([3, 4]), dtype=torch.float32),
|
|
171
|
+
output: Tensor(torch.Size([3, 8]), dtype=torch.float32)},
|
|
172
|
+
batch_size=torch.Size([3]),
|
|
173
|
+
device=None,
|
|
174
|
+
is_shared=False)
|
|
175
|
+
|
|
176
|
+
One can use a vmap operator to call the functional module. In this case the tensordict is expanded to match the
|
|
177
|
+
batch size (i.e. the tensordict isn't modified in-place anymore):
|
|
178
|
+
>>> # Model ensemble using vmap
|
|
179
|
+
>>> from torch import vmap
|
|
180
|
+
>>> params_repeat = params.expand(4, *params.shape)
|
|
181
|
+
>>> td_vmap = vmap(td_fmodule, (None, 0))(td.clone(), params_repeat)
|
|
182
|
+
>>> print(td_vmap)
|
|
183
|
+
TensorDict(
|
|
184
|
+
fields={
|
|
185
|
+
hidden: Tensor(torch.Size([4, 3, 8]), dtype=torch.float32),
|
|
186
|
+
input: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32),
|
|
187
|
+
output: Tensor(torch.Size([4, 3, 8]), dtype=torch.float32)},
|
|
188
|
+
batch_size=torch.Size([4, 3]),
|
|
189
|
+
device=None,
|
|
190
|
+
is_shared=False)
|
|
191
|
+
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
def __init__(
|
|
195
|
+
self,
|
|
196
|
+
module: (
|
|
197
|
+
FunctionalModule
|
|
198
|
+
| FunctionalModuleWithBuffers
|
|
199
|
+
| TensorDictModule
|
|
200
|
+
| nn.Module
|
|
201
|
+
),
|
|
202
|
+
in_keys: Iterable[str],
|
|
203
|
+
out_keys: Iterable[str],
|
|
204
|
+
spec: TensorSpec | None = None,
|
|
205
|
+
safe: bool = False,
|
|
206
|
+
inplace: bool | str = True,
|
|
207
|
+
):
|
|
208
|
+
super().__init__(module, in_keys, out_keys, inplace=inplace)
|
|
209
|
+
self.register_spec(safe=safe, spec=spec)
|
|
210
|
+
|
|
211
|
+
def register_spec(self, safe, spec):
|
|
212
|
+
if spec is not None:
|
|
213
|
+
spec = spec.clone()
|
|
214
|
+
if spec is not None and not isinstance(spec, TensorSpec):
|
|
215
|
+
raise TypeError("spec must be a TensorSpec subclass")
|
|
216
|
+
elif spec is not None and not isinstance(spec, Composite):
|
|
217
|
+
if len(self.out_keys) > 1:
|
|
218
|
+
raise RuntimeError(
|
|
219
|
+
f"got more than one out_key for the TensorDictModule: {self.out_keys},\nbut only one spec. "
|
|
220
|
+
"Consider using a Composite object or no spec at all."
|
|
221
|
+
)
|
|
222
|
+
spec = Composite({self.out_keys[0]: spec})
|
|
223
|
+
elif spec is not None and isinstance(spec, Composite):
|
|
224
|
+
if "_" in spec.keys() and spec["_"] is not None:
|
|
225
|
+
warnings.warn('got a spec with key "_": it will be ignored')
|
|
226
|
+
elif spec is None:
|
|
227
|
+
spec = Composite()
|
|
228
|
+
|
|
229
|
+
# unravel_key_list(self.out_keys) can be removed once 473 is merged in tensordict
|
|
230
|
+
spec_keys = set(unravel_key_list(list(spec.keys(True, True))))
|
|
231
|
+
out_keys = set(unravel_key_list(self.out_keys))
|
|
232
|
+
if spec_keys != out_keys:
|
|
233
|
+
# then assume that all the non indicated specs are None
|
|
234
|
+
for key in out_keys:
|
|
235
|
+
if key not in spec_keys:
|
|
236
|
+
spec[key] = None
|
|
237
|
+
spec_keys = set(unravel_key_list(list(spec.keys(True, True))))
|
|
238
|
+
if spec_keys != out_keys:
|
|
239
|
+
raise RuntimeError(
|
|
240
|
+
f"spec keys and out_keys do not match, got: {spec_keys} and {out_keys} respectively"
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
self._spec = spec
|
|
244
|
+
self.safe = safe
|
|
245
|
+
if safe:
|
|
246
|
+
if spec is None or (
|
|
247
|
+
isinstance(spec, Composite)
|
|
248
|
+
and all(_spec is None for _spec in spec.values())
|
|
249
|
+
):
|
|
250
|
+
raise RuntimeError(
|
|
251
|
+
"`TensorDictModule(spec=None, safe=True)` is not a valid configuration as the tensor "
|
|
252
|
+
"specs are not specified"
|
|
253
|
+
)
|
|
254
|
+
self.register_forward_hook(_forward_hook_safe_action)
|
|
255
|
+
|
|
256
|
+
@property
|
|
257
|
+
def spec(self) -> Composite:
|
|
258
|
+
return self._spec
|
|
259
|
+
|
|
260
|
+
@spec.setter
|
|
261
|
+
def spec(self, spec: Composite) -> None:
|
|
262
|
+
if not isinstance(spec, Composite):
|
|
263
|
+
raise RuntimeError(
|
|
264
|
+
f"Trying to set an object of type {type(spec)} as a tensorspec but expected a Composite instance."
|
|
265
|
+
)
|
|
266
|
+
self._spec = spec
|
|
267
|
+
|
|
268
|
+
def random(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
269
|
+
"""Samples a random element in the target space, irrespective of any input.
|
|
270
|
+
|
|
271
|
+
If multiple output keys are present, only the first will be written in the input :obj:`tensordict`.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
tensordict (TensorDictBase): tensordict where the output value should be written.
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
the original tensordict with a new/updated value for the output key.
|
|
278
|
+
|
|
279
|
+
"""
|
|
280
|
+
key0 = self.out_keys[0]
|
|
281
|
+
tensordict.set(key0, self.spec.rand(tensordict.batch_size))
|
|
282
|
+
return tensordict
|
|
283
|
+
|
|
284
|
+
def random_sample(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
285
|
+
"""See :obj:`TensorDictModule.random(...)`."""
|
|
286
|
+
return self.random(tensordict)
|
|
287
|
+
|
|
288
|
+
def to(self, dest: torch.dtype | DEVICE_TYPING) -> TensorDictModule:
|
|
289
|
+
if hasattr(self, "spec") and self.spec is not None:
|
|
290
|
+
self.spec = self.spec.to(dest)
|
|
291
|
+
out = super().to(dest)
|
|
292
|
+
return out
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def is_tensordict_compatible(module: TensorDictModule | nn.Module):
|
|
296
|
+
"""Returns `True` if a module can be used as a TensorDictModule, and False if it can't.
|
|
297
|
+
|
|
298
|
+
If the signature is misleading an error is raised.
|
|
299
|
+
|
|
300
|
+
Examples:
|
|
301
|
+
>>> module = nn.Linear(3, 4)
|
|
302
|
+
>>> is_tensordict_compatible(module)
|
|
303
|
+
False
|
|
304
|
+
>>> class CustomModule(nn.Module):
|
|
305
|
+
... def __init__(self, module):
|
|
306
|
+
... super().__init__()
|
|
307
|
+
... self.linear = module
|
|
308
|
+
... self.in_keys = ["x"]
|
|
309
|
+
... self.out_keys = ["y"]
|
|
310
|
+
... def forward(self, tensordict):
|
|
311
|
+
... tensordict["y"] = self.linear(tensordict["x"])
|
|
312
|
+
... return tensordict
|
|
313
|
+
>>> tensordict_module = CustomModule(module)
|
|
314
|
+
>>> is_tensordict_compatible(tensordict_module)
|
|
315
|
+
True
|
|
316
|
+
>>> class CustomModule(nn.Module):
|
|
317
|
+
... def __init__(self, module):
|
|
318
|
+
... super().__init__()
|
|
319
|
+
... self.linear = module
|
|
320
|
+
... self.in_keys = ["x"]
|
|
321
|
+
... self.out_keys = ["y"]
|
|
322
|
+
... def forward(self, tensordict, other_key):
|
|
323
|
+
... tensordict["y"] = self.linear(tensordict["x"])
|
|
324
|
+
... return tensordict
|
|
325
|
+
>>> tensordict_module = CustomModule(module)
|
|
326
|
+
>>> try:
|
|
327
|
+
... is_tensordict_compatible(tensordict_module)
|
|
328
|
+
... except TypeError:
|
|
329
|
+
... print("passing")
|
|
330
|
+
passing
|
|
331
|
+
"""
|
|
332
|
+
sig = inspect.signature(module.forward)
|
|
333
|
+
|
|
334
|
+
if isinstance(module, TensorDictModule) or (
|
|
335
|
+
len(sig.parameters) == 1
|
|
336
|
+
and hasattr(module, "in_keys")
|
|
337
|
+
and hasattr(module, "out_keys")
|
|
338
|
+
):
|
|
339
|
+
# if the module is a TensorDictModule or takes a single argument and defines
|
|
340
|
+
# in_keys and out_keys then we assume it can already deal with TensorDict input
|
|
341
|
+
# to forward and we return True
|
|
342
|
+
return True
|
|
343
|
+
elif not hasattr(module, "in_keys") and not hasattr(module, "out_keys"):
|
|
344
|
+
# if it's not a TensorDictModule, and in_keys and out_keys are not defined then
|
|
345
|
+
# we assume no TensorDict compatibility and will try to wrap it.
|
|
346
|
+
return False
|
|
347
|
+
|
|
348
|
+
# if in_keys or out_keys were defined but module is not a TensorDictModule or
|
|
349
|
+
# accepts multiple arguments then it's likely the user is trying to do something
|
|
350
|
+
# that will have undetermined behavior, we raise an error
|
|
351
|
+
raise TypeError(
|
|
352
|
+
"Received a module that defines in_keys or out_keys and also expects multiple "
|
|
353
|
+
"arguments to module.forward. If the module is compatible with TensorDict, it "
|
|
354
|
+
"should take a single argument of type TensorDict to module.forward and define "
|
|
355
|
+
"both in_keys and out_keys. Alternatively, module.forward can accept "
|
|
356
|
+
"arbitrarily many tensor inputs and leave in_keys and out_keys undefined and "
|
|
357
|
+
"TorchRL will attempt to automatically wrap the module with a TensorDictModule."
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def ensure_tensordict_compatible(
|
|
362
|
+
module: (
|
|
363
|
+
FunctionalModule | FunctionalModuleWithBuffers | TensorDictModule | nn.Module
|
|
364
|
+
),
|
|
365
|
+
in_keys: list[NestedKey] | None = None,
|
|
366
|
+
out_keys: list[NestedKey] | None = None,
|
|
367
|
+
safe: bool = False,
|
|
368
|
+
wrapper_type: type | None = TensorDictModule,
|
|
369
|
+
**kwargs,
|
|
370
|
+
):
|
|
371
|
+
"""Ensures module is compatible with TensorDictModule and, if not, it wraps it."""
|
|
372
|
+
in_keys = unravel_key_list(in_keys) if in_keys else in_keys
|
|
373
|
+
out_keys = unravel_key_list(out_keys) if out_keys else out_keys
|
|
374
|
+
|
|
375
|
+
"""Checks and ensures an object with forward method is TensorDict compatible."""
|
|
376
|
+
if is_tensordict_compatible(module):
|
|
377
|
+
if in_keys is not None and set(in_keys) != set(module.in_keys):
|
|
378
|
+
raise TypeError(
|
|
379
|
+
f"Arguments to module.forward ({set(module.in_keys)}) doesn't match "
|
|
380
|
+
f"with the expected TensorDict in_keys ({set(in_keys)})."
|
|
381
|
+
)
|
|
382
|
+
if out_keys is not None and set(module.out_keys) != set(out_keys):
|
|
383
|
+
raise TypeError(
|
|
384
|
+
f"Outputs of module.forward ({set(module.out_keys)}) doesn't match "
|
|
385
|
+
f"with the expected TensorDict out_keys ({set(out_keys)})."
|
|
386
|
+
)
|
|
387
|
+
# return module itself if it's already tensordict compatible
|
|
388
|
+
return module
|
|
389
|
+
|
|
390
|
+
if not isinstance(module, nn.Module):
|
|
391
|
+
raise TypeError(
|
|
392
|
+
"Argument to ensure_tensordict_compatible should be either "
|
|
393
|
+
"a TensorDictModule or an nn.Module"
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
sig = inspect.signature(module.forward)
|
|
397
|
+
if in_keys is not None and set(sig.parameters) != set(in_keys):
|
|
398
|
+
raise TypeError(
|
|
399
|
+
"Arguments to module.forward are incompatible with entries in "
|
|
400
|
+
"env.observation_spec. If you want TorchRL to automatically "
|
|
401
|
+
"wrap your module with a TensorDictModule then the arguments "
|
|
402
|
+
"to module must correspond one-to-one with entries in "
|
|
403
|
+
"in_keys. For more complex behavior and more control you can "
|
|
404
|
+
"consider writing your own TensorDictModule."
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
# TODO: Check whether out_keys match (at least in number) if they are provided.
|
|
408
|
+
if in_keys is not None:
|
|
409
|
+
kwargs["in_keys"] = in_keys
|
|
410
|
+
if out_keys is not None:
|
|
411
|
+
kwargs["out_keys"] = out_keys
|
|
412
|
+
return wrapper_type(module, **kwargs)
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
class VmapModule(TensorDictModuleBase):
|
|
416
|
+
"""A TensorDictModule wrapper to vmap over the input.
|
|
417
|
+
|
|
418
|
+
It is intended to be used with modules that accept data with one less batch
|
|
419
|
+
dimension than the one provided. By using this wrapper, one can hide a
|
|
420
|
+
batch dimension and satisfy the wrapped module.
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
module (TensorDictModuleBase): the module to vmap over.
|
|
424
|
+
vmap_dim (int, optional): the vmap input and output dim.
|
|
425
|
+
If none is provided, the last dimension of the tensordict is
|
|
426
|
+
assumed.
|
|
427
|
+
|
|
428
|
+
.. note::
|
|
429
|
+
|
|
430
|
+
Since vmap requires to have control over the batch size of the input
|
|
431
|
+
this module does not support dispatched arguments
|
|
432
|
+
|
|
433
|
+
Example:
|
|
434
|
+
>>> lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"])
|
|
435
|
+
>>> sample_in = torch.ones((10,3,2))
|
|
436
|
+
>>> sample_in_td = TensorDict({"x":sample_in}, batch_size=[10])
|
|
437
|
+
>>> lam(sample_in)
|
|
438
|
+
>>> vm = VmapModule(lam, 0)
|
|
439
|
+
>>> vm(sample_in_td)
|
|
440
|
+
>>> assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all()
|
|
441
|
+
"""
|
|
442
|
+
|
|
443
|
+
def __init__(self, module: TensorDictModuleBase, vmap_dim=None, mock: bool = False):
|
|
444
|
+
if not _has_functorch:
|
|
445
|
+
raise ImportError("VmapModule requires torch>=2.0.")
|
|
446
|
+
super().__init__()
|
|
447
|
+
self.in_keys = module.in_keys
|
|
448
|
+
self.out_keys = module.out_keys
|
|
449
|
+
self.module = module
|
|
450
|
+
self.vmap_dim = vmap_dim
|
|
451
|
+
self.mock = mock
|
|
452
|
+
if torch.__version__ >= "2.0":
|
|
453
|
+
self._vmap = torch.vmap
|
|
454
|
+
else:
|
|
455
|
+
import functorch
|
|
456
|
+
|
|
457
|
+
self._vmap = functorch.vmap
|
|
458
|
+
|
|
459
|
+
def mock_(self, value: bool = True):
|
|
460
|
+
self.mock = value
|
|
461
|
+
|
|
462
|
+
def forward(self, tensordict):
|
|
463
|
+
# TODO: there is a risk of segfault if input is not a tensordict.
|
|
464
|
+
# We should investigate (possibly prevent it c++ side?)
|
|
465
|
+
vmap_dim = self.vmap_dim
|
|
466
|
+
if vmap_dim is None:
|
|
467
|
+
ndim = tensordict.ndim
|
|
468
|
+
vmap_dim = ndim - 1
|
|
469
|
+
if self.mock:
|
|
470
|
+
td = torch.stack(
|
|
471
|
+
[self.module(_td) for _td in tensordict.unbind(vmap_dim)], vmap_dim
|
|
472
|
+
)
|
|
473
|
+
else:
|
|
474
|
+
td = self._vmap(self.module, (vmap_dim,), (vmap_dim,))(tensordict)
|
|
475
|
+
return tensordict.update(td)
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
class DistributionalDQNnet(TensorDictModuleBase):
|
|
479
|
+
"""Distributional Deep Q-Network softmax layer.
|
|
480
|
+
|
|
481
|
+
This layer should be used in between a regular model that predicts the
|
|
482
|
+
action values and a distribution which acts on logits values.
|
|
483
|
+
|
|
484
|
+
Args:
|
|
485
|
+
in_keys (list of str or tuples of str): input keys to the log-softmax
|
|
486
|
+
operation. Defaults to ``["action_value"]``.
|
|
487
|
+
out_keys (list of str or tuples of str): output keys to the log-softmax
|
|
488
|
+
operation. Defaults to ``["action_value"]``.
|
|
489
|
+
|
|
490
|
+
Examples:
|
|
491
|
+
>>> import torch
|
|
492
|
+
>>> from tensordict import TensorDict
|
|
493
|
+
>>> net = DistributionalDQNnet()
|
|
494
|
+
>>> td = TensorDict({"action_value": torch.randn(10, 5)}, batch_size=[10])
|
|
495
|
+
>>> net(td)
|
|
496
|
+
TensorDict(
|
|
497
|
+
fields={
|
|
498
|
+
action_value: Tensor(shape=torch.Size([10, 5]), device=cpu, dtype=torch.float32, is_shared=False)},
|
|
499
|
+
batch_size=torch.Size([10]),
|
|
500
|
+
device=None,
|
|
501
|
+
is_shared=False)
|
|
502
|
+
|
|
503
|
+
"""
|
|
504
|
+
|
|
505
|
+
_wrong_out_feature_dims_error = (
|
|
506
|
+
"DistributionalDQNnet requires dqn output to be at least "
|
|
507
|
+
"2-dimensional, with dimensions *Batch x #Atoms x #Actions. Got {0} "
|
|
508
|
+
"instead."
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
def __init__(self, *, in_keys=None, out_keys=None):
|
|
512
|
+
super().__init__()
|
|
513
|
+
if in_keys is None:
|
|
514
|
+
in_keys = ["action_value"]
|
|
515
|
+
if out_keys is None:
|
|
516
|
+
out_keys = ["action_value"]
|
|
517
|
+
self.in_keys = in_keys
|
|
518
|
+
self.out_keys = out_keys
|
|
519
|
+
|
|
520
|
+
@dispatch(auto_batch_size=False)
|
|
521
|
+
def forward(self, tensordict):
|
|
522
|
+
for in_key, out_key in zip(self.in_keys, self.out_keys):
|
|
523
|
+
q_values = tensordict.get(in_key)
|
|
524
|
+
if q_values.ndimension() < 2:
|
|
525
|
+
raise RuntimeError(
|
|
526
|
+
self._wrong_out_feature_dims_error.format(q_values.shape)
|
|
527
|
+
)
|
|
528
|
+
tensordict.set(out_key, F.log_softmax(q_values, dim=-2))
|
|
529
|
+
return tensordict
|