torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,570 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from functools import partial
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from omegaconf import MISSING
|
|
15
|
+
|
|
16
|
+
from torchrl.trainers.algorithms.configs.common import ConfigBase
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class ActivationConfig(ConfigBase):
|
|
21
|
+
"""A class to configure an activation function.
|
|
22
|
+
|
|
23
|
+
Defaults to :class:`torch.nn.Tanh`.
|
|
24
|
+
|
|
25
|
+
.. seealso:: :class:`torch.nn.Tanh`
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
_target_: str = "torch.nn.Tanh"
|
|
29
|
+
_partial_: bool = False
|
|
30
|
+
|
|
31
|
+
def __post_init__(self) -> None:
|
|
32
|
+
"""Post-initialization hook for activation configurations."""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class LayerConfig(ConfigBase):
|
|
37
|
+
"""A class to configure a layer.
|
|
38
|
+
|
|
39
|
+
Defaults to :class:`torch.nn.Linear`.
|
|
40
|
+
|
|
41
|
+
.. seealso:: :class:`torch.nn.Linear`
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
_target_: str = "torch.nn.Linear"
|
|
45
|
+
_partial_: bool = False
|
|
46
|
+
|
|
47
|
+
def __post_init__(self) -> None:
|
|
48
|
+
"""Post-initialization hook for layer configurations."""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class NetworkConfig(ConfigBase):
|
|
53
|
+
"""Parent class to configure a network."""
|
|
54
|
+
|
|
55
|
+
_partial_: bool = False
|
|
56
|
+
|
|
57
|
+
def __post_init__(self) -> None:
|
|
58
|
+
"""Post-initialization hook for network configurations."""
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class MLPConfig(NetworkConfig):
|
|
63
|
+
"""A class to configure a multi-layer perceptron.
|
|
64
|
+
|
|
65
|
+
Example:
|
|
66
|
+
>>> cfg = MLPConfig(in_features=10, out_features=5, depth=2, num_cells=32)
|
|
67
|
+
>>> net = instantiate(cfg)
|
|
68
|
+
>>> y = net(torch.randn(1, 10))
|
|
69
|
+
>>> assert y.shape == (1, 5)
|
|
70
|
+
|
|
71
|
+
.. seealso:: :class:`torchrl.modules.MLP`
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
in_features: int | None = None
|
|
75
|
+
out_features: Any = None
|
|
76
|
+
depth: int | None = None
|
|
77
|
+
num_cells: Any = None
|
|
78
|
+
activation_class: ActivationConfig = field(
|
|
79
|
+
default_factory=partial(
|
|
80
|
+
ActivationConfig, _target_="torch.nn.Tanh", _partial_=True
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
activation_kwargs: Any = None
|
|
84
|
+
norm_class: Any = None
|
|
85
|
+
norm_kwargs: Any = None
|
|
86
|
+
dropout: float | None = None
|
|
87
|
+
bias_last_layer: bool = True
|
|
88
|
+
single_bias_last_layer: bool = False
|
|
89
|
+
layer_class: LayerConfig = field(
|
|
90
|
+
default_factory=partial(LayerConfig, _target_="torch.nn.Linear", _partial_=True)
|
|
91
|
+
)
|
|
92
|
+
layer_kwargs: dict | None = None
|
|
93
|
+
activate_last_layer: bool = False
|
|
94
|
+
device: Any = None
|
|
95
|
+
_target_: str = "torchrl.modules.MLP"
|
|
96
|
+
|
|
97
|
+
def __post_init__(self):
|
|
98
|
+
if isinstance(self.activation_class, str):
|
|
99
|
+
self.activation_class = ActivationConfig(
|
|
100
|
+
_target_=self.activation_class, _partial_=True
|
|
101
|
+
)
|
|
102
|
+
if isinstance(self.layer_class, str):
|
|
103
|
+
self.layer_class = LayerConfig(_target_=self.layer_class, _partial_=True)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dataclass
|
|
107
|
+
class NormConfig(ConfigBase):
|
|
108
|
+
"""A class to configure a normalization layer.
|
|
109
|
+
|
|
110
|
+
Defaults to :class:`torch.nn.BatchNorm1d`.
|
|
111
|
+
|
|
112
|
+
.. seealso:: :class:`torch.nn.BatchNorm1d`
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
_target_: str = "torch.nn.BatchNorm1d"
|
|
116
|
+
_partial_: bool = False
|
|
117
|
+
|
|
118
|
+
def __post_init__(self) -> None:
|
|
119
|
+
"""Post-initialization hook for normalization configurations."""
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@dataclass
|
|
123
|
+
class AggregatorConfig(ConfigBase):
|
|
124
|
+
"""A class to configure an aggregator layer.
|
|
125
|
+
|
|
126
|
+
Defaults to :class:`torchrl.modules.models.utils.SquashDims`.
|
|
127
|
+
|
|
128
|
+
.. seealso:: :class:`torchrl.modules.models.utils.SquashDims`
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
_target_: str = "torchrl.modules.models.utils.SquashDims"
|
|
132
|
+
_partial_: bool = False
|
|
133
|
+
|
|
134
|
+
def __post_init__(self) -> None:
|
|
135
|
+
"""Post-initialization hook for aggregator configurations."""
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@dataclass
|
|
139
|
+
class ConvNetConfig(NetworkConfig):
|
|
140
|
+
"""A class to configure a convolutional network.
|
|
141
|
+
|
|
142
|
+
Defaults to :class:`torchrl.modules.ConvNet`.
|
|
143
|
+
|
|
144
|
+
Example:
|
|
145
|
+
>>> cfg = ConvNetConfig(in_features=3, depth=2, num_cells=[32, 64], kernel_sizes=[3, 5], strides=[1, 2], paddings=[1, 2])
|
|
146
|
+
>>> net = instantiate(cfg)
|
|
147
|
+
>>> y = net(torch.randn(1, 3, 32, 32))
|
|
148
|
+
>>> assert y.shape == (1, 64)
|
|
149
|
+
|
|
150
|
+
.. seealso:: :class:`torchrl.modules.ConvNet`
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
in_features: int | None = None
|
|
154
|
+
depth: int | None = None
|
|
155
|
+
num_cells: Any = None
|
|
156
|
+
kernel_sizes: Any = 3
|
|
157
|
+
strides: Any = 1
|
|
158
|
+
paddings: Any = 0
|
|
159
|
+
activation_class: ActivationConfig = field(
|
|
160
|
+
default_factory=partial(
|
|
161
|
+
ActivationConfig, _target_="torch.nn.ELU", _partial_=True
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
activation_kwargs: Any = None
|
|
165
|
+
norm_class: NormConfig | None = None
|
|
166
|
+
norm_kwargs: Any = None
|
|
167
|
+
bias_last_layer: bool = True
|
|
168
|
+
aggregator_class: AggregatorConfig = field(
|
|
169
|
+
default_factory=partial(
|
|
170
|
+
AggregatorConfig,
|
|
171
|
+
_target_="torchrl.modules.models.utils.SquashDims",
|
|
172
|
+
_partial_=True,
|
|
173
|
+
)
|
|
174
|
+
)
|
|
175
|
+
aggregator_kwargs: dict | None = None
|
|
176
|
+
squeeze_output: bool = False
|
|
177
|
+
device: Any = None
|
|
178
|
+
_target_: str = "torchrl.modules.ConvNet"
|
|
179
|
+
|
|
180
|
+
def __post_init__(self):
|
|
181
|
+
if self.activation_class is None and isinstance(self.activation_class, str):
|
|
182
|
+
self.activation_class = ActivationConfig(
|
|
183
|
+
_target_=self.activation_class, _partial_=True
|
|
184
|
+
)
|
|
185
|
+
if self.norm_class is None and isinstance(self.norm_class, str):
|
|
186
|
+
self.norm_class = NormConfig(_target_=self.norm_class, _partial_=True)
|
|
187
|
+
if self.aggregator_class is None and isinstance(self.aggregator_class, str):
|
|
188
|
+
self.aggregator_class = AggregatorConfig(
|
|
189
|
+
_target_=self.aggregator_class, _partial_=True
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
@dataclass
|
|
194
|
+
class ModelConfig(ConfigBase):
|
|
195
|
+
"""Parent class to configure a model.
|
|
196
|
+
|
|
197
|
+
A model can be made of several networks. It is always a :class:`~tensordict.nn.TensorDictModuleBase` instance.
|
|
198
|
+
|
|
199
|
+
.. seealso:: :class:`TanhNormalModelConfig`, :class:`ValueModelConfig`
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
_partial_: bool = False
|
|
203
|
+
in_keys: Any = None
|
|
204
|
+
out_keys: Any = None
|
|
205
|
+
shared: bool = False
|
|
206
|
+
|
|
207
|
+
def __post_init__(self) -> None:
|
|
208
|
+
"""Post-initialization hook for model configurations."""
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@dataclass
|
|
212
|
+
class TensorDictModuleConfig(ModelConfig):
|
|
213
|
+
"""A class to configure a TensorDictModule.
|
|
214
|
+
|
|
215
|
+
Example:
|
|
216
|
+
>>> cfg = TensorDictModuleConfig(module=MLPConfig(in_features=10, out_features=10, depth=2, num_cells=32), in_keys=["observation"], out_keys=["action"])
|
|
217
|
+
>>> module = instantiate(cfg)
|
|
218
|
+
>>> assert isinstance(module, TensorDictModule)
|
|
219
|
+
>>> assert module(observation=torch.randn(10, 10)).shape == (10, 10)
|
|
220
|
+
|
|
221
|
+
.. seealso:: :class:`tensordict.nn.TensorDictModule`
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
module: MLPConfig = MISSING
|
|
225
|
+
_target_: str = (
|
|
226
|
+
"torchrl.trainers.algorithms.configs.modules._make_tensordict_module"
|
|
227
|
+
)
|
|
228
|
+
_partial_: bool = False
|
|
229
|
+
|
|
230
|
+
def __post_init__(self) -> None:
|
|
231
|
+
"""Post-initialization hook for TensorDict module configurations."""
|
|
232
|
+
return super().__post_init__()
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
@dataclass
|
|
236
|
+
class TensorDictSequentialConfig(ModelConfig):
|
|
237
|
+
"""A class to configure a TensorDictSequential.
|
|
238
|
+
|
|
239
|
+
Example:
|
|
240
|
+
>>> cfg = TensorDictSequentialConfig(
|
|
241
|
+
... modules=[
|
|
242
|
+
... TensorDictModuleConfig(module=MLPConfig(in_features=10, out_features=10, depth=2, num_cells=32), in_keys=["observation"], out_keys=["hidden"]),
|
|
243
|
+
... TensorDictModuleConfig(module=MLPConfig(in_features=10, out_features=5, depth=2, num_cells=32), in_keys=["hidden"], out_keys=["action"])
|
|
244
|
+
... ]
|
|
245
|
+
... )
|
|
246
|
+
>>> seq = instantiate(cfg)
|
|
247
|
+
>>> assert isinstance(seq, TensorDictSequential)
|
|
248
|
+
|
|
249
|
+
.. seealso:: :class:`tensordict.nn.TensorDictSequential`
|
|
250
|
+
"""
|
|
251
|
+
|
|
252
|
+
modules: Any | None = None
|
|
253
|
+
partial_tolerant: bool = False
|
|
254
|
+
selected_out_keys: Any | None = None
|
|
255
|
+
inplace: bool | str | None = None
|
|
256
|
+
_target_: str = (
|
|
257
|
+
"torchrl.trainers.algorithms.configs.modules._make_tensordict_sequential"
|
|
258
|
+
)
|
|
259
|
+
_partial_: bool = False
|
|
260
|
+
|
|
261
|
+
def __post_init__(self) -> None:
|
|
262
|
+
return super().__post_init__()
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
@dataclass
|
|
266
|
+
class TanhNormalModelConfig(ModelConfig):
|
|
267
|
+
"""A class to configure a TanhNormal model.
|
|
268
|
+
|
|
269
|
+
Example:
|
|
270
|
+
>>> cfg = TanhNormalModelConfig(network=MLPConfig(in_features=10, out_features=5, depth=2, num_cells=32))
|
|
271
|
+
>>> net = instantiate(cfg)
|
|
272
|
+
>>> y = net(torch.randn(1, 10))
|
|
273
|
+
>>> assert y.shape == (1, 5)
|
|
274
|
+
|
|
275
|
+
.. seealso:: :class:`torchrl.modules.TanhNormal`
|
|
276
|
+
"""
|
|
277
|
+
|
|
278
|
+
network: MLPConfig = MISSING
|
|
279
|
+
eval_mode: bool = False
|
|
280
|
+
|
|
281
|
+
extract_normal_params: bool = True
|
|
282
|
+
scale_mapping: str = "biased_softplus_1.0"
|
|
283
|
+
scale_lb: float = 1e-4
|
|
284
|
+
|
|
285
|
+
param_keys: Any = None
|
|
286
|
+
|
|
287
|
+
exploration_type: Any = "RANDOM"
|
|
288
|
+
|
|
289
|
+
return_log_prob: bool = False
|
|
290
|
+
|
|
291
|
+
_target_: str = (
|
|
292
|
+
"torchrl.trainers.algorithms.configs.modules._make_tanh_normal_model"
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
def __post_init__(self):
|
|
296
|
+
"""Post-initialization hook for TanhNormal model configurations."""
|
|
297
|
+
super().__post_init__()
|
|
298
|
+
if self.in_keys is None:
|
|
299
|
+
self.in_keys = ["observation"]
|
|
300
|
+
if self.param_keys is None:
|
|
301
|
+
self.param_keys = ["loc", "scale"]
|
|
302
|
+
if self.out_keys is None:
|
|
303
|
+
self.out_keys = ["action"]
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
@dataclass
|
|
307
|
+
class ValueModelConfig(ModelConfig):
|
|
308
|
+
"""A class to configure a Value model.
|
|
309
|
+
|
|
310
|
+
Example:
|
|
311
|
+
>>> cfg = ValueModelConfig(network=MLPConfig(in_features=10, out_features=5, depth=2, num_cells=32))
|
|
312
|
+
>>> net = instantiate(cfg)
|
|
313
|
+
>>> y = net(torch.randn(1, 10))
|
|
314
|
+
>>> assert y.shape == (1, 5)
|
|
315
|
+
|
|
316
|
+
.. seealso:: :class:`torchrl.modules.ValueOperator`
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
_target_: str = "torchrl.trainers.algorithms.configs.modules._make_value_model"
|
|
320
|
+
network: NetworkConfig = MISSING
|
|
321
|
+
|
|
322
|
+
def __post_init__(self) -> None:
|
|
323
|
+
"""Post-initialization hook for value model configurations."""
|
|
324
|
+
super().__post_init__()
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
@dataclass
|
|
328
|
+
class TanhModuleConfig(ModelConfig):
|
|
329
|
+
"""A class to configure a TanhModule.
|
|
330
|
+
|
|
331
|
+
Example:
|
|
332
|
+
>>> cfg = TanhModuleConfig(in_keys=["action"], out_keys=["action"], low=-1.0, high=1.0)
|
|
333
|
+
>>> module = instantiate(cfg)
|
|
334
|
+
>>> assert isinstance(module, TanhModule)
|
|
335
|
+
|
|
336
|
+
.. seealso:: :class:`torchrl.modules.TanhModule`
|
|
337
|
+
"""
|
|
338
|
+
|
|
339
|
+
spec: Any = None
|
|
340
|
+
low: Any = None
|
|
341
|
+
high: Any = None
|
|
342
|
+
clamp: bool = False
|
|
343
|
+
_target_: str = "torchrl.trainers.algorithms.configs.modules._make_tanh_module"
|
|
344
|
+
|
|
345
|
+
def __post_init__(self) -> None:
|
|
346
|
+
"""Post-initialization hook for TanhModule configurations."""
|
|
347
|
+
super().__post_init__()
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
@dataclass
|
|
351
|
+
class AdditiveGaussianModuleConfig(ModelConfig):
|
|
352
|
+
"""A class to configure an AdditiveGaussianModule.
|
|
353
|
+
|
|
354
|
+
Example:
|
|
355
|
+
>>> cfg = AdditiveGaussianModuleConfig(
|
|
356
|
+
... spec=None,
|
|
357
|
+
... sigma_init=1.0,
|
|
358
|
+
... sigma_end=0.1,
|
|
359
|
+
... mean=0.0,
|
|
360
|
+
... std=1.0,
|
|
361
|
+
... action_key="action",
|
|
362
|
+
... )
|
|
363
|
+
>>> module = instantiate(cfg)
|
|
364
|
+
>>> assert isinstance(module, AdditiveGaussianModule)
|
|
365
|
+
|
|
366
|
+
.. seealso:: :class:`torchrl.modules.AdditiveGaussianModule`
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
spec: Any = None
|
|
370
|
+
sigma_init: float = 1.0
|
|
371
|
+
sigma_end: float = 0.1
|
|
372
|
+
annealing_num_steps: int = 1000
|
|
373
|
+
mean: float = 0.0
|
|
374
|
+
std: float = 1.0
|
|
375
|
+
action_key: Any = "action"
|
|
376
|
+
safe: bool = False
|
|
377
|
+
device: Any = None
|
|
378
|
+
_target_: str = (
|
|
379
|
+
"torchrl.trainers.algorithms.configs.modules._make_additive_gaussian_module"
|
|
380
|
+
)
|
|
381
|
+
_partial_: bool = False
|
|
382
|
+
|
|
383
|
+
def __post_init__(self) -> None:
|
|
384
|
+
super().__post_init__()
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def _make_tensordict_module(*args, **kwargs):
|
|
388
|
+
"""Helper function to create a TensorDictModule."""
|
|
389
|
+
from hydra.utils import instantiate
|
|
390
|
+
from tensordict.nn import TensorDictModule
|
|
391
|
+
|
|
392
|
+
module = kwargs.pop("module")
|
|
393
|
+
shared = kwargs.pop("shared", False)
|
|
394
|
+
|
|
395
|
+
# Instantiate the module if it's a config
|
|
396
|
+
if hasattr(module, "_target_"):
|
|
397
|
+
module = instantiate(module)
|
|
398
|
+
elif callable(module) and hasattr(module, "func"): # partial function
|
|
399
|
+
module = module()
|
|
400
|
+
|
|
401
|
+
# Create the TensorDictModule
|
|
402
|
+
tensordict_module = TensorDictModule(module, **kwargs)
|
|
403
|
+
|
|
404
|
+
# Apply share_memory if needed
|
|
405
|
+
if shared:
|
|
406
|
+
tensordict_module = tensordict_module.share_memory()
|
|
407
|
+
|
|
408
|
+
return tensordict_module
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def _make_tensordict_sequential(*args, **kwargs):
|
|
412
|
+
"""Helper function to create a TensorDictSequential."""
|
|
413
|
+
from hydra.utils import instantiate
|
|
414
|
+
from omegaconf import DictConfig, ListConfig
|
|
415
|
+
from tensordict.nn import TensorDictSequential
|
|
416
|
+
|
|
417
|
+
modules = kwargs.pop("modules")
|
|
418
|
+
shared = kwargs.pop("shared", False)
|
|
419
|
+
partial_tolerant = kwargs.pop("partial_tolerant", False)
|
|
420
|
+
selected_out_keys = kwargs.pop("selected_out_keys", None)
|
|
421
|
+
inplace = kwargs.pop("inplace", None)
|
|
422
|
+
|
|
423
|
+
def _instantiate_module(module):
|
|
424
|
+
if hasattr(module, "_target_"):
|
|
425
|
+
return instantiate(module)
|
|
426
|
+
elif callable(module) and hasattr(module, "func"):
|
|
427
|
+
return module()
|
|
428
|
+
else:
|
|
429
|
+
return module
|
|
430
|
+
|
|
431
|
+
if isinstance(modules, (dict, DictConfig)):
|
|
432
|
+
instantiated_modules = {
|
|
433
|
+
key: _instantiate_module(module) for key, module in modules.items()
|
|
434
|
+
}
|
|
435
|
+
elif isinstance(modules, (list, ListConfig)):
|
|
436
|
+
instantiated_modules = [_instantiate_module(module) for module in modules]
|
|
437
|
+
else:
|
|
438
|
+
raise ValueError(
|
|
439
|
+
f"modules must be a dict or list, got {type(modules).__name__}"
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
tensordict_sequential = TensorDictSequential(
|
|
443
|
+
instantiated_modules,
|
|
444
|
+
partial_tolerant=partial_tolerant,
|
|
445
|
+
selected_out_keys=selected_out_keys,
|
|
446
|
+
inplace=inplace,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
if shared:
|
|
450
|
+
tensordict_sequential = tensordict_sequential.share_memory()
|
|
451
|
+
|
|
452
|
+
return tensordict_sequential
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
def _make_tanh_normal_model(*args, **kwargs):
|
|
456
|
+
"""Helper function to create a TanhNormal model with ProbabilisticTensorDictSequential."""
|
|
457
|
+
from hydra.utils import instantiate
|
|
458
|
+
from tensordict.nn import (
|
|
459
|
+
ProbabilisticTensorDictModule,
|
|
460
|
+
ProbabilisticTensorDictSequential,
|
|
461
|
+
TensorDictModule,
|
|
462
|
+
)
|
|
463
|
+
from torchrl.modules import NormalParamExtractor, TanhNormal
|
|
464
|
+
|
|
465
|
+
# Extract parameters
|
|
466
|
+
network = kwargs.pop("network")
|
|
467
|
+
in_keys = list(kwargs.pop("in_keys", ["observation"]))
|
|
468
|
+
param_keys = list(kwargs.pop("param_keys", ["loc", "scale"]))
|
|
469
|
+
out_keys = list(kwargs.pop("out_keys", ["action"]))
|
|
470
|
+
extract_normal_params = kwargs.pop("extract_normal_params", True)
|
|
471
|
+
scale_mapping = kwargs.pop("scale_mapping", "biased_softplus_1.0")
|
|
472
|
+
scale_lb = kwargs.pop("scale_lb", 1e-4)
|
|
473
|
+
return_log_prob = kwargs.pop("return_log_prob", False)
|
|
474
|
+
eval_mode = kwargs.pop("eval_mode", False)
|
|
475
|
+
exploration_type = kwargs.pop("exploration_type", "RANDOM")
|
|
476
|
+
shared = kwargs.pop("shared", False)
|
|
477
|
+
|
|
478
|
+
# Now instantiate the network
|
|
479
|
+
if hasattr(network, "_target_"):
|
|
480
|
+
network = instantiate(network)
|
|
481
|
+
elif callable(network) and hasattr(network, "func"): # partial function
|
|
482
|
+
network = network()
|
|
483
|
+
|
|
484
|
+
# Create the sequential
|
|
485
|
+
if extract_normal_params:
|
|
486
|
+
# Add NormalParamExtractor to split the output
|
|
487
|
+
network = torch.nn.Sequential(
|
|
488
|
+
network,
|
|
489
|
+
NormalParamExtractor(scale_mapping=scale_mapping, scale_lb=scale_lb),
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
module = TensorDictModule(network, in_keys=in_keys, out_keys=param_keys)
|
|
493
|
+
if shared:
|
|
494
|
+
module = module.share_memory()
|
|
495
|
+
|
|
496
|
+
# Create ProbabilisticTensorDictModule
|
|
497
|
+
prob_module = ProbabilisticTensorDictModule(
|
|
498
|
+
in_keys=param_keys,
|
|
499
|
+
out_keys=out_keys,
|
|
500
|
+
distribution_class=TanhNormal,
|
|
501
|
+
return_log_prob=return_log_prob,
|
|
502
|
+
default_interaction_type=exploration_type,
|
|
503
|
+
**kwargs,
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
result = ProbabilisticTensorDictSequential(module, prob_module)
|
|
507
|
+
if eval_mode:
|
|
508
|
+
result.eval()
|
|
509
|
+
return result
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def _make_value_model(*args, **kwargs):
|
|
513
|
+
"""Helper function to create a ValueOperator with the given network."""
|
|
514
|
+
from hydra.utils import instantiate
|
|
515
|
+
|
|
516
|
+
from torchrl.modules import ValueOperator
|
|
517
|
+
|
|
518
|
+
network = kwargs.pop("network")
|
|
519
|
+
shared = kwargs.pop("shared", False)
|
|
520
|
+
|
|
521
|
+
# Instantiate the network if it's a config
|
|
522
|
+
if hasattr(network, "_target_"):
|
|
523
|
+
network = instantiate(network)
|
|
524
|
+
elif callable(network) and hasattr(network, "func"): # partial function
|
|
525
|
+
network = network()
|
|
526
|
+
|
|
527
|
+
# Create the ValueOperator
|
|
528
|
+
value_operator = ValueOperator(network, **kwargs)
|
|
529
|
+
|
|
530
|
+
# Apply share_memory if needed
|
|
531
|
+
if shared:
|
|
532
|
+
value_operator = value_operator.share_memory()
|
|
533
|
+
|
|
534
|
+
return value_operator
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
def _make_tanh_module(*args, **kwargs):
|
|
538
|
+
"""Helper function to create a TanhModule."""
|
|
539
|
+
from omegaconf import ListConfig
|
|
540
|
+
|
|
541
|
+
from torchrl.modules import TanhModule
|
|
542
|
+
|
|
543
|
+
kwargs.pop("shared", False)
|
|
544
|
+
|
|
545
|
+
if "in_keys" in kwargs and isinstance(kwargs["in_keys"], ListConfig):
|
|
546
|
+
kwargs["in_keys"] = list(kwargs["in_keys"])
|
|
547
|
+
if "out_keys" in kwargs and isinstance(kwargs["out_keys"], ListConfig):
|
|
548
|
+
kwargs["out_keys"] = list(kwargs["out_keys"])
|
|
549
|
+
|
|
550
|
+
return TanhModule(**kwargs)
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
def _make_additive_gaussian_module(*args, **kwargs):
|
|
554
|
+
"""Helper function to create an AdditiveGaussianModule."""
|
|
555
|
+
from omegaconf import ListConfig
|
|
556
|
+
|
|
557
|
+
from torchrl.modules.tensordict_module.exploration import AdditiveGaussianModule
|
|
558
|
+
|
|
559
|
+
kwargs.pop("shared", False)
|
|
560
|
+
kwargs.pop("in_keys", None)
|
|
561
|
+
kwargs.pop("out_keys", None)
|
|
562
|
+
|
|
563
|
+
if "action_key" in kwargs and isinstance(kwargs["action_key"], ListConfig):
|
|
564
|
+
action_key_list = list(kwargs["action_key"])
|
|
565
|
+
if len(action_key_list) == 1:
|
|
566
|
+
kwargs["action_key"] = action_key_list[0]
|
|
567
|
+
else:
|
|
568
|
+
kwargs["action_key"] = tuple(action_key_list)
|
|
569
|
+
|
|
570
|
+
return AdditiveGaussianModule(**kwargs)
|