torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from tensordict.nn import TensorDictModule, TensorDictSequential
|
|
9
|
+
from torch import nn
|
|
10
|
+
|
|
11
|
+
from torchrl.data.tensor_specs import Composite
|
|
12
|
+
from torchrl.modules.tensordict_module.common import SafeModule
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SafeSequential(TensorDictSequential, SafeModule):
|
|
16
|
+
"""A safe sequence of TensorDictModules.
|
|
17
|
+
|
|
18
|
+
Similarly to :obj:`nn.Sequence` which passes a tensor through a chain of mappings that read and write a single tensor
|
|
19
|
+
each, this module will read and write over a tensordict by querying each of the input modules.
|
|
20
|
+
When calling a :obj:`TensorDictSequential` instance with a functional module, it is expected that the parameter lists (and
|
|
21
|
+
buffers) will be concatenated in a single list.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
modules (iterable of TensorDictModules): ordered sequence of TensorDictModule instances to be run sequentially.
|
|
25
|
+
partial_tolerant (bool, optional): if ``True``, the input tensordict can miss some of the input keys.
|
|
26
|
+
If so, the only modules that will be executed are those which can be executed given the keys that
|
|
27
|
+
are present.
|
|
28
|
+
Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is ``True`` AND if the
|
|
29
|
+
stack does not have the required keys, then SafeSequential will scan through the sub-tensordicts
|
|
30
|
+
looking for those that have the required keys, if any.
|
|
31
|
+
inplace (bool or str, optional): if `True`, the input tensordict is modified in-place. If `False`, a new empty
|
|
32
|
+
:class:`~tensordict.TensorDict` instance is created. If `"empty"`, `input.empty()` is used instead (ie, the
|
|
33
|
+
output preserves type, device and batch-size). Defaults to `None` (relies on sub-modules).
|
|
34
|
+
|
|
35
|
+
TensorDictSequence supports functional, modular and vmap coding:
|
|
36
|
+
Examples:
|
|
37
|
+
>>> import torch
|
|
38
|
+
>>> from tensordict import TensorDict
|
|
39
|
+
>>> from torchrl.data import Composite, Unbounded
|
|
40
|
+
>>> from torchrl.modules import TanhNormal, SafeSequential, TensorDictModule, NormalParamExtractor
|
|
41
|
+
>>> from torchrl.modules.tensordict_module import SafeProbabilisticModule
|
|
42
|
+
>>> td = TensorDict({"input": torch.randn(3, 4)}, [3,])
|
|
43
|
+
>>> spec1 = Composite(hidden=Unbounded(4), loc=None, scale=None)
|
|
44
|
+
>>> net1 = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor())
|
|
45
|
+
>>> module1 = TensorDictModule(net1, in_keys=["input"], out_keys=["loc", "scale"])
|
|
46
|
+
>>> td_module1 = SafeProbabilisticModule(
|
|
47
|
+
... module=module1,
|
|
48
|
+
... spec=spec1,
|
|
49
|
+
... in_keys=["loc", "scale"],
|
|
50
|
+
... out_keys=["hidden"],
|
|
51
|
+
... distribution_class=TanhNormal,
|
|
52
|
+
... return_log_prob=True,
|
|
53
|
+
... )
|
|
54
|
+
>>> spec2 = Unbounded(8)
|
|
55
|
+
>>> module2 = torch.nn.Linear(4, 8)
|
|
56
|
+
>>> td_module2 = TensorDictModule(
|
|
57
|
+
... module=module2,
|
|
58
|
+
... spec=spec2,
|
|
59
|
+
... in_keys=["hidden"],
|
|
60
|
+
... out_keys=["output"],
|
|
61
|
+
... )
|
|
62
|
+
>>> td_module = SafeSequential(td_module1, td_module2)
|
|
63
|
+
>>> params = TensorDict.from_module(td_module)
|
|
64
|
+
>>> with params.to_module(td_module):
|
|
65
|
+
... td_module(td)
|
|
66
|
+
>>> print(td)
|
|
67
|
+
TensorDict(
|
|
68
|
+
fields={
|
|
69
|
+
hidden: Tensor(torch.Size([3, 4]), dtype=torch.float32),
|
|
70
|
+
input: Tensor(torch.Size([3, 4]), dtype=torch.float32),
|
|
71
|
+
loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),
|
|
72
|
+
output: Tensor(torch.Size([3, 8]), dtype=torch.float32),
|
|
73
|
+
sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32),
|
|
74
|
+
scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)},
|
|
75
|
+
batch_size=torch.Size([3]),
|
|
76
|
+
device=None,
|
|
77
|
+
is_shared=False)
|
|
78
|
+
>>> # The module spec aggregates all the input specs:
|
|
79
|
+
>>> print(td_module.spec)
|
|
80
|
+
Composite(
|
|
81
|
+
hidden: UnboundedContinuous(
|
|
82
|
+
shape=torch.Size([4]), space=None, device=cpu, dtype=torch.float32, domain=continuous),
|
|
83
|
+
loc: None,
|
|
84
|
+
scale: None,
|
|
85
|
+
output: UnboundedContinuous(
|
|
86
|
+
shape=torch.Size([8]), space=None, device=cpu, dtype=torch.float32, domain=continuous))
|
|
87
|
+
|
|
88
|
+
In the vmap case:
|
|
89
|
+
>>> from torch import vmap
|
|
90
|
+
>>> params = params.expand(4, *params.shape)
|
|
91
|
+
>>> td_vmap = vmap(td_module, (None, 0))(td, params)
|
|
92
|
+
>>> print(td_vmap)
|
|
93
|
+
TensorDict(
|
|
94
|
+
fields={
|
|
95
|
+
hidden: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32),
|
|
96
|
+
input: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32),
|
|
97
|
+
loc: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32),
|
|
98
|
+
output: Tensor(torch.Size([4, 3, 8]), dtype=torch.float32),
|
|
99
|
+
sample_log_prob: Tensor(torch.Size([4, 3, 1]), dtype=torch.float32),
|
|
100
|
+
scale: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32)},
|
|
101
|
+
batch_size=torch.Size([4, 3]),
|
|
102
|
+
device=None,
|
|
103
|
+
is_shared=False)
|
|
104
|
+
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
module: nn.ModuleList
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
*modules: TensorDictModule,
|
|
112
|
+
partial_tolerant: bool = False,
|
|
113
|
+
inplace: bool | str | None = None,
|
|
114
|
+
):
|
|
115
|
+
self.partial_tolerant = partial_tolerant
|
|
116
|
+
|
|
117
|
+
in_keys, out_keys = self._compute_in_and_out_keys(modules)
|
|
118
|
+
|
|
119
|
+
spec = Composite()
|
|
120
|
+
for module in modules:
|
|
121
|
+
try:
|
|
122
|
+
spec.update(module.spec)
|
|
123
|
+
except AttributeError:
|
|
124
|
+
spec.update(Composite({key: None for key in module.out_keys}))
|
|
125
|
+
|
|
126
|
+
super(TensorDictSequential, self).__init__(
|
|
127
|
+
spec=spec,
|
|
128
|
+
module=nn.ModuleList(list(modules)),
|
|
129
|
+
in_keys=in_keys,
|
|
130
|
+
out_keys=out_keys,
|
|
131
|
+
inplace=inplace,
|
|
132
|
+
)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from tensordict.nn import TensorDictModule, TensorDictSequential
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class WorldModelWrapper(TensorDictSequential):
|
|
11
|
+
"""World model wrapper.
|
|
12
|
+
|
|
13
|
+
This module wraps together a transition model and a reward model.
|
|
14
|
+
The transition model is used to predict an imaginary world state.
|
|
15
|
+
The reward model is used to predict the reward of the imagined transition.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
transition_model (TensorDictModule): a transition model that generates a new world states.
|
|
19
|
+
reward_model (TensorDictModule): a reward model, that reads the world state and returns a reward.
|
|
20
|
+
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self, transition_model: TensorDictModule, reward_model: TensorDictModule
|
|
25
|
+
):
|
|
26
|
+
super().__init__(transition_model, reward_model)
|
|
27
|
+
|
|
28
|
+
def get_transition_model_operator(self) -> TensorDictModule:
|
|
29
|
+
"""Returns a transition operator that maps either an observation to a world state or a world state to the next world state."""
|
|
30
|
+
return self.module[0]
|
|
31
|
+
|
|
32
|
+
def get_reward_operator(self) -> TensorDictModule:
|
|
33
|
+
"""Returns a reward operator that maps a world state to a reward."""
|
|
34
|
+
return self.module[1]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from collections import OrderedDict
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from packaging import version
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
|
13
|
+
from torch.nn.parameter import _ParameterMeta
|
|
14
|
+
else:
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
# Metaclass to combine _TensorMeta and the instance check override for Parameter.
|
|
18
|
+
class _ParameterMeta(torch._C._TensorMeta):
|
|
19
|
+
# Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag.
|
|
20
|
+
def __instancecheck__(self, instance):
|
|
21
|
+
return super().__instancecheck__(instance) or (
|
|
22
|
+
isinstance(instance, torch.Tensor)
|
|
23
|
+
and getattr(instance, "_is_param", False)
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
from .mappings import biased_softplus, inv_softplus, mappings
|
|
28
|
+
from .utils import get_primers_from_module
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"OrderedDict",
|
|
32
|
+
"torch",
|
|
33
|
+
"version",
|
|
34
|
+
"biased_softplus",
|
|
35
|
+
"inv_softplus",
|
|
36
|
+
"mappings",
|
|
37
|
+
"get_primers_from_module",
|
|
38
|
+
]
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from tensordict.nn.utils import biased_softplus, expln, inv_softplus, mappings
|
|
8
|
+
|
|
9
|
+
__all__ = ["biased_softplus", "expln", "inv_softplus", "mappings"]
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import warnings
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from tensordict.utils import expand_as_right
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_primers_from_module(module):
|
|
14
|
+
"""Get all tensordict primers from all submodules of a module.
|
|
15
|
+
|
|
16
|
+
This method is useful for retrieving primers from modules that are contained within a
|
|
17
|
+
parent module.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
module (torch.nn.Module): The parent module.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
TensorDictPrimer: A TensorDictPrimer Transform.
|
|
24
|
+
|
|
25
|
+
Example:
|
|
26
|
+
>>> from torchrl.modules.utils import get_primers_from_module
|
|
27
|
+
>>> from torchrl.modules import GRUModule, MLP
|
|
28
|
+
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
|
|
29
|
+
>>> # Define a GRU module
|
|
30
|
+
>>> gru_module = GRUModule(
|
|
31
|
+
... input_size=10,
|
|
32
|
+
... hidden_size=10,
|
|
33
|
+
... num_layers=1,
|
|
34
|
+
... in_keys=["input", "recurrent_state", "is_init"],
|
|
35
|
+
... out_keys=["features", ("next", "recurrent_state")],
|
|
36
|
+
... )
|
|
37
|
+
>>> # Define a head module
|
|
38
|
+
>>> head = TensorDictModule(
|
|
39
|
+
... MLP(
|
|
40
|
+
... in_features=10,
|
|
41
|
+
... out_features=10,
|
|
42
|
+
... num_cells=[],
|
|
43
|
+
... ),
|
|
44
|
+
... in_keys=["features"],
|
|
45
|
+
... out_keys=["output"],
|
|
46
|
+
... )
|
|
47
|
+
>>> # Create a sequential model
|
|
48
|
+
>>> model = TensorDictSequential(gru_module, head)
|
|
49
|
+
>>> # Retrieve primers from the model
|
|
50
|
+
>>> primers = get_primers_from_module(model)
|
|
51
|
+
>>> print(primers)
|
|
52
|
+
|
|
53
|
+
TensorDictPrimer(primers=Composite(
|
|
54
|
+
recurrent_state: UnboundedContinuous(
|
|
55
|
+
shape=torch.Size([1, 10]),
|
|
56
|
+
space=None,
|
|
57
|
+
device=cpu,
|
|
58
|
+
dtype=torch.float32,
|
|
59
|
+
domain=continuous), device=None, shape=torch.Size([])), default_value={'recurrent_state': 0.0}, random=None)
|
|
60
|
+
|
|
61
|
+
"""
|
|
62
|
+
primers = []
|
|
63
|
+
|
|
64
|
+
def make_primers(submodule):
|
|
65
|
+
if hasattr(submodule, "make_tensordict_primer"):
|
|
66
|
+
primers.append(submodule.make_tensordict_primer())
|
|
67
|
+
|
|
68
|
+
module.apply(make_primers)
|
|
69
|
+
if not primers:
|
|
70
|
+
warnings.warn("No primers found in the module.")
|
|
71
|
+
return
|
|
72
|
+
elif len(primers) == 1:
|
|
73
|
+
return primers[0]
|
|
74
|
+
else:
|
|
75
|
+
from torchrl.envs.transforms import Compose
|
|
76
|
+
|
|
77
|
+
return Compose(*primers)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _unpad_tensors(tensors, mask, as_nested: bool = True) -> torch.Tensor:
|
|
81
|
+
shape = tensors.shape[2:]
|
|
82
|
+
mask = expand_as_right(mask.bool(), tensors)
|
|
83
|
+
nelts = mask.sum(-1)
|
|
84
|
+
while nelts.dim() > 1:
|
|
85
|
+
nelts = nelts.sum(-1)
|
|
86
|
+
vals = [t.view(-1, *shape) for t in tensors[mask].split(nelts.tolist(), dim=0)]
|
|
87
|
+
if as_nested:
|
|
88
|
+
return torch.nested.as_nested_tensor(vals)
|
|
89
|
+
return vals
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from torchrl.objectives.a2c import A2CLoss
|
|
7
|
+
from torchrl.objectives.common import add_random_module, LossModule
|
|
8
|
+
from torchrl.objectives.cql import CQLLoss, DiscreteCQLLoss
|
|
9
|
+
from torchrl.objectives.crossq import CrossQLoss
|
|
10
|
+
from torchrl.objectives.ddpg import DDPGLoss
|
|
11
|
+
from torchrl.objectives.decision_transformer import DTLoss, OnlineDTLoss
|
|
12
|
+
from torchrl.objectives.dqn import DistributionalDQNLoss, DQNLoss
|
|
13
|
+
from torchrl.objectives.dreamer import (
|
|
14
|
+
DreamerActorLoss,
|
|
15
|
+
DreamerModelLoss,
|
|
16
|
+
DreamerValueLoss,
|
|
17
|
+
)
|
|
18
|
+
from torchrl.objectives.gail import GAILLoss
|
|
19
|
+
from torchrl.objectives.iql import DiscreteIQLLoss, IQLLoss
|
|
20
|
+
from torchrl.objectives.multiagent import QMixerLoss
|
|
21
|
+
from torchrl.objectives.ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss
|
|
22
|
+
from torchrl.objectives.redq import REDQLoss
|
|
23
|
+
from torchrl.objectives.reinforce import ReinforceLoss
|
|
24
|
+
from torchrl.objectives.sac import DiscreteSACLoss, SACLoss
|
|
25
|
+
from torchrl.objectives.td3 import TD3Loss
|
|
26
|
+
from torchrl.objectives.td3_bc import TD3BCLoss
|
|
27
|
+
from torchrl.objectives.utils import (
|
|
28
|
+
default_value_kwargs,
|
|
29
|
+
distance_loss,
|
|
30
|
+
group_optimizers,
|
|
31
|
+
HardUpdate,
|
|
32
|
+
hold_out_net,
|
|
33
|
+
hold_out_params,
|
|
34
|
+
next_state_value,
|
|
35
|
+
SoftUpdate,
|
|
36
|
+
TargetNetUpdater,
|
|
37
|
+
ValueEstimators,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
__all__ = [
|
|
41
|
+
"A2CLoss",
|
|
42
|
+
"CQLLoss",
|
|
43
|
+
"ClipPPOLoss",
|
|
44
|
+
"CrossQLoss",
|
|
45
|
+
"DDPGLoss",
|
|
46
|
+
"DQNLoss",
|
|
47
|
+
"DTLoss",
|
|
48
|
+
"DiscreteCQLLoss",
|
|
49
|
+
"DiscreteIQLLoss",
|
|
50
|
+
"DiscreteSACLoss",
|
|
51
|
+
"DistributionalDQNLoss",
|
|
52
|
+
"DreamerActorLoss",
|
|
53
|
+
"DreamerModelLoss",
|
|
54
|
+
"DreamerValueLoss",
|
|
55
|
+
"GAILLoss",
|
|
56
|
+
"HardUpdate",
|
|
57
|
+
"IQLLoss",
|
|
58
|
+
"KLPENPPOLoss",
|
|
59
|
+
"LossModule",
|
|
60
|
+
"OnlineDTLoss",
|
|
61
|
+
"PPOLoss",
|
|
62
|
+
"QMixerLoss",
|
|
63
|
+
"REDQLoss",
|
|
64
|
+
"ReinforceLoss",
|
|
65
|
+
"SACLoss",
|
|
66
|
+
"SoftUpdate",
|
|
67
|
+
"TD3BCLoss",
|
|
68
|
+
"TD3Loss",
|
|
69
|
+
"TargetNetUpdater",
|
|
70
|
+
"ValueEstimators",
|
|
71
|
+
"add_random_module",
|
|
72
|
+
"default_value_kwargs",
|
|
73
|
+
"distance_loss",
|
|
74
|
+
"group_optimizers",
|
|
75
|
+
"hold_out_net",
|
|
76
|
+
"hold_out_params",
|
|
77
|
+
"next_state_value",
|
|
78
|
+
]
|