torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,658 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import itertools
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from tensordict import set_lazy_legacy
|
|
12
|
+
from tensordict.nn import InteractionType
|
|
13
|
+
from torch import nn
|
|
14
|
+
from torchrl.data.tensor_specs import Categorical, Composite, Unbounded
|
|
15
|
+
from torchrl.data.utils import DEVICE_TYPING
|
|
16
|
+
from torchrl.envs.common import EnvBase
|
|
17
|
+
from torchrl.envs.model_based.dreamer import DreamerEnv
|
|
18
|
+
from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv
|
|
19
|
+
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
20
|
+
from torchrl.modules import (
|
|
21
|
+
NoisyLinear,
|
|
22
|
+
SafeModule,
|
|
23
|
+
SafeProbabilisticModule,
|
|
24
|
+
SafeProbabilisticTensorDictSequential,
|
|
25
|
+
SafeSequential,
|
|
26
|
+
)
|
|
27
|
+
from torchrl.modules.distributions import (
|
|
28
|
+
Delta,
|
|
29
|
+
OneHotCategorical,
|
|
30
|
+
TanhDelta,
|
|
31
|
+
TanhNormal,
|
|
32
|
+
)
|
|
33
|
+
from torchrl.modules.models.model_based import (
|
|
34
|
+
DreamerActor,
|
|
35
|
+
ObsDecoder,
|
|
36
|
+
ObsEncoder,
|
|
37
|
+
RSSMPosterior,
|
|
38
|
+
RSSMPrior,
|
|
39
|
+
RSSMRollout,
|
|
40
|
+
)
|
|
41
|
+
from torchrl.modules.models.models import DuelingCnnDQNet, DuelingMlpDQNet, MLP
|
|
42
|
+
from torchrl.modules.tensordict_module import (
|
|
43
|
+
Actor,
|
|
44
|
+
DistributionalQValueActor,
|
|
45
|
+
QValueActor,
|
|
46
|
+
)
|
|
47
|
+
from torchrl.modules.tensordict_module.world_models import WorldModelWrapper
|
|
48
|
+
from torchrl.trainers.helpers import transformed_env_constructor
|
|
49
|
+
|
|
50
|
+
DISTRIBUTIONS = {
|
|
51
|
+
"delta": Delta,
|
|
52
|
+
"tanh-normal": TanhNormal,
|
|
53
|
+
"categorical": OneHotCategorical,
|
|
54
|
+
"tanh-delta": TanhDelta,
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
ACTIVATIONS = {
|
|
58
|
+
"elu": nn.ELU,
|
|
59
|
+
"tanh": nn.Tanh,
|
|
60
|
+
"relu": nn.ReLU,
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def make_dqn_actor(
|
|
65
|
+
proof_environment: EnvBase, cfg: DictConfig, device: torch.device # noqa: F821
|
|
66
|
+
) -> Actor:
|
|
67
|
+
"""DQN constructor helper function.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
proof_environment (EnvBase): a dummy environment to retrieve the observation and action spec.
|
|
71
|
+
cfg (DictConfig): contains arguments of the DQN script
|
|
72
|
+
device (torch.device): device on which the model must be cast
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
A DQN policy operator.
|
|
76
|
+
|
|
77
|
+
Examples:
|
|
78
|
+
>>> from torchrl.trainers.helpers.models import make_dqn_actor, DiscreteModelConfig
|
|
79
|
+
>>> from torchrl.trainers.helpers.envs import EnvConfig
|
|
80
|
+
>>> from torchrl.envs.libs.gym import GymEnv
|
|
81
|
+
>>> from torchrl.envs.transforms import ToTensorImage, TransformedEnv
|
|
82
|
+
>>> import hydra
|
|
83
|
+
>>> from hydra.core.config_store import ConfigStore
|
|
84
|
+
>>> import dataclasses
|
|
85
|
+
>>> proof_environment = TransformedEnv(GymEnv("ALE/Pong-v5",
|
|
86
|
+
... pixels_only=True), ToTensorImage())
|
|
87
|
+
>>> device = torch.device("cpu")
|
|
88
|
+
>>> config_fields = [(config_field.name, config_field.type, config_field) for config_cls in
|
|
89
|
+
... (DiscreteModelConfig, EnvConfig)
|
|
90
|
+
... for config_field in dataclasses.fields(config_cls)]
|
|
91
|
+
>>> Config = dataclasses.make_dataclass(cls_name="Config", fields=config_fields)
|
|
92
|
+
>>> cs = ConfigStore.instance()
|
|
93
|
+
>>> cs.store(name="config", node=Config)
|
|
94
|
+
>>> with initialize(config_path=None):
|
|
95
|
+
>>> cfg = compose(config_name="config")
|
|
96
|
+
>>> actor = make_dqn_actor(proof_environment, cfg, device)
|
|
97
|
+
>>> td = proof_environment.reset()
|
|
98
|
+
>>> print(actor(td))
|
|
99
|
+
TensorDict(
|
|
100
|
+
fields={
|
|
101
|
+
done: Tensor(torch.Size([1]), dtype=torch.bool),
|
|
102
|
+
pixels: Tensor(torch.Size([3, 210, 160]), dtype=torch.float32),
|
|
103
|
+
action: Tensor(torch.Size([6]), dtype=torch.int64),
|
|
104
|
+
action_value: Tensor(torch.Size([6]), dtype=torch.float32),
|
|
105
|
+
chosen_action_value: Tensor(torch.Size([1]), dtype=torch.float32)},
|
|
106
|
+
batch_size=torch.Size([]),
|
|
107
|
+
device=cpu,
|
|
108
|
+
is_shared=False)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
"""
|
|
112
|
+
env_specs = proof_environment.specs
|
|
113
|
+
|
|
114
|
+
atoms = cfg.atoms if cfg.distributional else None
|
|
115
|
+
linear_layer_class = torch.nn.Linear if not cfg.noisy else NoisyLinear
|
|
116
|
+
|
|
117
|
+
action_spec = env_specs["input_spec", "full_action_spec", "action"]
|
|
118
|
+
if action_spec.domain != "discrete":
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"env {proof_environment} has an action domain "
|
|
121
|
+
f"{action_spec.domain} which is incompatible with "
|
|
122
|
+
f"DQN. Make sure your environment has a discrete "
|
|
123
|
+
f"domain."
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if cfg.from_pixels:
|
|
127
|
+
net_class = DuelingCnnDQNet
|
|
128
|
+
default_net_kwargs = {
|
|
129
|
+
"cnn_kwargs": {
|
|
130
|
+
"bias_last_layer": True,
|
|
131
|
+
"depth": None,
|
|
132
|
+
"num_cells": [32, 64, 64],
|
|
133
|
+
"kernel_sizes": [8, 4, 3],
|
|
134
|
+
"strides": [4, 2, 1],
|
|
135
|
+
},
|
|
136
|
+
"mlp_kwargs": {"num_cells": 512, "layer_class": linear_layer_class},
|
|
137
|
+
}
|
|
138
|
+
in_key = "pixels"
|
|
139
|
+
|
|
140
|
+
else:
|
|
141
|
+
net_class = DuelingMlpDQNet
|
|
142
|
+
default_net_kwargs = {
|
|
143
|
+
"mlp_kwargs_feature": {}, # see class for details
|
|
144
|
+
"mlp_kwargs_output": {"num_cells": 512, "layer_class": linear_layer_class},
|
|
145
|
+
}
|
|
146
|
+
# automatically infer in key
|
|
147
|
+
(in_key,) = itertools.islice(
|
|
148
|
+
env_specs["output_spec", "full_observation_spec"], 1
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
actor_class = QValueActor
|
|
152
|
+
actor_kwargs = {}
|
|
153
|
+
|
|
154
|
+
if isinstance(action_spec, Categorical):
|
|
155
|
+
# if action spec is modeled as categorical variable, we still need to have features equal
|
|
156
|
+
# to the number of possible choices and also set categorical behavioral for actors.
|
|
157
|
+
actor_kwargs.update({"action_space": "categorical"})
|
|
158
|
+
out_features = env_specs["input_spec", "full_action_spec", "action"].space.n
|
|
159
|
+
else:
|
|
160
|
+
out_features = action_spec.shape[0]
|
|
161
|
+
|
|
162
|
+
if cfg.distributional:
|
|
163
|
+
if not atoms:
|
|
164
|
+
raise RuntimeError(
|
|
165
|
+
"Expected atoms to be a positive integer, " f"got {atoms}"
|
|
166
|
+
)
|
|
167
|
+
vmin = -3
|
|
168
|
+
vmax = 3
|
|
169
|
+
|
|
170
|
+
out_features = (atoms, out_features)
|
|
171
|
+
support = torch.linspace(vmin, vmax, atoms)
|
|
172
|
+
actor_class = DistributionalQValueActor
|
|
173
|
+
actor_kwargs.update({"support": support})
|
|
174
|
+
default_net_kwargs.update({"out_features_value": (atoms, 1)})
|
|
175
|
+
|
|
176
|
+
net = net_class(
|
|
177
|
+
out_features=out_features,
|
|
178
|
+
**default_net_kwargs,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
model = actor_class(
|
|
182
|
+
module=net,
|
|
183
|
+
spec=Composite(action=action_spec),
|
|
184
|
+
in_keys=[in_key],
|
|
185
|
+
safe=True,
|
|
186
|
+
**actor_kwargs,
|
|
187
|
+
).to(device)
|
|
188
|
+
|
|
189
|
+
# init
|
|
190
|
+
with torch.no_grad():
|
|
191
|
+
td = proof_environment.fake_tensordict()
|
|
192
|
+
td = td.unsqueeze(-1)
|
|
193
|
+
model(td.to(device))
|
|
194
|
+
return model
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@set_lazy_legacy(False)
|
|
198
|
+
def make_dreamer(
|
|
199
|
+
cfg: DictConfig, # noqa: F821
|
|
200
|
+
proof_environment: EnvBase = None,
|
|
201
|
+
device: DEVICE_TYPING = "cpu",
|
|
202
|
+
action_key: str = "action",
|
|
203
|
+
value_key: str = "state_value",
|
|
204
|
+
use_decoder_in_env: bool = False,
|
|
205
|
+
obs_norm_state_dict=None,
|
|
206
|
+
) -> nn.ModuleList:
|
|
207
|
+
"""Create Dreamer components.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
cfg (DictConfig): Config object.
|
|
211
|
+
proof_environment (EnvBase): Environment to initialize the model.
|
|
212
|
+
device (DEVICE_TYPING, optional): Device to use.
|
|
213
|
+
Defaults to "cpu".
|
|
214
|
+
action_key (str, optional): Key to use for the action.
|
|
215
|
+
Defaults to "action".
|
|
216
|
+
value_key (str, optional): Key to use for the value.
|
|
217
|
+
Defaults to "state_value".
|
|
218
|
+
use_decoder_in_env (bool, optional): Whether to use the decoder in the model based dreamer env.
|
|
219
|
+
Defaults to `False`.
|
|
220
|
+
obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform used
|
|
221
|
+
when proof_environment is missing. Defaults to None.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
nn.TensorDictModel: Dreamer World model.
|
|
225
|
+
nn.TensorDictModel: Dreamer Model based environment.
|
|
226
|
+
nn.TensorDictModel: Dreamer Actor the world model space.
|
|
227
|
+
nn.TensorDictModel: Dreamer Value model.
|
|
228
|
+
nn.TensorDictModel: Dreamer Actor for the real world space.
|
|
229
|
+
|
|
230
|
+
"""
|
|
231
|
+
proof_env_is_none = proof_environment is None
|
|
232
|
+
if proof_env_is_none:
|
|
233
|
+
proof_environment = transformed_env_constructor(
|
|
234
|
+
cfg=cfg, use_env_creator=False, obs_norm_state_dict=obs_norm_state_dict
|
|
235
|
+
)()
|
|
236
|
+
|
|
237
|
+
# Modules
|
|
238
|
+
obs_encoder = ObsEncoder()
|
|
239
|
+
obs_decoder = ObsDecoder()
|
|
240
|
+
|
|
241
|
+
rssm_prior = RSSMPrior(
|
|
242
|
+
hidden_dim=cfg.rssm_hidden_dim,
|
|
243
|
+
rnn_hidden_dim=cfg.rssm_hidden_dim,
|
|
244
|
+
state_dim=cfg.state_dim,
|
|
245
|
+
action_spec=proof_environment.action_spec,
|
|
246
|
+
)
|
|
247
|
+
rssm_posterior = RSSMPosterior(
|
|
248
|
+
hidden_dim=cfg.rssm_hidden_dim, state_dim=cfg.state_dim
|
|
249
|
+
)
|
|
250
|
+
reward_module = MLP(
|
|
251
|
+
out_features=1, depth=2, num_cells=cfg.mlp_num_units, activation_class=nn.ELU
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
world_model = _dreamer_make_world_model(
|
|
255
|
+
obs_encoder, obs_decoder, rssm_prior, rssm_posterior, reward_module
|
|
256
|
+
).to(device)
|
|
257
|
+
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
|
|
258
|
+
tensordict = proof_environment.fake_tensordict().unsqueeze(-1)
|
|
259
|
+
tensordict = tensordict.to(device)
|
|
260
|
+
world_model(tensordict)
|
|
261
|
+
|
|
262
|
+
model_based_env = _dreamer_make_mbenv(
|
|
263
|
+
reward_module,
|
|
264
|
+
rssm_prior,
|
|
265
|
+
obs_decoder,
|
|
266
|
+
proof_environment,
|
|
267
|
+
use_decoder_in_env,
|
|
268
|
+
cfg.state_dim,
|
|
269
|
+
cfg.rssm_hidden_dim,
|
|
270
|
+
)
|
|
271
|
+
model_based_env = model_based_env.to(device)
|
|
272
|
+
|
|
273
|
+
actor_simulator, actor_realworld = _dreamer_make_actors(
|
|
274
|
+
obs_encoder,
|
|
275
|
+
rssm_prior,
|
|
276
|
+
rssm_posterior,
|
|
277
|
+
cfg.mlp_num_units,
|
|
278
|
+
action_key,
|
|
279
|
+
proof_environment,
|
|
280
|
+
)
|
|
281
|
+
actor_simulator = actor_simulator.to(device)
|
|
282
|
+
|
|
283
|
+
value_model = _dreamer_make_value_model(cfg.mlp_num_units, value_key)
|
|
284
|
+
value_model = value_model.to(device)
|
|
285
|
+
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
|
|
286
|
+
tensordict = model_based_env.fake_tensordict().unsqueeze(-1)
|
|
287
|
+
tensordict = tensordict.to(device)
|
|
288
|
+
tensordict = actor_simulator(tensordict)
|
|
289
|
+
value_model(tensordict)
|
|
290
|
+
|
|
291
|
+
actor_realworld = actor_realworld.to(device)
|
|
292
|
+
if proof_env_is_none:
|
|
293
|
+
proof_environment.close()
|
|
294
|
+
torch.cuda.empty_cache()
|
|
295
|
+
del proof_environment
|
|
296
|
+
|
|
297
|
+
del tensordict
|
|
298
|
+
return world_model, model_based_env, actor_simulator, value_model, actor_realworld
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _dreamer_make_world_model(
|
|
302
|
+
obs_encoder, obs_decoder, rssm_prior, rssm_posterior, reward_module
|
|
303
|
+
):
|
|
304
|
+
# World Model and reward model
|
|
305
|
+
rssm_rollout = RSSMRollout(
|
|
306
|
+
SafeModule(
|
|
307
|
+
rssm_prior,
|
|
308
|
+
in_keys=["state", "belief", "action"],
|
|
309
|
+
out_keys=[
|
|
310
|
+
("next", "prior_mean"),
|
|
311
|
+
("next", "prior_std"),
|
|
312
|
+
"_",
|
|
313
|
+
("next", "belief"),
|
|
314
|
+
],
|
|
315
|
+
),
|
|
316
|
+
SafeModule(
|
|
317
|
+
rssm_posterior,
|
|
318
|
+
in_keys=[("next", "belief"), ("next", "encoded_latents")],
|
|
319
|
+
out_keys=[
|
|
320
|
+
("next", "posterior_mean"),
|
|
321
|
+
("next", "posterior_std"),
|
|
322
|
+
("next", "state"),
|
|
323
|
+
],
|
|
324
|
+
),
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
transition_model = SafeSequential(
|
|
328
|
+
SafeModule(
|
|
329
|
+
obs_encoder,
|
|
330
|
+
in_keys=[("next", "pixels")],
|
|
331
|
+
out_keys=[("next", "encoded_latents")],
|
|
332
|
+
),
|
|
333
|
+
rssm_rollout,
|
|
334
|
+
SafeModule(
|
|
335
|
+
obs_decoder,
|
|
336
|
+
in_keys=[("next", "state"), ("next", "belief")],
|
|
337
|
+
out_keys=[("next", "reco_pixels")],
|
|
338
|
+
),
|
|
339
|
+
)
|
|
340
|
+
reward_model = SafeModule(
|
|
341
|
+
reward_module,
|
|
342
|
+
in_keys=[("next", "state"), ("next", "belief")],
|
|
343
|
+
out_keys=[("next", "reward")],
|
|
344
|
+
)
|
|
345
|
+
world_model = WorldModelWrapper(
|
|
346
|
+
transition_model,
|
|
347
|
+
reward_model,
|
|
348
|
+
)
|
|
349
|
+
return world_model
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def _dreamer_make_actors(
|
|
353
|
+
obs_encoder,
|
|
354
|
+
rssm_prior,
|
|
355
|
+
rssm_posterior,
|
|
356
|
+
mlp_num_units,
|
|
357
|
+
action_key,
|
|
358
|
+
proof_environment,
|
|
359
|
+
):
|
|
360
|
+
actor_module = DreamerActor(
|
|
361
|
+
out_features=proof_environment.action_spec.shape[0],
|
|
362
|
+
depth=3,
|
|
363
|
+
num_cells=mlp_num_units,
|
|
364
|
+
activation_class=nn.ELU,
|
|
365
|
+
)
|
|
366
|
+
actor_simulator = _dreamer_make_actor_sim(
|
|
367
|
+
action_key, proof_environment, actor_module
|
|
368
|
+
)
|
|
369
|
+
actor_realworld = _dreamer_make_actor_real(
|
|
370
|
+
obs_encoder,
|
|
371
|
+
rssm_prior,
|
|
372
|
+
rssm_posterior,
|
|
373
|
+
actor_module,
|
|
374
|
+
action_key,
|
|
375
|
+
proof_environment,
|
|
376
|
+
)
|
|
377
|
+
return actor_simulator, actor_realworld
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def _dreamer_make_actor_sim(action_key, proof_environment, actor_module):
|
|
381
|
+
actor_simulator = SafeProbabilisticTensorDictSequential(
|
|
382
|
+
SafeModule(
|
|
383
|
+
actor_module,
|
|
384
|
+
in_keys=["state", "belief"],
|
|
385
|
+
out_keys=["loc", "scale"],
|
|
386
|
+
spec=Composite(
|
|
387
|
+
**{
|
|
388
|
+
"loc": Unbounded(
|
|
389
|
+
proof_environment.action_spec.shape,
|
|
390
|
+
device=proof_environment.action_spec.device,
|
|
391
|
+
),
|
|
392
|
+
"scale": Unbounded(
|
|
393
|
+
proof_environment.action_spec.shape,
|
|
394
|
+
device=proof_environment.action_spec.device,
|
|
395
|
+
),
|
|
396
|
+
}
|
|
397
|
+
),
|
|
398
|
+
),
|
|
399
|
+
SafeProbabilisticModule(
|
|
400
|
+
in_keys=["loc", "scale"],
|
|
401
|
+
out_keys=[action_key],
|
|
402
|
+
default_interaction_type=InteractionType.RANDOM,
|
|
403
|
+
distribution_class=TanhNormal,
|
|
404
|
+
distribution_kwargs={"tanh_loc": True},
|
|
405
|
+
spec=Composite(**{action_key: proof_environment.action_spec}),
|
|
406
|
+
),
|
|
407
|
+
)
|
|
408
|
+
return actor_simulator
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def _dreamer_make_actor_real(
|
|
412
|
+
obs_encoder, rssm_prior, rssm_posterior, actor_module, action_key, proof_environment
|
|
413
|
+
):
|
|
414
|
+
# actor for real world: interacts with states ~ posterior
|
|
415
|
+
# Out actor differs from the original paper where first they compute prior and posterior and then act on it
|
|
416
|
+
# but we found that this approach worked better.
|
|
417
|
+
actor_realworld = SafeSequential(
|
|
418
|
+
SafeModule(
|
|
419
|
+
obs_encoder,
|
|
420
|
+
in_keys=["pixels"],
|
|
421
|
+
out_keys=["encoded_latents"],
|
|
422
|
+
),
|
|
423
|
+
SafeModule(
|
|
424
|
+
rssm_posterior,
|
|
425
|
+
in_keys=["belief", "encoded_latents"],
|
|
426
|
+
out_keys=[
|
|
427
|
+
"_",
|
|
428
|
+
"_",
|
|
429
|
+
"state",
|
|
430
|
+
],
|
|
431
|
+
),
|
|
432
|
+
SafeProbabilisticTensorDictSequential(
|
|
433
|
+
SafeModule(
|
|
434
|
+
actor_module,
|
|
435
|
+
in_keys=["state", "belief"],
|
|
436
|
+
out_keys=["loc", "scale"],
|
|
437
|
+
spec=Composite(
|
|
438
|
+
**{
|
|
439
|
+
"loc": Unbounded(
|
|
440
|
+
proof_environment.action_spec.shape,
|
|
441
|
+
),
|
|
442
|
+
"scale": Unbounded(
|
|
443
|
+
proof_environment.action_spec.shape,
|
|
444
|
+
),
|
|
445
|
+
}
|
|
446
|
+
),
|
|
447
|
+
),
|
|
448
|
+
SafeProbabilisticModule(
|
|
449
|
+
in_keys=["loc", "scale"],
|
|
450
|
+
out_keys=[action_key],
|
|
451
|
+
default_interaction_type=InteractionType.DETERMINISTIC,
|
|
452
|
+
distribution_class=TanhNormal,
|
|
453
|
+
distribution_kwargs={"tanh_loc": True},
|
|
454
|
+
spec=Composite(**{action_key: proof_environment.action_spec.to("cpu")}),
|
|
455
|
+
),
|
|
456
|
+
),
|
|
457
|
+
SafeModule(
|
|
458
|
+
rssm_prior,
|
|
459
|
+
in_keys=["state", "belief", action_key],
|
|
460
|
+
out_keys=[
|
|
461
|
+
"_",
|
|
462
|
+
"_",
|
|
463
|
+
"_", # we don't need the prior state
|
|
464
|
+
("next", "belief"),
|
|
465
|
+
],
|
|
466
|
+
),
|
|
467
|
+
)
|
|
468
|
+
return actor_realworld
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
def _dreamer_make_value_model(mlp_num_units, value_key):
|
|
472
|
+
# actor for simulator: interacts with states ~ prior
|
|
473
|
+
value_model = SafeModule(
|
|
474
|
+
MLP(
|
|
475
|
+
out_features=1,
|
|
476
|
+
depth=3,
|
|
477
|
+
num_cells=mlp_num_units,
|
|
478
|
+
activation_class=nn.ELU,
|
|
479
|
+
),
|
|
480
|
+
in_keys=["state", "belief"],
|
|
481
|
+
out_keys=[value_key],
|
|
482
|
+
)
|
|
483
|
+
return value_model
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def _dreamer_make_mbenv(
|
|
487
|
+
reward_module,
|
|
488
|
+
rssm_prior,
|
|
489
|
+
obs_decoder,
|
|
490
|
+
proof_environment,
|
|
491
|
+
use_decoder_in_env,
|
|
492
|
+
state_dim,
|
|
493
|
+
rssm_hidden_dim,
|
|
494
|
+
):
|
|
495
|
+
# MB environment
|
|
496
|
+
if use_decoder_in_env:
|
|
497
|
+
mb_env_obs_decoder = SafeModule(
|
|
498
|
+
obs_decoder,
|
|
499
|
+
in_keys=[("next", "state"), ("next", "belief")],
|
|
500
|
+
out_keys=[("next", "reco_pixels")],
|
|
501
|
+
)
|
|
502
|
+
else:
|
|
503
|
+
mb_env_obs_decoder = None
|
|
504
|
+
|
|
505
|
+
transition_model = SafeSequential(
|
|
506
|
+
SafeModule(
|
|
507
|
+
rssm_prior,
|
|
508
|
+
in_keys=["state", "belief", "action"],
|
|
509
|
+
out_keys=[
|
|
510
|
+
"_",
|
|
511
|
+
"_",
|
|
512
|
+
"state",
|
|
513
|
+
"belief",
|
|
514
|
+
],
|
|
515
|
+
),
|
|
516
|
+
)
|
|
517
|
+
reward_model = SafeModule(
|
|
518
|
+
reward_module,
|
|
519
|
+
in_keys=["state", "belief"],
|
|
520
|
+
out_keys=["reward"],
|
|
521
|
+
)
|
|
522
|
+
model_based_env = DreamerEnv(
|
|
523
|
+
world_model=WorldModelWrapper(
|
|
524
|
+
transition_model,
|
|
525
|
+
reward_model,
|
|
526
|
+
),
|
|
527
|
+
prior_shape=torch.Size([state_dim]),
|
|
528
|
+
belief_shape=torch.Size([rssm_hidden_dim]),
|
|
529
|
+
obs_decoder=mb_env_obs_decoder,
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
model_based_env.set_specs_from_env(proof_environment)
|
|
533
|
+
model_based_env = TransformedEnv(model_based_env)
|
|
534
|
+
default_dict = {
|
|
535
|
+
"state": Unbounded(state_dim),
|
|
536
|
+
"belief": Unbounded(rssm_hidden_dim),
|
|
537
|
+
# "action": proof_environment.action_spec,
|
|
538
|
+
}
|
|
539
|
+
model_based_env.append_transform(
|
|
540
|
+
TensorDictPrimer(random=False, default_value=0, **default_dict)
|
|
541
|
+
)
|
|
542
|
+
return model_based_env
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
@dataclass
|
|
546
|
+
class DreamerConfig:
|
|
547
|
+
"""Dreamer model config struct."""
|
|
548
|
+
|
|
549
|
+
batch_length: int = 50
|
|
550
|
+
state_dim: int = 30
|
|
551
|
+
rssm_hidden_dim: int = 200
|
|
552
|
+
mlp_num_units: int = 400
|
|
553
|
+
grad_clip: int = 100
|
|
554
|
+
world_model_lr: float = 6e-4
|
|
555
|
+
actor_value_lr: float = 8e-5
|
|
556
|
+
imagination_horizon: int = 15
|
|
557
|
+
model_device: str = ""
|
|
558
|
+
# Decay of the reward moving averaging
|
|
559
|
+
exploration: str = "additive_gaussian"
|
|
560
|
+
# One of "additive_gaussian", "ou_exploration" or ""
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
@dataclass
|
|
564
|
+
class REDQModelConfig:
|
|
565
|
+
"""REDQ model config struct."""
|
|
566
|
+
|
|
567
|
+
annealing_frames: int = 1000000
|
|
568
|
+
# float of frames used for annealing of the OrnsteinUhlenbeckProcess. Default=1e6.
|
|
569
|
+
noisy: bool = False
|
|
570
|
+
# whether to use NoisyLinearLayers in the value network.
|
|
571
|
+
ou_exploration: bool = False
|
|
572
|
+
# wraps the policy in an OU exploration wrapper, similar to DDPG. SAC being designed for
|
|
573
|
+
# efficient entropy-based exploration, this should be left for experimentation only.
|
|
574
|
+
ou_sigma: float = 0.2
|
|
575
|
+
# Ornstein-Uhlenbeck sigma
|
|
576
|
+
ou_theta: float = 0.15
|
|
577
|
+
# Aimed at superseding --ou_exploration.
|
|
578
|
+
distributional: bool = False
|
|
579
|
+
# whether a distributional loss should be used (TODO: not implemented yet).
|
|
580
|
+
atoms: int = 51
|
|
581
|
+
# number of atoms used for the distributional loss (TODO)
|
|
582
|
+
gSDE: bool = False
|
|
583
|
+
# if True, exploration is achieved using the gSDE technique.
|
|
584
|
+
tanh_loc: bool = False
|
|
585
|
+
# if True, uses a Tanh-Normal transform for the policy location of the form
|
|
586
|
+
# upscale * tanh(loc/upscale) (only available with TanhTransform and TruncatedGaussian distributions)
|
|
587
|
+
default_policy_scale: float = 1.0
|
|
588
|
+
# Default policy scale parameter
|
|
589
|
+
distribution: str = "tanh_normal"
|
|
590
|
+
# if True, uses a Tanh-Normal-Tanh distribution for the policy
|
|
591
|
+
actor_cells: int = 256
|
|
592
|
+
# cells of the actor
|
|
593
|
+
qvalue_cells: int = 256
|
|
594
|
+
# cells of the qvalue net
|
|
595
|
+
scale_lb: float = 0.1
|
|
596
|
+
# min value of scale
|
|
597
|
+
value_cells: int = 256
|
|
598
|
+
# cells of the value net
|
|
599
|
+
activation: str = "tanh"
|
|
600
|
+
# activation function, either relu or elu or tanh, Default=tanh
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
@dataclass
|
|
604
|
+
class ContinuousModelConfig:
|
|
605
|
+
"""Continuous control model config struct."""
|
|
606
|
+
|
|
607
|
+
annealing_frames: int = 1000000
|
|
608
|
+
# float of frames used for annealing of the OrnsteinUhlenbeckProcess. Default=1e6.
|
|
609
|
+
noisy: bool = False
|
|
610
|
+
# whether to use NoisyLinearLayers in the value network.
|
|
611
|
+
ou_exploration: bool = False
|
|
612
|
+
# wraps the policy in an OU exploration wrapper, similar to DDPG. SAC being designed for
|
|
613
|
+
# efficient entropy-based exploration, this should be left for experimentation only.
|
|
614
|
+
ou_sigma: float = 0.2
|
|
615
|
+
# Ornstein-Uhlenbeck sigma
|
|
616
|
+
ou_theta: float = 0.15
|
|
617
|
+
# Aimed at superseding --ou_exploration.
|
|
618
|
+
distributional: bool = False
|
|
619
|
+
# whether a distributional loss should be used (TODO: not implemented yet).
|
|
620
|
+
atoms: int = 51
|
|
621
|
+
# number of atoms used for the distributional loss (TODO)
|
|
622
|
+
gSDE: bool = False
|
|
623
|
+
# if True, exploration is achieved using the gSDE technique.
|
|
624
|
+
tanh_loc: bool = False
|
|
625
|
+
# if True, uses a Tanh-Normal transform for the policy location of the form
|
|
626
|
+
# upscale * tanh(loc/upscale) (only available with TanhTransform and TruncatedGaussian distributions)
|
|
627
|
+
default_policy_scale: float = 1.0
|
|
628
|
+
# Default policy scale parameter
|
|
629
|
+
distribution: str = "tanh_normal"
|
|
630
|
+
# if True, uses a Tanh-Normal-Tanh distribution for the policy
|
|
631
|
+
lstm: bool = False
|
|
632
|
+
# if True, uses an LSTM for the policy.
|
|
633
|
+
shared_mapping: bool = False
|
|
634
|
+
# if True, the first layers of the actor-critic are shared.
|
|
635
|
+
actor_cells: int = 256
|
|
636
|
+
# cells of the actor
|
|
637
|
+
qvalue_cells: int = 256
|
|
638
|
+
# cells of the qvalue net
|
|
639
|
+
scale_lb: float = 0.1
|
|
640
|
+
# min value of scale
|
|
641
|
+
value_cells: int = 256
|
|
642
|
+
# cells of the value net
|
|
643
|
+
activation: str = "tanh"
|
|
644
|
+
# activation function, either relu or elu or tanh, Default=tanh
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
@dataclass
|
|
648
|
+
class DiscreteModelConfig:
|
|
649
|
+
"""Discrete model config struct."""
|
|
650
|
+
|
|
651
|
+
annealing_frames: int = 1000000
|
|
652
|
+
# Number of frames used for annealing of the EGreedy exploration. Default=1e6.
|
|
653
|
+
noisy: bool = False
|
|
654
|
+
# whether to use NoisyLinearLayers in the value network
|
|
655
|
+
distributional: bool = False
|
|
656
|
+
# whether a distributional loss should be used.
|
|
657
|
+
atoms: int = 51
|
|
658
|
+
# number of atoms used for the distributional loss
|