torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,596 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import warnings
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from packaging import version
|
|
11
|
+
from tensordict.nn import (
|
|
12
|
+
NormalParamExtractor,
|
|
13
|
+
TensorDictModule,
|
|
14
|
+
TensorDictModuleBase,
|
|
15
|
+
TensorDictSequential,
|
|
16
|
+
)
|
|
17
|
+
from torch import nn
|
|
18
|
+
|
|
19
|
+
# from torchrl.modules.tensordict_module.rnn import GRUCell
|
|
20
|
+
from torch.nn import GRUCell
|
|
21
|
+
|
|
22
|
+
from torchrl.modules.models.models import MLP
|
|
23
|
+
|
|
24
|
+
UNSQUEEZE_RNN_INPUT = version.parse(torch.__version__) < version.parse("1.11")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DreamerActor(nn.Module):
|
|
28
|
+
"""Dreamer actor network.
|
|
29
|
+
|
|
30
|
+
This network is used to predict the action distribution given the
|
|
31
|
+
the stochastic state and the deterministic belief at the current
|
|
32
|
+
time step.
|
|
33
|
+
It outputs the mean and the scale of the action distribution.
|
|
34
|
+
|
|
35
|
+
Reference: https://arxiv.org/abs/1912.01603
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
out_features (int): Number of output features.
|
|
39
|
+
depth (int, optional): Number of hidden layers.
|
|
40
|
+
Defaults to 4.
|
|
41
|
+
num_cells (int, optional): Number of hidden units per layer.
|
|
42
|
+
Defaults to 200.
|
|
43
|
+
activation_class (nn.Module, optional): Activation class.
|
|
44
|
+
Defaults to nn.ELU.
|
|
45
|
+
std_bias (:obj:`float`, optional): Bias of the softplus transform.
|
|
46
|
+
Defaults to 5.0.
|
|
47
|
+
std_min_val (:obj:`float`, optional): Minimum value of the standard deviation.
|
|
48
|
+
Defaults to 1e-4.
|
|
49
|
+
device (torch.device, optional): Device to create the module on.
|
|
50
|
+
Defaults to None (uses default device).
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
out_features,
|
|
56
|
+
depth=4,
|
|
57
|
+
num_cells=200,
|
|
58
|
+
activation_class=nn.ELU,
|
|
59
|
+
std_bias=5.0,
|
|
60
|
+
std_min_val=1e-4,
|
|
61
|
+
device=None,
|
|
62
|
+
):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.backbone = MLP(
|
|
65
|
+
out_features=2 * out_features,
|
|
66
|
+
depth=depth,
|
|
67
|
+
num_cells=num_cells,
|
|
68
|
+
activation_class=activation_class,
|
|
69
|
+
device=device,
|
|
70
|
+
)
|
|
71
|
+
self.backbone.append(
|
|
72
|
+
NormalParamExtractor(
|
|
73
|
+
scale_mapping=f"biased_softplus_{std_bias}_{std_min_val}",
|
|
74
|
+
# scale_mapping="relu",
|
|
75
|
+
),
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def forward(self, state, belief):
|
|
79
|
+
loc, scale = self.backbone(state, belief)
|
|
80
|
+
return loc, scale
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class ObsEncoder(nn.Module):
|
|
84
|
+
"""Observation encoder network.
|
|
85
|
+
|
|
86
|
+
Takes a pixel observation and encodes it into a latent space.
|
|
87
|
+
|
|
88
|
+
Reference: https://arxiv.org/abs/1803.10122
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
channels (int, optional): Number of hidden units in the first layer.
|
|
92
|
+
Defaults to 32.
|
|
93
|
+
num_layers (int, optional): Depth of the network. Defaults to 4.
|
|
94
|
+
in_channels (int, optional): Number of input channels. If None, uses LazyConv2d.
|
|
95
|
+
Defaults to None for backward compatibility.
|
|
96
|
+
device (torch.device, optional): Device to create the module on.
|
|
97
|
+
Defaults to None (uses default device).
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
def __init__(
|
|
101
|
+
self, channels=32, num_layers=4, in_channels=None, depth=None, device=None
|
|
102
|
+
):
|
|
103
|
+
if depth is not None:
|
|
104
|
+
warnings.warn(
|
|
105
|
+
f"The depth argument in {type(self)} will soon be deprecated and "
|
|
106
|
+
f"used for the depth of the network instead. Please use channels "
|
|
107
|
+
f"for the layer size and num_layers for the depth until depth "
|
|
108
|
+
f"replaces num_layers."
|
|
109
|
+
)
|
|
110
|
+
channels = depth
|
|
111
|
+
if num_layers < 1:
|
|
112
|
+
raise RuntimeError("num_layers cannot be smaller than 1.")
|
|
113
|
+
super().__init__()
|
|
114
|
+
# Use explicit Conv2d if in_channels provided, else LazyConv2d for backward compat
|
|
115
|
+
if in_channels is not None:
|
|
116
|
+
first_conv = nn.Conv2d(in_channels, channels, 4, stride=2, device=device)
|
|
117
|
+
else:
|
|
118
|
+
first_conv = nn.LazyConv2d(channels, 4, stride=2, device=device)
|
|
119
|
+
layers = [
|
|
120
|
+
first_conv,
|
|
121
|
+
nn.ReLU(),
|
|
122
|
+
]
|
|
123
|
+
k = 1
|
|
124
|
+
for _ in range(1, num_layers):
|
|
125
|
+
layers += [
|
|
126
|
+
nn.Conv2d(channels * k, channels * (k * 2), 4, stride=2, device=device),
|
|
127
|
+
nn.ReLU(),
|
|
128
|
+
]
|
|
129
|
+
k = k * 2
|
|
130
|
+
self.encoder = nn.Sequential(*layers)
|
|
131
|
+
|
|
132
|
+
def forward(self, observation):
|
|
133
|
+
*batch_sizes, C, H, W = observation.shape
|
|
134
|
+
if len(batch_sizes) == 0:
|
|
135
|
+
end_dim = 0
|
|
136
|
+
else:
|
|
137
|
+
end_dim = len(batch_sizes) - 1
|
|
138
|
+
observation = torch.flatten(observation, start_dim=0, end_dim=end_dim)
|
|
139
|
+
obs_encoded = self.encoder(observation)
|
|
140
|
+
latent = obs_encoded.reshape(*batch_sizes, -1)
|
|
141
|
+
return latent
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class ObsDecoder(nn.Module):
|
|
145
|
+
"""Observation decoder network.
|
|
146
|
+
|
|
147
|
+
Takes the deterministic state and the stochastic belief and decodes it into a pixel observation.
|
|
148
|
+
|
|
149
|
+
Reference: https://arxiv.org/abs/1803.10122
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
channels (int, optional): Number of hidden units in the last layer.
|
|
153
|
+
Defaults to 32.
|
|
154
|
+
num_layers (int, optional): Depth of the network. Defaults to 4.
|
|
155
|
+
kernel_sizes (int or list of int, optional): the kernel_size of each layer.
|
|
156
|
+
Defaults to ``[5, 5, 6, 6]`` if num_layers if 4, else ``[5] * num_layers``.
|
|
157
|
+
latent_dim (int, optional): Input dimension (state_dim + rnn_hidden_dim).
|
|
158
|
+
If None, uses LazyLinear. Defaults to None for backward compatibility.
|
|
159
|
+
device (torch.device, optional): Device to create the module on.
|
|
160
|
+
Defaults to None (uses default device).
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
def __init__(
|
|
164
|
+
self,
|
|
165
|
+
channels=32,
|
|
166
|
+
num_layers=4,
|
|
167
|
+
kernel_sizes=None,
|
|
168
|
+
latent_dim=None,
|
|
169
|
+
depth=None,
|
|
170
|
+
device=None,
|
|
171
|
+
):
|
|
172
|
+
if depth is not None:
|
|
173
|
+
warnings.warn(
|
|
174
|
+
f"The depth argument in {type(self)} will soon be deprecated and "
|
|
175
|
+
f"used for the depth of the network instead. Please use channels "
|
|
176
|
+
f"for the layer size and num_layers for the depth until depth "
|
|
177
|
+
f"replaces num_layers."
|
|
178
|
+
)
|
|
179
|
+
channels = depth
|
|
180
|
+
if num_layers < 1:
|
|
181
|
+
raise RuntimeError("num_layers cannot be smaller than 1.")
|
|
182
|
+
|
|
183
|
+
super().__init__()
|
|
184
|
+
# Use explicit Linear if latent_dim provided, else LazyLinear for backward compat
|
|
185
|
+
linear_out = channels * 8 * 2 * 2
|
|
186
|
+
if latent_dim is not None:
|
|
187
|
+
first_linear = nn.Linear(latent_dim, linear_out, device=device)
|
|
188
|
+
else:
|
|
189
|
+
first_linear = nn.LazyLinear(linear_out, device=device)
|
|
190
|
+
self.state_to_latent = nn.Sequential(
|
|
191
|
+
first_linear,
|
|
192
|
+
nn.ReLU(),
|
|
193
|
+
)
|
|
194
|
+
if kernel_sizes is None and num_layers == 4:
|
|
195
|
+
kernel_sizes = [5, 5, 6, 6]
|
|
196
|
+
elif kernel_sizes is None:
|
|
197
|
+
kernel_sizes = 5
|
|
198
|
+
if isinstance(kernel_sizes, int):
|
|
199
|
+
kernel_sizes = [kernel_sizes] * num_layers
|
|
200
|
+
layers = [
|
|
201
|
+
nn.ReLU(),
|
|
202
|
+
nn.ConvTranspose2d(channels, 3, kernel_sizes[-1], stride=2, device=device),
|
|
203
|
+
]
|
|
204
|
+
kernel_sizes = kernel_sizes[:-1]
|
|
205
|
+
k = 1
|
|
206
|
+
for j in range(1, num_layers):
|
|
207
|
+
if j != num_layers - 1:
|
|
208
|
+
layers = [
|
|
209
|
+
nn.ConvTranspose2d(
|
|
210
|
+
channels * k * 2,
|
|
211
|
+
channels * k,
|
|
212
|
+
kernel_sizes[-1],
|
|
213
|
+
stride=2,
|
|
214
|
+
device=device,
|
|
215
|
+
),
|
|
216
|
+
] + layers
|
|
217
|
+
kernel_sizes = kernel_sizes[:-1]
|
|
218
|
+
k = k * 2
|
|
219
|
+
layers = [nn.ReLU()] + layers
|
|
220
|
+
else:
|
|
221
|
+
# Use explicit ConvTranspose2d - input is always channels * 8 from state_to_latent
|
|
222
|
+
layers = [
|
|
223
|
+
nn.ConvTranspose2d(
|
|
224
|
+
linear_out,
|
|
225
|
+
channels * k,
|
|
226
|
+
kernel_sizes[-1],
|
|
227
|
+
stride=2,
|
|
228
|
+
device=device,
|
|
229
|
+
)
|
|
230
|
+
] + layers
|
|
231
|
+
|
|
232
|
+
self.decoder = nn.Sequential(*layers)
|
|
233
|
+
self._depth = channels
|
|
234
|
+
|
|
235
|
+
def forward(self, state, rnn_hidden):
|
|
236
|
+
latent = self.state_to_latent(torch.cat([state, rnn_hidden], dim=-1))
|
|
237
|
+
*batch_sizes, D = latent.shape
|
|
238
|
+
latent = latent.view(-1, D, 1, 1)
|
|
239
|
+
obs_decoded = self.decoder(latent)
|
|
240
|
+
_, C, H, W = obs_decoded.shape
|
|
241
|
+
obs_decoded = obs_decoded.view(*batch_sizes, C, H, W)
|
|
242
|
+
return obs_decoded
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class RSSMRollout(TensorDictModuleBase):
|
|
246
|
+
"""Rollout the RSSM network.
|
|
247
|
+
|
|
248
|
+
Given a set of encoded observations and actions, this module will rollout the RSSM network to compute all the intermediate
|
|
249
|
+
states and beliefs.
|
|
250
|
+
The previous posterior is used as the prior for the next time step.
|
|
251
|
+
The forward method returns a stack of all intermediate states and beliefs.
|
|
252
|
+
|
|
253
|
+
Reference: https://arxiv.org/abs/1811.04551
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
rssm_prior (TensorDictModule): Prior network.
|
|
257
|
+
rssm_posterior (TensorDictModule): Posterior network.
|
|
258
|
+
use_scan (bool, optional): If True, uses torch._higher_order_ops.scan for
|
|
259
|
+
the rollout loop. This is more torch.compile friendly but may have
|
|
260
|
+
different performance characteristics. Defaults to False.
|
|
261
|
+
compile_step (bool, optional): If True, compiles the individual step function.
|
|
262
|
+
Only used when use_scan=False. Defaults to False.
|
|
263
|
+
compile_backend (str, optional): Backend to use for compilation.
|
|
264
|
+
Defaults to "inductor".
|
|
265
|
+
compile_mode (str, optional): Mode to use for compilation.
|
|
266
|
+
Defaults to None (uses PyTorch default).
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
def __init__(
|
|
272
|
+
self,
|
|
273
|
+
rssm_prior: TensorDictModule,
|
|
274
|
+
rssm_posterior: TensorDictModule,
|
|
275
|
+
use_scan: bool = False,
|
|
276
|
+
compile_step: bool = False,
|
|
277
|
+
compile_backend: str = "inductor",
|
|
278
|
+
compile_mode: str | None = None,
|
|
279
|
+
):
|
|
280
|
+
super().__init__()
|
|
281
|
+
_module = TensorDictSequential(rssm_prior, rssm_posterior)
|
|
282
|
+
self.in_keys = _module.in_keys
|
|
283
|
+
self.out_keys = _module.out_keys
|
|
284
|
+
self.rssm_prior = rssm_prior
|
|
285
|
+
self.rssm_posterior = rssm_posterior
|
|
286
|
+
self.use_scan = use_scan
|
|
287
|
+
self.compile_step = compile_step
|
|
288
|
+
self.compile_backend = compile_backend
|
|
289
|
+
self.compile_mode = compile_mode
|
|
290
|
+
self._compiled_step = None
|
|
291
|
+
|
|
292
|
+
def _get_step_fn(self):
|
|
293
|
+
"""Get the step function, optionally compiled."""
|
|
294
|
+
if self.compile_step and self._compiled_step is None:
|
|
295
|
+
self._compiled_step = torch.compile(
|
|
296
|
+
self._step,
|
|
297
|
+
backend=self.compile_backend,
|
|
298
|
+
mode=self.compile_mode,
|
|
299
|
+
)
|
|
300
|
+
return self._compiled_step if self.compile_step else self._step
|
|
301
|
+
|
|
302
|
+
def _step(self, _tensordict):
|
|
303
|
+
"""Single RSSM step: prior + posterior."""
|
|
304
|
+
self.rssm_prior(_tensordict)
|
|
305
|
+
self.rssm_posterior(_tensordict)
|
|
306
|
+
return _tensordict
|
|
307
|
+
|
|
308
|
+
def forward(self, tensordict):
|
|
309
|
+
"""Runs a rollout of simulated transitions in the latent space given a sequence of actions and environment observations.
|
|
310
|
+
|
|
311
|
+
The rollout requires a belief and posterior state primer.
|
|
312
|
+
|
|
313
|
+
At each step, two probability distributions are built and sampled from:
|
|
314
|
+
|
|
315
|
+
- A prior distribution p(s_{t+1} | s_t, a_t, b_t) where b_t is a
|
|
316
|
+
deterministic transform of the form b_t(s_{t-1}, a_{t-1}). The
|
|
317
|
+
previous state s_t is sampled according to the posterior
|
|
318
|
+
distribution (see below), creating a chain of posterior-to-priors
|
|
319
|
+
that accumulates evidence to compute a prior distribution over
|
|
320
|
+
the current event distribution:
|
|
321
|
+
p(s_{t+1} s_t | o_t, a_t, s_{t-1}, a_{t-1}) = p(s_{t+1} | s_t, a_t, b_t) q(s_t | b_t, o_t)
|
|
322
|
+
|
|
323
|
+
- A posterior distribution of the form q(s_{t+1} | b_{t+1}, o_{t+1})
|
|
324
|
+
which amends to q(s_{t+1} | s_t, a_t, o_{t+1})
|
|
325
|
+
|
|
326
|
+
"""
|
|
327
|
+
if self.use_scan:
|
|
328
|
+
return self._forward_scan(tensordict)
|
|
329
|
+
return self._forward_loop(tensordict)
|
|
330
|
+
|
|
331
|
+
def _forward_loop(self, tensordict):
|
|
332
|
+
"""Traditional loop-based forward."""
|
|
333
|
+
tensordict_out = []
|
|
334
|
+
*batch, time_steps = tensordict.shape
|
|
335
|
+
|
|
336
|
+
update_values = tensordict.exclude(*self.out_keys).unbind(-1)
|
|
337
|
+
_tensordict = update_values[0]
|
|
338
|
+
step_fn = self._get_step_fn()
|
|
339
|
+
|
|
340
|
+
# Determine output keys from first timestep to ensure consistent stacking.
|
|
341
|
+
# Root state/belief may be added by carry_forward for t>0 but won't exist
|
|
342
|
+
# for t=0, so we use the original input structure as reference.
|
|
343
|
+
output_keys = list(
|
|
344
|
+
update_values[0].keys(include_nested=True, leaves_only=True)
|
|
345
|
+
) + list(self.out_keys)
|
|
346
|
+
|
|
347
|
+
for t in range(time_steps):
|
|
348
|
+
_tensordict = step_fn(_tensordict)
|
|
349
|
+
|
|
350
|
+
# Select consistent keys for stacking (excludes root state/belief
|
|
351
|
+
# that may have been added by carry_forward for t>0)
|
|
352
|
+
tensordict_out.append(_tensordict.select(*output_keys, strict=False))
|
|
353
|
+
if t < time_steps - 1:
|
|
354
|
+
# Propagate state/belief from ("next", ...) to root level for next iteration
|
|
355
|
+
# The posterior outputs ("next", "state") which should become "state" for t+1
|
|
356
|
+
# The prior outputs ("next", "belief") which should become "belief" for t+1
|
|
357
|
+
next_state = _tensordict.get(("next", "state"))
|
|
358
|
+
next_belief = _tensordict.get(("next", "belief"))
|
|
359
|
+
|
|
360
|
+
# Get next timestep's input data (action, encoded_latents, etc.)
|
|
361
|
+
_tensordict = update_values[t + 1]
|
|
362
|
+
|
|
363
|
+
# Set the propagated state/belief (overwriting original data's initial values)
|
|
364
|
+
_tensordict.set("state", next_state)
|
|
365
|
+
_tensordict.set("belief", next_belief)
|
|
366
|
+
|
|
367
|
+
out = torch.stack(tensordict_out, tensordict.ndim - 1)
|
|
368
|
+
return out
|
|
369
|
+
|
|
370
|
+
def _forward_scan(self, tensordict):
|
|
371
|
+
"""Scan-based forward using torch._higher_order_ops.scan.
|
|
372
|
+
|
|
373
|
+
This is more torch.compile friendly as it avoids Python control flow.
|
|
374
|
+
"""
|
|
375
|
+
from torch._higher_order_ops.scan import scan
|
|
376
|
+
|
|
377
|
+
*batch, time_steps = tensordict.shape
|
|
378
|
+
|
|
379
|
+
update_values = tensordict.exclude(*self.out_keys).unbind(-1)
|
|
380
|
+
init_td = update_values[0]
|
|
381
|
+
|
|
382
|
+
# Determine output keys from first timestep to ensure consistent stacking.
|
|
383
|
+
output_keys = list(
|
|
384
|
+
update_values[0].keys(include_nested=True, leaves_only=True)
|
|
385
|
+
) + list(self.out_keys)
|
|
386
|
+
|
|
387
|
+
# Stack the update values for scan input
|
|
388
|
+
stacked_updates = torch.stack(list(update_values), dim=0)
|
|
389
|
+
|
|
390
|
+
def scan_fn(carry, x):
|
|
391
|
+
# carry is the current tensordict with propagated state/belief
|
|
392
|
+
# x is the next timestep's input data (action, encoded_latents, etc.)
|
|
393
|
+
|
|
394
|
+
# Get propagated state/belief from previous step's output
|
|
395
|
+
next_state = carry.get(("next", "state"), None)
|
|
396
|
+
next_belief = carry.get(("next", "belief"), None)
|
|
397
|
+
|
|
398
|
+
# Start with next timestep's data
|
|
399
|
+
_td = x
|
|
400
|
+
|
|
401
|
+
# Propagate state/belief if available (not first step)
|
|
402
|
+
if next_state is not None:
|
|
403
|
+
_td.set("state", next_state)
|
|
404
|
+
if next_belief is not None:
|
|
405
|
+
_td.set("belief", next_belief)
|
|
406
|
+
|
|
407
|
+
# Run prior and posterior
|
|
408
|
+
self.rssm_prior(_td)
|
|
409
|
+
self.rssm_posterior(_td)
|
|
410
|
+
|
|
411
|
+
# Select consistent keys for stacking
|
|
412
|
+
output_td = _td.select(*output_keys, strict=False)
|
|
413
|
+
|
|
414
|
+
# Return output for stacking and full _td as carry for propagation
|
|
415
|
+
return _td, output_td
|
|
416
|
+
|
|
417
|
+
# Run scan
|
|
418
|
+
_, outputs = scan(scan_fn, [init_td], [stacked_updates])
|
|
419
|
+
|
|
420
|
+
# outputs is stacked along dim 0, move to time dimension
|
|
421
|
+
out = outputs.transpose(0, tensordict.ndim - 1)
|
|
422
|
+
return out
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
class RSSMPrior(nn.Module):
|
|
426
|
+
"""The prior network of the RSSM.
|
|
427
|
+
|
|
428
|
+
This network takes as input the previous state and belief and the current action.
|
|
429
|
+
It returns the next prior state and belief, as well as the parameters of the prior state distribution.
|
|
430
|
+
State is by construction stochastic and belief is deterministic. In "Dream to control", these are called "deterministic state " and "stochastic state", respectively.
|
|
431
|
+
|
|
432
|
+
Reference: https://arxiv.org/abs/1811.04551
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
action_spec (TensorSpec): Action spec.
|
|
436
|
+
hidden_dim (int, optional): Number of hidden units in the linear network. Input size of the recurrent network.
|
|
437
|
+
Defaults to 200.
|
|
438
|
+
rnn_hidden_dim (int, optional): Number of hidden units in the recurrent network. Also size of the belief.
|
|
439
|
+
Defaults to 200.
|
|
440
|
+
state_dim (int, optional): Size of the state.
|
|
441
|
+
Defaults to 30.
|
|
442
|
+
scale_lb (:obj:`float`, optional): Lower bound of the scale of the state distribution.
|
|
443
|
+
Defaults to 0.1.
|
|
444
|
+
action_dim (int, optional): Dimension of the action. If provided along with state_dim,
|
|
445
|
+
uses explicit Linear instead of LazyLinear. Defaults to None for backward compatibility.
|
|
446
|
+
device (torch.device, optional): Device to create the module on.
|
|
447
|
+
Defaults to None (uses default device).
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
"""
|
|
451
|
+
|
|
452
|
+
def __init__(
|
|
453
|
+
self,
|
|
454
|
+
action_spec,
|
|
455
|
+
hidden_dim=200,
|
|
456
|
+
rnn_hidden_dim=200,
|
|
457
|
+
state_dim=30,
|
|
458
|
+
scale_lb=0.1,
|
|
459
|
+
action_dim=None,
|
|
460
|
+
device=None,
|
|
461
|
+
):
|
|
462
|
+
super().__init__()
|
|
463
|
+
|
|
464
|
+
# Prior - use explicit Linear if action_dim provided, else LazyLinear
|
|
465
|
+
self.rnn = GRUCell(hidden_dim, rnn_hidden_dim, device=device)
|
|
466
|
+
if action_dim is not None:
|
|
467
|
+
projector_in = state_dim + action_dim
|
|
468
|
+
first_linear = nn.Linear(projector_in, hidden_dim, device=device)
|
|
469
|
+
else:
|
|
470
|
+
first_linear = nn.LazyLinear(hidden_dim, device=device)
|
|
471
|
+
self.action_state_projector = nn.Sequential(first_linear, nn.ELU())
|
|
472
|
+
self.rnn_to_prior_projector = nn.Sequential(
|
|
473
|
+
nn.Linear(hidden_dim, hidden_dim, device=device),
|
|
474
|
+
nn.ELU(),
|
|
475
|
+
nn.Linear(hidden_dim, 2 * state_dim, device=device),
|
|
476
|
+
NormalParamExtractor(
|
|
477
|
+
scale_lb=scale_lb,
|
|
478
|
+
scale_mapping="softplus",
|
|
479
|
+
),
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
self.state_dim = state_dim
|
|
483
|
+
self.rnn_hidden_dim = rnn_hidden_dim
|
|
484
|
+
self.action_shape = action_spec.shape
|
|
485
|
+
|
|
486
|
+
def forward(self, state, belief, action, noise=None):
|
|
487
|
+
"""Forward pass through the prior network.
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
state: Previous stochastic state.
|
|
491
|
+
belief: Previous deterministic belief.
|
|
492
|
+
action: Action to condition on.
|
|
493
|
+
noise: Optional pre-sampled noise for the prior state.
|
|
494
|
+
If None, samples from standard normal. Used for deterministic testing.
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
Tuple of (prior_mean, prior_std, state, belief).
|
|
498
|
+
"""
|
|
499
|
+
projector_input = torch.cat([state, action], dim=-1)
|
|
500
|
+
action_state = self.action_state_projector(projector_input)
|
|
501
|
+
unsqueeze = False
|
|
502
|
+
if UNSQUEEZE_RNN_INPUT and action_state.ndimension() == 1:
|
|
503
|
+
if belief is not None:
|
|
504
|
+
belief = belief.unsqueeze(0)
|
|
505
|
+
action_state = action_state.unsqueeze(0)
|
|
506
|
+
unsqueeze = True
|
|
507
|
+
|
|
508
|
+
# GRUCell can have issues with bfloat16 autocast on some GPU/cuBLAS combinations.
|
|
509
|
+
# Run the RNN in full precision to avoid CUBLAS_STATUS_INVALID_VALUE errors.
|
|
510
|
+
dtype = action_state.dtype
|
|
511
|
+
device_type = action_state.device.type
|
|
512
|
+
with torch.amp.autocast(device_type=device_type, enabled=False):
|
|
513
|
+
belief = self.rnn(
|
|
514
|
+
action_state.float(), belief.float() if belief is not None else None
|
|
515
|
+
)
|
|
516
|
+
belief = belief.to(dtype)
|
|
517
|
+
if unsqueeze:
|
|
518
|
+
belief = belief.squeeze(0)
|
|
519
|
+
|
|
520
|
+
prior_mean, prior_std = self.rnn_to_prior_projector(belief)
|
|
521
|
+
if noise is None:
|
|
522
|
+
noise = torch.randn_like(prior_std)
|
|
523
|
+
state = prior_mean + noise * prior_std
|
|
524
|
+
return prior_mean, prior_std, state, belief
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
class RSSMPosterior(nn.Module):
|
|
528
|
+
"""The posterior network of the RSSM.
|
|
529
|
+
|
|
530
|
+
This network takes as input the belief and the associated encoded observation.
|
|
531
|
+
It returns the parameters of the posterior as well as a state sampled according to this distribution.
|
|
532
|
+
|
|
533
|
+
Reference: https://arxiv.org/abs/1811.04551
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
hidden_dim (int, optional): Number of hidden units in the linear network.
|
|
537
|
+
Defaults to 200.
|
|
538
|
+
state_dim (int, optional): Size of the state.
|
|
539
|
+
Defaults to 30.
|
|
540
|
+
scale_lb (:obj:`float`, optional): Lower bound of the scale of the state distribution.
|
|
541
|
+
Defaults to 0.1.
|
|
542
|
+
rnn_hidden_dim (int, optional): Dimension of the belief/rnn hidden state.
|
|
543
|
+
If provided along with obs_embed_dim, uses explicit Linear. Defaults to None.
|
|
544
|
+
obs_embed_dim (int, optional): Dimension of the observation embedding.
|
|
545
|
+
If provided along with rnn_hidden_dim, uses explicit Linear. Defaults to None.
|
|
546
|
+
device (torch.device, optional): Device to create the module on.
|
|
547
|
+
Defaults to None (uses default device).
|
|
548
|
+
|
|
549
|
+
"""
|
|
550
|
+
|
|
551
|
+
def __init__(
|
|
552
|
+
self,
|
|
553
|
+
hidden_dim=200,
|
|
554
|
+
state_dim=30,
|
|
555
|
+
scale_lb=0.1,
|
|
556
|
+
rnn_hidden_dim=None,
|
|
557
|
+
obs_embed_dim=None,
|
|
558
|
+
device=None,
|
|
559
|
+
):
|
|
560
|
+
super().__init__()
|
|
561
|
+
# Use explicit Linear if both dims provided, else LazyLinear for backward compat
|
|
562
|
+
if rnn_hidden_dim is not None and obs_embed_dim is not None:
|
|
563
|
+
projector_in = rnn_hidden_dim + obs_embed_dim
|
|
564
|
+
first_linear = nn.Linear(projector_in, hidden_dim, device=device)
|
|
565
|
+
else:
|
|
566
|
+
first_linear = nn.LazyLinear(hidden_dim, device=device)
|
|
567
|
+
self.obs_rnn_to_post_projector = nn.Sequential(
|
|
568
|
+
first_linear,
|
|
569
|
+
nn.ELU(),
|
|
570
|
+
nn.Linear(hidden_dim, 2 * state_dim, device=device),
|
|
571
|
+
NormalParamExtractor(
|
|
572
|
+
scale_lb=scale_lb,
|
|
573
|
+
scale_mapping="softplus",
|
|
574
|
+
),
|
|
575
|
+
)
|
|
576
|
+
self.hidden_dim = hidden_dim
|
|
577
|
+
|
|
578
|
+
def forward(self, belief, obs_embedding, noise=None):
|
|
579
|
+
"""Forward pass through the posterior network.
|
|
580
|
+
|
|
581
|
+
Args:
|
|
582
|
+
belief: Deterministic belief from the prior.
|
|
583
|
+
obs_embedding: Encoded observation.
|
|
584
|
+
noise: Optional pre-sampled noise for the posterior state.
|
|
585
|
+
If None, samples from standard normal. Used for deterministic testing.
|
|
586
|
+
|
|
587
|
+
Returns:
|
|
588
|
+
Tuple of (posterior_mean, posterior_std, state).
|
|
589
|
+
"""
|
|
590
|
+
posterior_mean, posterior_std = self.obs_rnn_to_post_projector(
|
|
591
|
+
torch.cat([belief, obs_embedding], dim=-1)
|
|
592
|
+
)
|
|
593
|
+
if noise is None:
|
|
594
|
+
noise = torch.randn_like(posterior_std)
|
|
595
|
+
state = posterior_mean + noise * posterior_std
|
|
596
|
+
return posterior_mean, posterior_std, state
|