torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,1107 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
import tempfile
|
|
9
|
+
from contextlib import nullcontext
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
from tensordict import NestedKey
|
|
14
|
+
from tensordict.nn import (
|
|
15
|
+
InteractionType,
|
|
16
|
+
ProbabilisticTensorDictModule,
|
|
17
|
+
ProbabilisticTensorDictSequential,
|
|
18
|
+
TensorDictModule,
|
|
19
|
+
TensorDictSequential,
|
|
20
|
+
)
|
|
21
|
+
from torchrl import logger as torchrl_logger
|
|
22
|
+
from torchrl._utils import set_profiling_enabled
|
|
23
|
+
from torchrl.collectors import MultiCollector
|
|
24
|
+
|
|
25
|
+
from torchrl.data import (
|
|
26
|
+
Composite,
|
|
27
|
+
LazyMemmapStorage,
|
|
28
|
+
SliceSampler,
|
|
29
|
+
TensorDictReplayBuffer,
|
|
30
|
+
Unbounded,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
from torchrl.envs import (
|
|
34
|
+
Compose,
|
|
35
|
+
DMControlEnv,
|
|
36
|
+
DoubleToFloat,
|
|
37
|
+
DreamerDecoder,
|
|
38
|
+
DreamerEnv,
|
|
39
|
+
EnvCreator,
|
|
40
|
+
ExcludeTransform,
|
|
41
|
+
# ExcludeTransform,
|
|
42
|
+
FrameSkipTransform,
|
|
43
|
+
GrayScale,
|
|
44
|
+
GymEnv,
|
|
45
|
+
ParallelEnv,
|
|
46
|
+
RenameTransform,
|
|
47
|
+
Resize,
|
|
48
|
+
RewardSum,
|
|
49
|
+
set_gym_backend,
|
|
50
|
+
StepCounter,
|
|
51
|
+
TensorDictPrimer,
|
|
52
|
+
ToTensorImage,
|
|
53
|
+
TransformedEnv,
|
|
54
|
+
)
|
|
55
|
+
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
|
|
56
|
+
from torchrl.modules import (
|
|
57
|
+
AdditiveGaussianModule,
|
|
58
|
+
DreamerActor,
|
|
59
|
+
IndependentNormal,
|
|
60
|
+
MLP,
|
|
61
|
+
ObsDecoder,
|
|
62
|
+
ObsEncoder,
|
|
63
|
+
RSSMPosterior,
|
|
64
|
+
RSSMPrior,
|
|
65
|
+
RSSMRollout,
|
|
66
|
+
SafeModule,
|
|
67
|
+
SafeProbabilisticModule,
|
|
68
|
+
SafeProbabilisticTensorDictSequential,
|
|
69
|
+
SafeSequential,
|
|
70
|
+
TanhNormal,
|
|
71
|
+
WorldModelWrapper,
|
|
72
|
+
)
|
|
73
|
+
from torchrl.record import VideoRecorder
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def allocate_collector_devices(
|
|
77
|
+
num_collectors: int, training_device: torch.device
|
|
78
|
+
) -> list[torch.device]:
|
|
79
|
+
"""Allocate CUDA devices for collectors, reserving cuda:0 for training.
|
|
80
|
+
|
|
81
|
+
Device allocation strategy:
|
|
82
|
+
- Training always uses cuda:0
|
|
83
|
+
- Collectors use cuda:1, cuda:2, ..., cuda:N-1 if available
|
|
84
|
+
- If only 1 CUDA device, colocate training and inference on cuda:0
|
|
85
|
+
- If num_collectors >= num_cuda_devices, raise an exception
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
num_collectors: Number of collector workers requested
|
|
89
|
+
training_device: The device used for training (determines if CUDA is used)
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
List of devices for each collector worker
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
ValueError: If num_collectors >= num_cuda_devices (no device left for training)
|
|
96
|
+
"""
|
|
97
|
+
if training_device.type != "cuda":
|
|
98
|
+
# CPU training: all collectors on CPU
|
|
99
|
+
return [torch.device("cpu")] * num_collectors
|
|
100
|
+
|
|
101
|
+
num_cuda_devices = torch.cuda.device_count()
|
|
102
|
+
|
|
103
|
+
if num_cuda_devices == 0:
|
|
104
|
+
# No CUDA devices available, fall back to CPU
|
|
105
|
+
return [torch.device("cpu")] * num_collectors
|
|
106
|
+
|
|
107
|
+
if num_cuda_devices == 1:
|
|
108
|
+
# Single GPU: colocate training and inference
|
|
109
|
+
torchrl_logger.info(
|
|
110
|
+
f"Single CUDA device available. Colocating {num_collectors} collectors "
|
|
111
|
+
"with training on cuda:0"
|
|
112
|
+
)
|
|
113
|
+
return [torch.device("cuda:0")] * num_collectors
|
|
114
|
+
|
|
115
|
+
# Multiple GPUs available
|
|
116
|
+
# Reserve cuda:0 for training, use cuda:1..cuda:N-1 for inference
|
|
117
|
+
inference_devices = num_cuda_devices - 1 # Devices available for collectors
|
|
118
|
+
|
|
119
|
+
if num_collectors > inference_devices:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"Requested {num_collectors} collectors but only {inference_devices} "
|
|
122
|
+
f"CUDA devices available for inference (cuda:1 to cuda:{num_cuda_devices - 1}). "
|
|
123
|
+
f"cuda:0 is reserved for training. Either reduce num_collectors to "
|
|
124
|
+
f"{inference_devices} or add more GPUs."
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Distribute collectors across available inference devices (round-robin)
|
|
128
|
+
collector_devices = []
|
|
129
|
+
for i in range(num_collectors):
|
|
130
|
+
device_idx = (i % inference_devices) + 1 # +1 to skip cuda:0
|
|
131
|
+
collector_devices.append(torch.device(f"cuda:{device_idx}"))
|
|
132
|
+
|
|
133
|
+
device_str = ", ".join(str(d) for d in collector_devices)
|
|
134
|
+
torchrl_logger.info(
|
|
135
|
+
f"Allocated {num_collectors} collectors to devices: [{device_str}]. "
|
|
136
|
+
f"Training on cuda:0."
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
return collector_devices
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class DreamerProfiler:
|
|
143
|
+
"""Helper class for PyTorch profiling in Dreamer training.
|
|
144
|
+
|
|
145
|
+
Encapsulates profiler setup, stepping, and trace export logic.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
cfg: Hydra config with profiling section.
|
|
149
|
+
device: Training device (used to determine CUDA profiling).
|
|
150
|
+
pbar: Progress bar to update total when profiling.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
def __init__(self, cfg, device, pbar=None, *, compile_warmup: int = 0):
|
|
154
|
+
self.enabled = cfg.profiling.enabled
|
|
155
|
+
self.cfg = cfg
|
|
156
|
+
self.total_optim_steps = 0
|
|
157
|
+
self._profiler = None
|
|
158
|
+
self._stopped = False
|
|
159
|
+
self._compile_warmup = compile_warmup
|
|
160
|
+
|
|
161
|
+
# Enable detailed profiling instrumentation in torchrl when profiling
|
|
162
|
+
set_profiling_enabled(self.enabled)
|
|
163
|
+
|
|
164
|
+
if not self.enabled:
|
|
165
|
+
return
|
|
166
|
+
|
|
167
|
+
# Override total_optim_steps for profiling runs
|
|
168
|
+
torchrl_logger.info(
|
|
169
|
+
f"Profiling enabled: running {cfg.profiling.total_optim_steps} optim steps "
|
|
170
|
+
f"(skip_first={cfg.profiling.skip_first}, warmup={cfg.profiling.warmup_steps}, "
|
|
171
|
+
f"active={cfg.profiling.active_steps})"
|
|
172
|
+
)
|
|
173
|
+
if pbar is not None:
|
|
174
|
+
pbar.total = cfg.profiling.total_optim_steps
|
|
175
|
+
|
|
176
|
+
# Setup profiler schedule
|
|
177
|
+
# - skip_first: steps to skip entirely (no profiling)
|
|
178
|
+
# - warmup: steps to warm up profiler (data discarded)
|
|
179
|
+
# - active: steps to actually profile (data kept)
|
|
180
|
+
#
|
|
181
|
+
# When torch.compile is enabled via compile_with_warmup, the first `compile_warmup`
|
|
182
|
+
# calls run eagerly and the *next* call typically triggers compilation. Profiling
|
|
183
|
+
# these steps is usually undesirable because it captures compilation overhead and
|
|
184
|
+
# non-representative eager execution.
|
|
185
|
+
#
|
|
186
|
+
# Therefore we automatically extend skip_first by (compile_warmup + 1) optim steps.
|
|
187
|
+
extra_skip = self._compile_warmup + 1 if self._compile_warmup else 0
|
|
188
|
+
skip_first = cfg.profiling.skip_first + extra_skip
|
|
189
|
+
profiler_schedule = torch.profiler.schedule(
|
|
190
|
+
skip_first=skip_first,
|
|
191
|
+
wait=0,
|
|
192
|
+
warmup=cfg.profiling.warmup_steps,
|
|
193
|
+
active=cfg.profiling.active_steps,
|
|
194
|
+
repeat=1,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Determine profiler activities
|
|
198
|
+
activities = [torch.profiler.ProfilerActivity.CPU]
|
|
199
|
+
if cfg.profiling.profile_cuda and device.type == "cuda":
|
|
200
|
+
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
|
201
|
+
|
|
202
|
+
self._profiler = torch.profiler.profile(
|
|
203
|
+
activities=activities,
|
|
204
|
+
schedule=profiler_schedule,
|
|
205
|
+
on_trace_ready=torch.profiler.tensorboard_trace_handler("./profiler_logs")
|
|
206
|
+
if not cfg.profiling.trace_file
|
|
207
|
+
else None,
|
|
208
|
+
record_shapes=cfg.profiling.record_shapes,
|
|
209
|
+
profile_memory=cfg.profiling.profile_memory,
|
|
210
|
+
with_stack=cfg.profiling.with_stack,
|
|
211
|
+
with_flops=cfg.profiling.with_flops,
|
|
212
|
+
)
|
|
213
|
+
self._profiler.start()
|
|
214
|
+
|
|
215
|
+
def step(self) -> bool:
|
|
216
|
+
"""Step the profiler and check if profiling is complete.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
True if profiling is complete and training should exit.
|
|
220
|
+
"""
|
|
221
|
+
if not self.enabled or self._stopped:
|
|
222
|
+
return False
|
|
223
|
+
|
|
224
|
+
self.total_optim_steps += 1
|
|
225
|
+
self._profiler.step()
|
|
226
|
+
|
|
227
|
+
# Check if we should stop profiling
|
|
228
|
+
extra_skip = self._compile_warmup + 1 if self._compile_warmup else 0
|
|
229
|
+
target_steps = (
|
|
230
|
+
self.cfg.profiling.skip_first
|
|
231
|
+
+ extra_skip
|
|
232
|
+
+ self.cfg.profiling.warmup_steps
|
|
233
|
+
+ self.cfg.profiling.active_steps
|
|
234
|
+
)
|
|
235
|
+
if self.total_optim_steps >= target_steps:
|
|
236
|
+
torchrl_logger.info(
|
|
237
|
+
f"Profiling complete after {self.total_optim_steps} optim steps. "
|
|
238
|
+
f"Exporting trace to {self.cfg.profiling.trace_file}"
|
|
239
|
+
)
|
|
240
|
+
self._profiler.stop()
|
|
241
|
+
self._stopped = True
|
|
242
|
+
# Export trace if trace_file is set
|
|
243
|
+
if self.cfg.profiling.trace_file:
|
|
244
|
+
self._profiler.export_chrome_trace(self.cfg.profiling.trace_file)
|
|
245
|
+
return True
|
|
246
|
+
|
|
247
|
+
return False
|
|
248
|
+
|
|
249
|
+
def should_exit(self) -> bool:
|
|
250
|
+
"""Check if training loop should exit due to profiling completion."""
|
|
251
|
+
if not self.enabled:
|
|
252
|
+
return False
|
|
253
|
+
extra_skip = self._compile_warmup + 1 if self._compile_warmup else 0
|
|
254
|
+
target_steps = (
|
|
255
|
+
self.cfg.profiling.skip_first
|
|
256
|
+
+ extra_skip
|
|
257
|
+
+ self.cfg.profiling.warmup_steps
|
|
258
|
+
+ self.cfg.profiling.active_steps
|
|
259
|
+
)
|
|
260
|
+
return self.total_optim_steps >= target_steps
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def _make_env(cfg, device, from_pixels=False):
|
|
264
|
+
lib = cfg.env.backend
|
|
265
|
+
if lib in ("gym", "gymnasium"):
|
|
266
|
+
with set_gym_backend(lib):
|
|
267
|
+
env = GymEnv(
|
|
268
|
+
cfg.env.name,
|
|
269
|
+
device=device,
|
|
270
|
+
from_pixels=cfg.env.from_pixels or from_pixels,
|
|
271
|
+
pixels_only=cfg.env.from_pixels,
|
|
272
|
+
)
|
|
273
|
+
# Gym doesn't support native frame_skip, apply transform inside worker
|
|
274
|
+
if cfg.env.frame_skip > 1:
|
|
275
|
+
env = TransformedEnv(env, FrameSkipTransform(cfg.env.frame_skip))
|
|
276
|
+
elif lib == "dm_control":
|
|
277
|
+
env = DMControlEnv(
|
|
278
|
+
cfg.env.name,
|
|
279
|
+
cfg.env.task,
|
|
280
|
+
from_pixels=cfg.env.from_pixels or from_pixels,
|
|
281
|
+
pixels_only=cfg.env.from_pixels,
|
|
282
|
+
device=device,
|
|
283
|
+
frame_skip=cfg.env.frame_skip, # Native frame skip inside worker
|
|
284
|
+
)
|
|
285
|
+
else:
|
|
286
|
+
raise NotImplementedError(f"Unknown lib {lib}.")
|
|
287
|
+
default_dict = {
|
|
288
|
+
"state": Unbounded(shape=(cfg.networks.state_dim,)),
|
|
289
|
+
"belief": Unbounded(shape=(cfg.networks.rssm_hidden_dim,)),
|
|
290
|
+
}
|
|
291
|
+
env = env.append_transform(
|
|
292
|
+
TensorDictPrimer(random=False, default_value=0, **default_dict)
|
|
293
|
+
)
|
|
294
|
+
return env
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def transform_env(cfg, env):
|
|
298
|
+
if not isinstance(env, TransformedEnv):
|
|
299
|
+
env = TransformedEnv(env)
|
|
300
|
+
if cfg.env.from_pixels:
|
|
301
|
+
# transforms pixel from 0-255 to 0-1 (uint8 to float32)
|
|
302
|
+
env.append_transform(
|
|
303
|
+
RenameTransform(in_keys=["pixels"], out_keys=["pixels_int"])
|
|
304
|
+
)
|
|
305
|
+
env.append_transform(
|
|
306
|
+
ToTensorImage(from_int=True, in_keys=["pixels_int"], out_keys=["pixels"])
|
|
307
|
+
)
|
|
308
|
+
if cfg.env.grayscale:
|
|
309
|
+
env.append_transform(GrayScale())
|
|
310
|
+
|
|
311
|
+
image_size = cfg.env.image_size
|
|
312
|
+
env.append_transform(Resize(image_size, image_size))
|
|
313
|
+
|
|
314
|
+
env.append_transform(DoubleToFloat())
|
|
315
|
+
env.append_transform(RewardSum())
|
|
316
|
+
# Note: FrameSkipTransform is now applied inside workers (in _make_env) to avoid
|
|
317
|
+
# extra IPC round-trips. DMControl uses native frame_skip, Gym uses the transform.
|
|
318
|
+
env.append_transform(StepCounter(cfg.env.horizon))
|
|
319
|
+
|
|
320
|
+
return env
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def make_environments(cfg, parallel_envs=1, logger=None):
|
|
324
|
+
"""Make environments for training and evaluation.
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
train_env_factory: A callable that creates a training environment (for MultiCollector)
|
|
328
|
+
eval_env: The evaluation environment instance
|
|
329
|
+
"""
|
|
330
|
+
|
|
331
|
+
def train_env_factory():
|
|
332
|
+
"""Factory function for creating training environments."""
|
|
333
|
+
func = functools.partial(
|
|
334
|
+
_make_env, cfg=cfg, device=_default_device(cfg.env.device)
|
|
335
|
+
)
|
|
336
|
+
train_env = ParallelEnv(
|
|
337
|
+
parallel_envs,
|
|
338
|
+
EnvCreator(func),
|
|
339
|
+
serial_for_single=True,
|
|
340
|
+
)
|
|
341
|
+
train_env = transform_env(cfg, train_env)
|
|
342
|
+
train_env.set_seed(cfg.env.seed)
|
|
343
|
+
return train_env
|
|
344
|
+
|
|
345
|
+
# Create eval env directly (not a factory)
|
|
346
|
+
func = functools.partial(
|
|
347
|
+
_make_env,
|
|
348
|
+
cfg=cfg,
|
|
349
|
+
device=_default_device(cfg.env.device),
|
|
350
|
+
from_pixels=cfg.logger.video,
|
|
351
|
+
)
|
|
352
|
+
eval_env = ParallelEnv(
|
|
353
|
+
1,
|
|
354
|
+
EnvCreator(func),
|
|
355
|
+
serial_for_single=True,
|
|
356
|
+
)
|
|
357
|
+
eval_env = transform_env(cfg, eval_env)
|
|
358
|
+
eval_env.set_seed(cfg.env.seed + 1)
|
|
359
|
+
if cfg.logger.video:
|
|
360
|
+
eval_env.insert_transform(
|
|
361
|
+
0,
|
|
362
|
+
VideoRecorder(
|
|
363
|
+
logger,
|
|
364
|
+
tag="eval/video",
|
|
365
|
+
in_keys=["pixels"],
|
|
366
|
+
skip=cfg.logger.video_skip,
|
|
367
|
+
),
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
# Check specs on a temporary train env
|
|
371
|
+
temp_train_env = train_env_factory()
|
|
372
|
+
check_env_specs(temp_train_env)
|
|
373
|
+
temp_train_env.close()
|
|
374
|
+
del temp_train_env
|
|
375
|
+
|
|
376
|
+
check_env_specs(eval_env)
|
|
377
|
+
return train_env_factory, eval_env
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def dump_video(module, step: int | None = None):
|
|
381
|
+
"""Dump video from VideoRecorder transforms.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
module: The transform module to check.
|
|
385
|
+
step: Optional step to log the video at. If not provided,
|
|
386
|
+
the VideoRecorder uses its internal counter.
|
|
387
|
+
"""
|
|
388
|
+
if isinstance(module, VideoRecorder):
|
|
389
|
+
module.dump(step=step)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def _compute_encoder_output_size(image_size, channels=32, num_layers=4):
|
|
393
|
+
"""Compute the flattened output size of ObsEncoder."""
|
|
394
|
+
# Compute spatial size after each conv layer (kernel=4, stride=2)
|
|
395
|
+
size = image_size
|
|
396
|
+
for _ in range(num_layers):
|
|
397
|
+
size = (size - 4) // 2 + 1
|
|
398
|
+
# Final channels = channels * 2^(num_layers-1)
|
|
399
|
+
final_channels = channels * (2 ** (num_layers - 1))
|
|
400
|
+
return final_channels * size * size
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def make_dreamer(
|
|
404
|
+
cfg,
|
|
405
|
+
device,
|
|
406
|
+
action_key: str = "action",
|
|
407
|
+
value_key: str = "state_value",
|
|
408
|
+
use_decoder_in_env: bool = False,
|
|
409
|
+
compile: bool = True,
|
|
410
|
+
logger=None,
|
|
411
|
+
):
|
|
412
|
+
test_env = _make_env(cfg, device="cpu")
|
|
413
|
+
test_env = transform_env(cfg, test_env)
|
|
414
|
+
|
|
415
|
+
# Get dimensions for explicit module instantiation (avoids lazy modules)
|
|
416
|
+
state_dim = cfg.networks.state_dim
|
|
417
|
+
rssm_hidden_dim = cfg.networks.rssm_hidden_dim
|
|
418
|
+
action_dim = test_env.action_spec.shape[-1]
|
|
419
|
+
|
|
420
|
+
# Make encoder and decoder
|
|
421
|
+
if cfg.env.from_pixels:
|
|
422
|
+
# Determine input channels (1 for grayscale, 3 for RGB)
|
|
423
|
+
in_channels = 1 if cfg.env.grayscale else 3
|
|
424
|
+
image_size = cfg.env.image_size
|
|
425
|
+
|
|
426
|
+
# Compute encoder output size for explicit posterior input
|
|
427
|
+
obs_embed_dim = _compute_encoder_output_size(
|
|
428
|
+
image_size, channels=32, num_layers=4
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
encoder = ObsEncoder(in_channels=in_channels, device=device)
|
|
432
|
+
decoder = ObsDecoder(latent_dim=state_dim + rssm_hidden_dim, device=device)
|
|
433
|
+
|
|
434
|
+
observation_in_key = "pixels"
|
|
435
|
+
observation_out_key = "reco_pixels"
|
|
436
|
+
else:
|
|
437
|
+
obs_embed_dim = 1024 # MLP output size
|
|
438
|
+
encoder = MLP(
|
|
439
|
+
out_features=obs_embed_dim,
|
|
440
|
+
depth=2,
|
|
441
|
+
num_cells=cfg.networks.hidden_dim,
|
|
442
|
+
activation_class=get_activation(cfg.networks.activation),
|
|
443
|
+
device=device,
|
|
444
|
+
)
|
|
445
|
+
decoder = MLP(
|
|
446
|
+
out_features=test_env.observation_spec["observation"].shape[-1],
|
|
447
|
+
depth=2,
|
|
448
|
+
num_cells=cfg.networks.hidden_dim,
|
|
449
|
+
activation_class=get_activation(cfg.networks.activation),
|
|
450
|
+
device=device,
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
observation_in_key = "observation"
|
|
454
|
+
observation_out_key = "reco_observation"
|
|
455
|
+
|
|
456
|
+
# Make RSSM with explicit input sizes (no lazy modules)
|
|
457
|
+
rssm_prior = RSSMPrior(
|
|
458
|
+
hidden_dim=rssm_hidden_dim,
|
|
459
|
+
rnn_hidden_dim=rssm_hidden_dim,
|
|
460
|
+
state_dim=state_dim,
|
|
461
|
+
action_spec=test_env.action_spec,
|
|
462
|
+
action_dim=action_dim,
|
|
463
|
+
device=device,
|
|
464
|
+
)
|
|
465
|
+
rssm_posterior = RSSMPosterior(
|
|
466
|
+
hidden_dim=rssm_hidden_dim,
|
|
467
|
+
state_dim=state_dim,
|
|
468
|
+
rnn_hidden_dim=rssm_hidden_dim,
|
|
469
|
+
obs_embed_dim=obs_embed_dim,
|
|
470
|
+
device=device,
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
# When use_scan=True or rssm_rollout.compile=True, replace C++ GRU with Python-based GRU
|
|
474
|
+
# for torch.compile compatibility. The C++ GRU (cuBLAS) cannot be traced by torch.compile.
|
|
475
|
+
if cfg.networks.use_scan or cfg.networks.rssm_rollout.compile:
|
|
476
|
+
from torchrl.modules.tensordict_module.rnn import GRUCell as PythonGRUCell
|
|
477
|
+
|
|
478
|
+
old_rnn = rssm_prior.rnn
|
|
479
|
+
python_rnn = PythonGRUCell(
|
|
480
|
+
old_rnn.input_size, old_rnn.hidden_size, device=device
|
|
481
|
+
)
|
|
482
|
+
python_rnn.load_state_dict(old_rnn.state_dict())
|
|
483
|
+
rssm_prior.rnn = python_rnn
|
|
484
|
+
torchrl_logger.info(
|
|
485
|
+
"Switched RSSMPrior to Python-based GRU for torch.compile compatibility"
|
|
486
|
+
)
|
|
487
|
+
# Make reward module
|
|
488
|
+
reward_module = MLP(
|
|
489
|
+
out_features=1,
|
|
490
|
+
depth=2,
|
|
491
|
+
num_cells=cfg.networks.hidden_dim,
|
|
492
|
+
activation_class=get_activation(cfg.networks.activation),
|
|
493
|
+
device=device,
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
# Make combined world model (modules already on device)
|
|
497
|
+
world_model = _dreamer_make_world_model(
|
|
498
|
+
encoder,
|
|
499
|
+
decoder,
|
|
500
|
+
rssm_prior,
|
|
501
|
+
rssm_posterior,
|
|
502
|
+
reward_module,
|
|
503
|
+
observation_in_key=observation_in_key,
|
|
504
|
+
observation_out_key=observation_out_key,
|
|
505
|
+
use_scan=cfg.networks.use_scan,
|
|
506
|
+
rssm_rollout_compile=cfg.networks.rssm_rollout.compile,
|
|
507
|
+
rssm_rollout_compile_backend=cfg.networks.rssm_rollout.compile_backend,
|
|
508
|
+
rssm_rollout_compile_mode=cfg.networks.rssm_rollout.compile_mode,
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
# Initialize world model (already on device)
|
|
512
|
+
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
|
|
513
|
+
tensordict = (
|
|
514
|
+
test_env.rollout(5, auto_cast_to_device=True).unsqueeze(-1).to(device)
|
|
515
|
+
)
|
|
516
|
+
tensordict = tensordict.to_tensordict()
|
|
517
|
+
world_model(tensordict)
|
|
518
|
+
|
|
519
|
+
# Create model-based environment
|
|
520
|
+
model_based_env = _dreamer_make_mbenv(
|
|
521
|
+
reward_module=reward_module,
|
|
522
|
+
rssm_prior=rssm_prior,
|
|
523
|
+
decoder=decoder,
|
|
524
|
+
observation_out_key=observation_out_key,
|
|
525
|
+
test_env=test_env,
|
|
526
|
+
use_decoder_in_env=use_decoder_in_env,
|
|
527
|
+
state_dim=cfg.networks.state_dim,
|
|
528
|
+
rssm_hidden_dim=cfg.networks.rssm_hidden_dim,
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
# def detach_state_and_belief(data):
|
|
532
|
+
# data.set("state", data.get("state").detach())
|
|
533
|
+
# data.set("belief", data.get("belief").detach())
|
|
534
|
+
# return data
|
|
535
|
+
#
|
|
536
|
+
# model_based_env = model_based_env.append_transform(detach_state_and_belief)
|
|
537
|
+
check_env_specs(model_based_env)
|
|
538
|
+
|
|
539
|
+
# Make actor (modules already on device)
|
|
540
|
+
actor_simulator, actor_realworld = _dreamer_make_actors(
|
|
541
|
+
encoder=encoder,
|
|
542
|
+
observation_in_key=observation_in_key,
|
|
543
|
+
rssm_prior=rssm_prior,
|
|
544
|
+
rssm_posterior=rssm_posterior,
|
|
545
|
+
mlp_num_units=cfg.networks.hidden_dim,
|
|
546
|
+
activation=get_activation(cfg.networks.activation),
|
|
547
|
+
action_key=action_key,
|
|
548
|
+
test_env=test_env,
|
|
549
|
+
device=device,
|
|
550
|
+
)
|
|
551
|
+
# Exploration noise to be added to the actor_realworld
|
|
552
|
+
actor_realworld = TensorDictSequential(
|
|
553
|
+
actor_realworld,
|
|
554
|
+
AdditiveGaussianModule(
|
|
555
|
+
spec=test_env.action_spec,
|
|
556
|
+
sigma_init=1.0,
|
|
557
|
+
sigma_end=1.0,
|
|
558
|
+
annealing_num_steps=1,
|
|
559
|
+
mean=0.0,
|
|
560
|
+
std=cfg.networks.exploration_noise,
|
|
561
|
+
device=device,
|
|
562
|
+
),
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
# Make Critic (on device)
|
|
566
|
+
value_model = _dreamer_make_value_model(
|
|
567
|
+
hidden_dim=cfg.networks.hidden_dim,
|
|
568
|
+
activation=cfg.networks.activation,
|
|
569
|
+
value_key=value_key,
|
|
570
|
+
device=device,
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
# Move model_based_env to device (it contains references to modules already on device)
|
|
574
|
+
model_based_env.to(device)
|
|
575
|
+
|
|
576
|
+
# Initialize model-based environment, actor and critic
|
|
577
|
+
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
|
|
578
|
+
tensordict = (
|
|
579
|
+
model_based_env.fake_tensordict().unsqueeze(-1).to(value_model.device)
|
|
580
|
+
)
|
|
581
|
+
tensordict = tensordict
|
|
582
|
+
tensordict = actor_simulator(tensordict)
|
|
583
|
+
value_model(tensordict)
|
|
584
|
+
|
|
585
|
+
if cfg.logger.video:
|
|
586
|
+
model_based_env_eval = model_based_env.append_transform(DreamerDecoder())
|
|
587
|
+
|
|
588
|
+
def float_to_int(data):
|
|
589
|
+
reco_pixels_float = data.get("reco_pixels")
|
|
590
|
+
reco_pixels = (reco_pixels_float * 255).floor()
|
|
591
|
+
# assert (reco_pixels < 256).all() and (reco_pixels > 0).all(), (reco_pixels.min(), reco_pixels.max())
|
|
592
|
+
reco_pixels = reco_pixels.to(torch.uint8)
|
|
593
|
+
data.set("reco_pixels_float", reco_pixels_float)
|
|
594
|
+
return data.set("reco_pixels", reco_pixels)
|
|
595
|
+
|
|
596
|
+
model_based_env_eval.append_transform(float_to_int)
|
|
597
|
+
model_based_env_eval.append_transform(
|
|
598
|
+
VideoRecorder(
|
|
599
|
+
logger=logger,
|
|
600
|
+
tag="eval/simulated_video",
|
|
601
|
+
in_keys=["reco_pixels"],
|
|
602
|
+
skip=cfg.logger.video_skip,
|
|
603
|
+
)
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
else:
|
|
607
|
+
model_based_env_eval = None
|
|
608
|
+
return (
|
|
609
|
+
world_model,
|
|
610
|
+
model_based_env,
|
|
611
|
+
model_based_env_eval,
|
|
612
|
+
actor_simulator,
|
|
613
|
+
value_model,
|
|
614
|
+
actor_realworld,
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
def make_collector(
|
|
619
|
+
cfg,
|
|
620
|
+
train_env_factory,
|
|
621
|
+
actor_model_explore,
|
|
622
|
+
training_device: torch.device,
|
|
623
|
+
replay_buffer=None,
|
|
624
|
+
storage_transform=None,
|
|
625
|
+
track_policy_version=False,
|
|
626
|
+
):
|
|
627
|
+
"""Make async multi-collector for parallel data collection.
|
|
628
|
+
|
|
629
|
+
Args:
|
|
630
|
+
cfg: Configuration object
|
|
631
|
+
train_env_factory: A callable that creates a training environment
|
|
632
|
+
actor_model_explore: The exploration policy
|
|
633
|
+
training_device: Device used for training (used to allocate collector devices)
|
|
634
|
+
replay_buffer: Optional replay buffer for true async collection with start()
|
|
635
|
+
storage_transform: Optional transform to apply before storing in buffer
|
|
636
|
+
track_policy_version: If True, track policy version using integer versioning.
|
|
637
|
+
Can also be a PolicyVersion instance for custom versioning.
|
|
638
|
+
|
|
639
|
+
Returns:
|
|
640
|
+
MultiCollector in async mode with multiple worker processes
|
|
641
|
+
|
|
642
|
+
Device allocation:
|
|
643
|
+
- If training on CUDA with multiple GPUs: collectors use cuda:1, cuda:2, etc.
|
|
644
|
+
- If training on CUDA with single GPU: collectors colocate on cuda:0
|
|
645
|
+
- If training on CPU: collectors use CPU
|
|
646
|
+
"""
|
|
647
|
+
num_collectors = cfg.collector.num_collectors
|
|
648
|
+
init_random_frames = (
|
|
649
|
+
cfg.collector.init_random_frames
|
|
650
|
+
if not cfg.profiling.enabled
|
|
651
|
+
else cfg.profiling.collector.init_random_frames_override
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
# Allocate devices for collectors (reserves cuda:0 for training if multi-GPU)
|
|
655
|
+
collector_devices = allocate_collector_devices(num_collectors, training_device)
|
|
656
|
+
|
|
657
|
+
collector = MultiCollector(
|
|
658
|
+
create_env_fn=[train_env_factory] * num_collectors,
|
|
659
|
+
policy=actor_model_explore,
|
|
660
|
+
frames_per_batch=cfg.collector.frames_per_batch,
|
|
661
|
+
total_frames=-1, # Run indefinitely until async_shutdown() is called
|
|
662
|
+
init_random_frames=init_random_frames,
|
|
663
|
+
policy_device=collector_devices,
|
|
664
|
+
env_device=collector_devices, # Match env output device to policy device for CUDA transforms
|
|
665
|
+
storing_device="cpu",
|
|
666
|
+
sync=False, # Async mode for overlapping collection with training
|
|
667
|
+
update_at_each_batch=False, # We manually call update_policy_weights_() in training loop
|
|
668
|
+
replay_buffer=replay_buffer,
|
|
669
|
+
postproc=storage_transform,
|
|
670
|
+
track_policy_version=track_policy_version,
|
|
671
|
+
# Skip fake data initialization - storage handles coordination
|
|
672
|
+
local_init_rb=True,
|
|
673
|
+
)
|
|
674
|
+
collector.set_seed(cfg.env.seed)
|
|
675
|
+
|
|
676
|
+
return collector
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
def make_storage_transform(
|
|
680
|
+
*,
|
|
681
|
+
pixel_obs=True,
|
|
682
|
+
grayscale=True,
|
|
683
|
+
image_size,
|
|
684
|
+
):
|
|
685
|
+
"""Create transforms to be applied at extend-time (once per frame).
|
|
686
|
+
|
|
687
|
+
These heavy transforms (ToTensorImage, GrayScale, Resize) are applied once
|
|
688
|
+
when data is added to the buffer, rather than on every sample.
|
|
689
|
+
"""
|
|
690
|
+
if not pixel_obs:
|
|
691
|
+
return None
|
|
692
|
+
|
|
693
|
+
storage_transforms = Compose(
|
|
694
|
+
ExcludeTransform("pixels", ("next", "pixels"), inverse=True),
|
|
695
|
+
ToTensorImage(
|
|
696
|
+
in_keys=["pixels_int", ("next", "pixels_int")],
|
|
697
|
+
out_keys=["pixels", ("next", "pixels")],
|
|
698
|
+
),
|
|
699
|
+
)
|
|
700
|
+
if grayscale:
|
|
701
|
+
storage_transforms.append(GrayScale(in_keys=["pixels", ("next", "pixels")]))
|
|
702
|
+
storage_transforms.append(
|
|
703
|
+
Resize(image_size, image_size, in_keys=["pixels", ("next", "pixels")])
|
|
704
|
+
)
|
|
705
|
+
return storage_transforms
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
def _to_device(td, device):
|
|
709
|
+
return td.to(device=device, non_blocking=True)
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
def make_replay_buffer(
|
|
713
|
+
*,
|
|
714
|
+
batch_size,
|
|
715
|
+
batch_seq_len,
|
|
716
|
+
buffer_size=1000000,
|
|
717
|
+
buffer_scratch_dir=None,
|
|
718
|
+
device=None,
|
|
719
|
+
prefetch=8,
|
|
720
|
+
pixel_obs=True,
|
|
721
|
+
grayscale=True,
|
|
722
|
+
image_size,
|
|
723
|
+
):
|
|
724
|
+
"""Create replay buffer with minimal sample-time transforms.
|
|
725
|
+
|
|
726
|
+
Heavy image transforms are expected to be applied at extend-time using
|
|
727
|
+
make_storage_transform(). Only DeviceCastTransform is applied at sample-time.
|
|
728
|
+
|
|
729
|
+
Note: We don't compile the SliceSampler because:
|
|
730
|
+
1. Sampler operations (index computation) happen on CPU and are already fast
|
|
731
|
+
2. torch.compile with inductor has bugs with the sampler's vectorized int64 operations
|
|
732
|
+
"""
|
|
733
|
+
with (
|
|
734
|
+
tempfile.TemporaryDirectory()
|
|
735
|
+
if buffer_scratch_dir is None
|
|
736
|
+
else nullcontext(buffer_scratch_dir)
|
|
737
|
+
) as scratch_dir:
|
|
738
|
+
# Sample-time transforms: only device transfer (fast)
|
|
739
|
+
sample_transforms = Compose(
|
|
740
|
+
functools.partial(_to_device, device=device),
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
replay_buffer = TensorDictReplayBuffer(
|
|
744
|
+
pin_memory=False,
|
|
745
|
+
prefetch=prefetch,
|
|
746
|
+
storage=LazyMemmapStorage(
|
|
747
|
+
buffer_size,
|
|
748
|
+
scratch_dir=scratch_dir,
|
|
749
|
+
device="cpu",
|
|
750
|
+
ndim=2,
|
|
751
|
+
shared_init=True, # Allow remote processes to initialize storage
|
|
752
|
+
),
|
|
753
|
+
sampler=SliceSampler(
|
|
754
|
+
slice_len=batch_seq_len,
|
|
755
|
+
strict_length=False,
|
|
756
|
+
traj_key=("collector", "traj_ids"),
|
|
757
|
+
cache_values=False, # Disabled for async collection (cache not synced across processes)
|
|
758
|
+
# Don't compile the sampler - inductor has C++ codegen bugs for int64 ops
|
|
759
|
+
),
|
|
760
|
+
transform=sample_transforms,
|
|
761
|
+
batch_size=batch_size,
|
|
762
|
+
)
|
|
763
|
+
return replay_buffer
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
def _dreamer_make_value_model(
|
|
767
|
+
hidden_dim: int = 400,
|
|
768
|
+
activation: str = "elu",
|
|
769
|
+
value_key: str = "state_value",
|
|
770
|
+
device=None,
|
|
771
|
+
):
|
|
772
|
+
value_model = MLP(
|
|
773
|
+
out_features=1,
|
|
774
|
+
depth=3,
|
|
775
|
+
num_cells=hidden_dim,
|
|
776
|
+
activation_class=get_activation(activation),
|
|
777
|
+
device=device,
|
|
778
|
+
)
|
|
779
|
+
value_model = ProbabilisticTensorDictSequential(
|
|
780
|
+
TensorDictModule(
|
|
781
|
+
value_model,
|
|
782
|
+
in_keys=["state", "belief"],
|
|
783
|
+
out_keys=["loc"],
|
|
784
|
+
),
|
|
785
|
+
ProbabilisticTensorDictModule(
|
|
786
|
+
in_keys=["loc"],
|
|
787
|
+
out_keys=[value_key],
|
|
788
|
+
distribution_class=IndependentNormal,
|
|
789
|
+
distribution_kwargs={"scale": 1.0, "event_dim": 1},
|
|
790
|
+
),
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
return value_model
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
def _dreamer_make_actors(
|
|
797
|
+
encoder,
|
|
798
|
+
observation_in_key,
|
|
799
|
+
rssm_prior,
|
|
800
|
+
rssm_posterior,
|
|
801
|
+
mlp_num_units,
|
|
802
|
+
activation,
|
|
803
|
+
action_key,
|
|
804
|
+
test_env,
|
|
805
|
+
device=None,
|
|
806
|
+
):
|
|
807
|
+
actor_module = DreamerActor(
|
|
808
|
+
out_features=test_env.action_spec.shape[-1],
|
|
809
|
+
depth=3,
|
|
810
|
+
num_cells=mlp_num_units,
|
|
811
|
+
activation_class=activation,
|
|
812
|
+
device=device,
|
|
813
|
+
)
|
|
814
|
+
actor_simulator = _dreamer_make_actor_sim(action_key, test_env, actor_module)
|
|
815
|
+
actor_realworld = _dreamer_make_actor_real(
|
|
816
|
+
encoder,
|
|
817
|
+
observation_in_key,
|
|
818
|
+
rssm_prior,
|
|
819
|
+
rssm_posterior,
|
|
820
|
+
actor_module,
|
|
821
|
+
action_key,
|
|
822
|
+
test_env,
|
|
823
|
+
)
|
|
824
|
+
return actor_simulator, actor_realworld
|
|
825
|
+
|
|
826
|
+
|
|
827
|
+
def _dreamer_make_actor_sim(action_key, proof_environment, actor_module):
|
|
828
|
+
actor_simulator = SafeProbabilisticTensorDictSequential(
|
|
829
|
+
SafeModule(
|
|
830
|
+
actor_module,
|
|
831
|
+
in_keys=["state", "belief"],
|
|
832
|
+
out_keys=["loc", "scale"],
|
|
833
|
+
spec=Composite(
|
|
834
|
+
**{
|
|
835
|
+
"loc": Unbounded(
|
|
836
|
+
proof_environment.action_spec_unbatched.shape,
|
|
837
|
+
device=proof_environment.action_spec_unbatched.device,
|
|
838
|
+
),
|
|
839
|
+
"scale": Unbounded(
|
|
840
|
+
proof_environment.action_spec_unbatched.shape,
|
|
841
|
+
device=proof_environment.action_spec_unbatched.device,
|
|
842
|
+
),
|
|
843
|
+
}
|
|
844
|
+
),
|
|
845
|
+
),
|
|
846
|
+
SafeProbabilisticModule(
|
|
847
|
+
in_keys=["loc", "scale"],
|
|
848
|
+
out_keys=[action_key],
|
|
849
|
+
default_interaction_type=InteractionType.RANDOM,
|
|
850
|
+
distribution_class=TanhNormal,
|
|
851
|
+
distribution_kwargs={"tanh_loc": True},
|
|
852
|
+
spec=Composite(**{action_key: proof_environment.action_spec_unbatched}),
|
|
853
|
+
),
|
|
854
|
+
)
|
|
855
|
+
return actor_simulator
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
def _dreamer_make_actor_real(
|
|
859
|
+
encoder,
|
|
860
|
+
observation_in_key,
|
|
861
|
+
rssm_prior,
|
|
862
|
+
rssm_posterior,
|
|
863
|
+
actor_module,
|
|
864
|
+
action_key,
|
|
865
|
+
proof_environment,
|
|
866
|
+
):
|
|
867
|
+
# actor for real world: interacts with states ~ posterior
|
|
868
|
+
# Out actor differs from the original paper where first they compute prior and posterior and then act on it
|
|
869
|
+
# but we found that this approach worked better.
|
|
870
|
+
actor_realworld = SafeSequential(
|
|
871
|
+
SafeModule(
|
|
872
|
+
encoder,
|
|
873
|
+
in_keys=[observation_in_key],
|
|
874
|
+
out_keys=["encoded_latents"],
|
|
875
|
+
),
|
|
876
|
+
SafeModule(
|
|
877
|
+
rssm_posterior,
|
|
878
|
+
in_keys=["belief", "encoded_latents"],
|
|
879
|
+
out_keys=[
|
|
880
|
+
"_",
|
|
881
|
+
"_",
|
|
882
|
+
"state",
|
|
883
|
+
],
|
|
884
|
+
),
|
|
885
|
+
SafeProbabilisticTensorDictSequential(
|
|
886
|
+
SafeModule(
|
|
887
|
+
actor_module,
|
|
888
|
+
in_keys=["state", "belief"],
|
|
889
|
+
out_keys=["loc", "scale"],
|
|
890
|
+
spec=Composite(
|
|
891
|
+
**{
|
|
892
|
+
"loc": Unbounded(
|
|
893
|
+
proof_environment.action_spec_unbatched.shape,
|
|
894
|
+
),
|
|
895
|
+
"scale": Unbounded(
|
|
896
|
+
proof_environment.action_spec_unbatched.shape,
|
|
897
|
+
),
|
|
898
|
+
}
|
|
899
|
+
),
|
|
900
|
+
),
|
|
901
|
+
SafeProbabilisticModule(
|
|
902
|
+
in_keys=["loc", "scale"],
|
|
903
|
+
out_keys=[action_key],
|
|
904
|
+
default_interaction_type=InteractionType.DETERMINISTIC,
|
|
905
|
+
distribution_class=TanhNormal,
|
|
906
|
+
distribution_kwargs={"tanh_loc": True},
|
|
907
|
+
spec=proof_environment.full_action_spec_unbatched.to("cpu"),
|
|
908
|
+
),
|
|
909
|
+
),
|
|
910
|
+
SafeModule(
|
|
911
|
+
rssm_prior,
|
|
912
|
+
in_keys=["state", "belief", action_key],
|
|
913
|
+
out_keys=[
|
|
914
|
+
"_",
|
|
915
|
+
"_",
|
|
916
|
+
"_", # we don't need the prior state
|
|
917
|
+
("next", "belief"),
|
|
918
|
+
],
|
|
919
|
+
),
|
|
920
|
+
)
|
|
921
|
+
return actor_realworld
|
|
922
|
+
|
|
923
|
+
|
|
924
|
+
def _dreamer_make_mbenv(
|
|
925
|
+
reward_module,
|
|
926
|
+
rssm_prior,
|
|
927
|
+
test_env,
|
|
928
|
+
decoder,
|
|
929
|
+
observation_out_key: str = "reco_pixels",
|
|
930
|
+
use_decoder_in_env: bool = False,
|
|
931
|
+
state_dim: int = 30,
|
|
932
|
+
rssm_hidden_dim: int = 200,
|
|
933
|
+
):
|
|
934
|
+
# MB environment
|
|
935
|
+
if use_decoder_in_env:
|
|
936
|
+
mb_env_obs_decoder = SafeModule(
|
|
937
|
+
decoder,
|
|
938
|
+
in_keys=["state", "belief"],
|
|
939
|
+
out_keys=[observation_out_key],
|
|
940
|
+
)
|
|
941
|
+
else:
|
|
942
|
+
mb_env_obs_decoder = None
|
|
943
|
+
|
|
944
|
+
transition_model = SafeSequential(
|
|
945
|
+
SafeModule(
|
|
946
|
+
rssm_prior,
|
|
947
|
+
in_keys=["state", "belief", "action"],
|
|
948
|
+
out_keys=[
|
|
949
|
+
"_",
|
|
950
|
+
"_",
|
|
951
|
+
"state",
|
|
952
|
+
"belief",
|
|
953
|
+
],
|
|
954
|
+
),
|
|
955
|
+
)
|
|
956
|
+
|
|
957
|
+
reward_model = SafeProbabilisticTensorDictSequential(
|
|
958
|
+
SafeModule(
|
|
959
|
+
reward_module,
|
|
960
|
+
in_keys=["state", "belief"],
|
|
961
|
+
out_keys=["loc"],
|
|
962
|
+
),
|
|
963
|
+
SafeProbabilisticModule(
|
|
964
|
+
in_keys=["loc"],
|
|
965
|
+
out_keys=["reward"],
|
|
966
|
+
distribution_class=IndependentNormal,
|
|
967
|
+
distribution_kwargs={"scale": 1.0, "event_dim": 1},
|
|
968
|
+
),
|
|
969
|
+
)
|
|
970
|
+
|
|
971
|
+
model_based_env = DreamerEnv(
|
|
972
|
+
world_model=WorldModelWrapper(
|
|
973
|
+
transition_model,
|
|
974
|
+
reward_model,
|
|
975
|
+
),
|
|
976
|
+
prior_shape=torch.Size([state_dim]),
|
|
977
|
+
belief_shape=torch.Size([rssm_hidden_dim]),
|
|
978
|
+
obs_decoder=mb_env_obs_decoder,
|
|
979
|
+
)
|
|
980
|
+
|
|
981
|
+
model_based_env.set_specs_from_env(test_env)
|
|
982
|
+
return model_based_env
|
|
983
|
+
|
|
984
|
+
|
|
985
|
+
def _dreamer_make_world_model(
|
|
986
|
+
encoder,
|
|
987
|
+
decoder,
|
|
988
|
+
rssm_prior,
|
|
989
|
+
rssm_posterior,
|
|
990
|
+
reward_module,
|
|
991
|
+
observation_in_key: NestedKey = "pixels",
|
|
992
|
+
observation_out_key: NestedKey = "reco_pixels",
|
|
993
|
+
use_scan: bool = False,
|
|
994
|
+
rssm_rollout_compile: bool = False,
|
|
995
|
+
rssm_rollout_compile_backend: str = "inductor",
|
|
996
|
+
rssm_rollout_compile_mode: str | None = "reduce-overhead",
|
|
997
|
+
):
|
|
998
|
+
# World Model and reward model
|
|
999
|
+
# Note: in_keys uses dict form with out_to_in_map=True to map function args to tensordict keys.
|
|
1000
|
+
# {"noise": "prior_noise"} means: read "prior_noise" from tensordict, pass as `noise` kwarg.
|
|
1001
|
+
# With strict=False (default), missing noise keys pass None to the module.
|
|
1002
|
+
rssm_rollout = RSSMRollout(
|
|
1003
|
+
TensorDictModule(
|
|
1004
|
+
rssm_prior,
|
|
1005
|
+
in_keys={
|
|
1006
|
+
"state": "state",
|
|
1007
|
+
"belief": "belief",
|
|
1008
|
+
"action": "action",
|
|
1009
|
+
"noise": "prior_noise",
|
|
1010
|
+
},
|
|
1011
|
+
out_keys=[
|
|
1012
|
+
("next", "prior_mean"),
|
|
1013
|
+
("next", "prior_std"),
|
|
1014
|
+
"_",
|
|
1015
|
+
("next", "belief"),
|
|
1016
|
+
],
|
|
1017
|
+
out_to_in_map=True,
|
|
1018
|
+
),
|
|
1019
|
+
TensorDictModule(
|
|
1020
|
+
rssm_posterior,
|
|
1021
|
+
in_keys={
|
|
1022
|
+
"belief": ("next", "belief"),
|
|
1023
|
+
"obs_embedding": ("next", "encoded_latents"),
|
|
1024
|
+
"noise": "posterior_noise",
|
|
1025
|
+
},
|
|
1026
|
+
out_keys=[
|
|
1027
|
+
("next", "posterior_mean"),
|
|
1028
|
+
("next", "posterior_std"),
|
|
1029
|
+
("next", "state"),
|
|
1030
|
+
],
|
|
1031
|
+
out_to_in_map=True,
|
|
1032
|
+
),
|
|
1033
|
+
use_scan=use_scan,
|
|
1034
|
+
compile_step=rssm_rollout_compile,
|
|
1035
|
+
compile_backend=rssm_rollout_compile_backend,
|
|
1036
|
+
compile_mode=rssm_rollout_compile_mode,
|
|
1037
|
+
)
|
|
1038
|
+
event_dim = 3 if observation_out_key == "reco_pixels" else 1 # 3 for RGB
|
|
1039
|
+
decoder = ProbabilisticTensorDictSequential(
|
|
1040
|
+
TensorDictModule(
|
|
1041
|
+
decoder,
|
|
1042
|
+
in_keys=[("next", "state"), ("next", "belief")],
|
|
1043
|
+
out_keys=["loc"],
|
|
1044
|
+
),
|
|
1045
|
+
ProbabilisticTensorDictModule(
|
|
1046
|
+
in_keys=["loc"],
|
|
1047
|
+
out_keys=[("next", observation_out_key)],
|
|
1048
|
+
distribution_class=IndependentNormal,
|
|
1049
|
+
distribution_kwargs={"scale": 1.0, "event_dim": event_dim},
|
|
1050
|
+
),
|
|
1051
|
+
)
|
|
1052
|
+
|
|
1053
|
+
transition_model = TensorDictSequential(
|
|
1054
|
+
TensorDictModule(
|
|
1055
|
+
encoder,
|
|
1056
|
+
in_keys=[("next", observation_in_key)],
|
|
1057
|
+
out_keys=[("next", "encoded_latents")],
|
|
1058
|
+
),
|
|
1059
|
+
rssm_rollout,
|
|
1060
|
+
decoder,
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
reward_model = ProbabilisticTensorDictSequential(
|
|
1064
|
+
TensorDictModule(
|
|
1065
|
+
reward_module,
|
|
1066
|
+
in_keys=[("next", "state"), ("next", "belief")],
|
|
1067
|
+
out_keys=[("next", "loc")],
|
|
1068
|
+
),
|
|
1069
|
+
ProbabilisticTensorDictModule(
|
|
1070
|
+
in_keys=[("next", "loc")],
|
|
1071
|
+
out_keys=[("next", "reward")],
|
|
1072
|
+
distribution_class=IndependentNormal,
|
|
1073
|
+
distribution_kwargs={"scale": 1.0, "event_dim": 1},
|
|
1074
|
+
),
|
|
1075
|
+
)
|
|
1076
|
+
|
|
1077
|
+
world_model = WorldModelWrapper(
|
|
1078
|
+
transition_model,
|
|
1079
|
+
reward_model,
|
|
1080
|
+
)
|
|
1081
|
+
return world_model
|
|
1082
|
+
|
|
1083
|
+
|
|
1084
|
+
def log_metrics(logger, metrics, step):
|
|
1085
|
+
for metric_name, metric_value in metrics.items():
|
|
1086
|
+
logger.log_scalar(metric_name, metric_value, step)
|
|
1087
|
+
|
|
1088
|
+
|
|
1089
|
+
def get_activation(name):
|
|
1090
|
+
if name == "relu":
|
|
1091
|
+
return nn.ReLU
|
|
1092
|
+
elif name == "tanh":
|
|
1093
|
+
return nn.Tanh
|
|
1094
|
+
elif name == "leaky_relu":
|
|
1095
|
+
return nn.LeakyReLU
|
|
1096
|
+
elif name == "elu":
|
|
1097
|
+
return nn.ELU
|
|
1098
|
+
else:
|
|
1099
|
+
raise NotImplementedError
|
|
1100
|
+
|
|
1101
|
+
|
|
1102
|
+
def _default_device(device=None):
|
|
1103
|
+
if device in ("", None):
|
|
1104
|
+
if torch.cuda.is_available():
|
|
1105
|
+
return torch.device("cuda")
|
|
1106
|
+
return torch.device("cpu")
|
|
1107
|
+
return torch.device(device)
|