torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,1060 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from collections.abc import Callable, Sequence
|
|
8
|
+
|
|
9
|
+
from copy import copy
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from omegaconf import OmegaConf
|
|
13
|
+
from tensordict.nn import (
|
|
14
|
+
InteractionType,
|
|
15
|
+
ProbabilisticTensorDictSequential,
|
|
16
|
+
TensorDictModule,
|
|
17
|
+
TensorDictModuleWrapper,
|
|
18
|
+
)
|
|
19
|
+
from torch import distributions as d, nn, optim
|
|
20
|
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
21
|
+
|
|
22
|
+
from torchrl._utils import logger as torchrl_logger, VERBOSE
|
|
23
|
+
from torchrl.collectors import DataCollectorBase
|
|
24
|
+
from torchrl.data import (
|
|
25
|
+
LazyMemmapStorage,
|
|
26
|
+
MultiStep,
|
|
27
|
+
PrioritizedSampler,
|
|
28
|
+
RandomSampler,
|
|
29
|
+
ReplayBuffer,
|
|
30
|
+
TensorDictReplayBuffer,
|
|
31
|
+
)
|
|
32
|
+
from torchrl.data.utils import DEVICE_TYPING
|
|
33
|
+
from torchrl.envs import (
|
|
34
|
+
CatFrames,
|
|
35
|
+
CatTensors,
|
|
36
|
+
CenterCrop,
|
|
37
|
+
Compose,
|
|
38
|
+
DMControlEnv,
|
|
39
|
+
DoubleToFloat,
|
|
40
|
+
env_creator,
|
|
41
|
+
EnvBase,
|
|
42
|
+
EnvCreator,
|
|
43
|
+
FlattenObservation,
|
|
44
|
+
GrayScale,
|
|
45
|
+
gSDENoise,
|
|
46
|
+
GymEnv,
|
|
47
|
+
InitTracker,
|
|
48
|
+
NoopResetEnv,
|
|
49
|
+
ObservationNorm,
|
|
50
|
+
ParallelEnv,
|
|
51
|
+
Resize,
|
|
52
|
+
RewardScaling,
|
|
53
|
+
StepCounter,
|
|
54
|
+
ToTensorImage,
|
|
55
|
+
TransformedEnv,
|
|
56
|
+
VecNorm,
|
|
57
|
+
)
|
|
58
|
+
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
59
|
+
from torchrl.modules import (
|
|
60
|
+
ActorCriticOperator,
|
|
61
|
+
ActorValueOperator,
|
|
62
|
+
DdpgCnnActor,
|
|
63
|
+
DdpgCnnQNet,
|
|
64
|
+
MLP,
|
|
65
|
+
NoisyLinear,
|
|
66
|
+
NormalParamExtractor,
|
|
67
|
+
ProbabilisticActor,
|
|
68
|
+
SafeModule,
|
|
69
|
+
SafeSequential,
|
|
70
|
+
TanhNormal,
|
|
71
|
+
ValueOperator,
|
|
72
|
+
)
|
|
73
|
+
from torchrl.modules.distributions.continuous import SafeTanhTransform
|
|
74
|
+
from torchrl.modules.models.exploration import LazygSDEModule
|
|
75
|
+
from torchrl.objectives import HardUpdate, LossModule, SoftUpdate, TargetNetUpdater
|
|
76
|
+
from torchrl.objectives.deprecated import REDQLoss_deprecated
|
|
77
|
+
from torchrl.record.loggers import Logger
|
|
78
|
+
from torchrl.record.recorder import VideoRecorder
|
|
79
|
+
from torchrl.trainers.helpers import sync_async_collector, sync_sync_collector
|
|
80
|
+
from torchrl.trainers.trainers import (
|
|
81
|
+
BatchSubSampler,
|
|
82
|
+
ClearCudaCache,
|
|
83
|
+
CountFramesLog,
|
|
84
|
+
LogScalar,
|
|
85
|
+
LogValidationReward,
|
|
86
|
+
ReplayBufferTrainer,
|
|
87
|
+
RewardNormalizer,
|
|
88
|
+
Trainer,
|
|
89
|
+
UpdateWeights,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
LIBS = {
|
|
93
|
+
"gym": GymEnv,
|
|
94
|
+
"dm_control": DMControlEnv,
|
|
95
|
+
}
|
|
96
|
+
ACTIVATIONS = {
|
|
97
|
+
"elu": nn.ELU,
|
|
98
|
+
"tanh": nn.Tanh,
|
|
99
|
+
"relu": nn.ReLU,
|
|
100
|
+
}
|
|
101
|
+
OPTIMIZERS = {
|
|
102
|
+
"adam": optim.Adam,
|
|
103
|
+
"sgd": optim.SGD,
|
|
104
|
+
"adamax": optim.Adamax,
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def correct_for_frame_skip(cfg: DictConfig) -> DictConfig: # noqa: F821
|
|
109
|
+
"""Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip.
|
|
110
|
+
|
|
111
|
+
This is aimed at avoiding unknowingly over-sampling from the environment, i.e. targeting a total number of frames
|
|
112
|
+
of 1M but actually collecting frame_skip * 1M frames.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
cfg (DictConfig): DictConfig containing some frame-counting argument, including:
|
|
116
|
+
"max_frames_per_traj", "total_frames", "frames_per_batch", "record_frames", "annealing_frames",
|
|
117
|
+
"init_random_frames", "init_env_steps"
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
the input DictConfig, modified in-place.
|
|
121
|
+
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def _hasattr(field):
|
|
125
|
+
local_cfg = cfg
|
|
126
|
+
fields = field.split(".")
|
|
127
|
+
for f in fields:
|
|
128
|
+
if not hasattr(local_cfg, f):
|
|
129
|
+
return False
|
|
130
|
+
local_cfg = getattr(local_cfg, f)
|
|
131
|
+
else:
|
|
132
|
+
return True
|
|
133
|
+
|
|
134
|
+
def _getattr(field):
|
|
135
|
+
local_cfg = cfg
|
|
136
|
+
fields = field.split(".")
|
|
137
|
+
for f in fields:
|
|
138
|
+
local_cfg = getattr(local_cfg, f)
|
|
139
|
+
return local_cfg
|
|
140
|
+
|
|
141
|
+
def _setattr(field, val):
|
|
142
|
+
local_cfg = cfg
|
|
143
|
+
fields = field.split(".")
|
|
144
|
+
for f in fields[:-1]:
|
|
145
|
+
local_cfg = getattr(local_cfg, f)
|
|
146
|
+
setattr(local_cfg, field[-1], val)
|
|
147
|
+
|
|
148
|
+
# Adapt all frame counts wrt frame_skip
|
|
149
|
+
frame_skip = cfg.env.frame_skip
|
|
150
|
+
if frame_skip != 1:
|
|
151
|
+
fields = [
|
|
152
|
+
"collector.max_frames_per_traj",
|
|
153
|
+
"collector.total_frames",
|
|
154
|
+
"collector.frames_per_batch",
|
|
155
|
+
"logger.record_frames",
|
|
156
|
+
"exploration.annealing_frames",
|
|
157
|
+
"collector.init_random_frames",
|
|
158
|
+
"env.init_env_steps",
|
|
159
|
+
"env.noops",
|
|
160
|
+
]
|
|
161
|
+
for field in fields:
|
|
162
|
+
if _hasattr(cfg, field):
|
|
163
|
+
_setattr(field, _getattr(field) // frame_skip)
|
|
164
|
+
return cfg
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def make_trainer(
|
|
168
|
+
collector: DataCollectorBase,
|
|
169
|
+
loss_module: LossModule,
|
|
170
|
+
recorder: EnvBase | None,
|
|
171
|
+
target_net_updater: TargetNetUpdater | None,
|
|
172
|
+
policy_exploration: TensorDictModuleWrapper | TensorDictModule | None,
|
|
173
|
+
replay_buffer: ReplayBuffer | None,
|
|
174
|
+
logger: Logger | None,
|
|
175
|
+
cfg: DictConfig, # noqa: F821
|
|
176
|
+
) -> Trainer:
|
|
177
|
+
"""Creates a Trainer instance given its constituents.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
collector (DataCollectorBase): A data collector to be used to collect data.
|
|
181
|
+
loss_module (LossModule): A TorchRL loss module
|
|
182
|
+
recorder (EnvBase, optional): a recorder environment.
|
|
183
|
+
target_net_updater (TargetNetUpdater): A target network update object.
|
|
184
|
+
policy_exploration (TDModule or TensorDictModuleWrapper): a policy to be used for recording and exploration
|
|
185
|
+
updates (should be synced with the learnt policy).
|
|
186
|
+
replay_buffer (ReplayBuffer): a replay buffer to be used to collect data.
|
|
187
|
+
logger (Logger): a Logger to be used for logging.
|
|
188
|
+
cfg (DictConfig): a DictConfig containing the arguments of the script.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
A trainer built with the input objects. The optimizer is built by this helper function using the cfg provided.
|
|
192
|
+
|
|
193
|
+
Examples:
|
|
194
|
+
>>> import torch
|
|
195
|
+
>>> import tempfile
|
|
196
|
+
>>> from torchrl.trainers.loggers import TensorboardLogger
|
|
197
|
+
>>> from torchrl.trainers import Trainer
|
|
198
|
+
>>> from torchrl.envs import EnvCreator
|
|
199
|
+
>>> from torchrl.collectors import SyncDataCollector
|
|
200
|
+
>>> from torchrl.data import TensorDictReplayBuffer
|
|
201
|
+
>>> from torchrl.envs.libs.gym import GymEnv
|
|
202
|
+
>>> from torchrl.modules import TensorDictModuleWrapper, SafeModule, ValueOperator, EGreedyWrapper
|
|
203
|
+
>>> from torchrl.objectives.common import LossModule
|
|
204
|
+
>>> from torchrl.objectives.utils import TargetNetUpdater
|
|
205
|
+
>>> from torchrl.objectives import DDPGLoss
|
|
206
|
+
>>> env_maker = EnvCreator(lambda: GymEnv("Pendulum-v0"))
|
|
207
|
+
>>> env_proof = env_maker()
|
|
208
|
+
>>> obs_spec = env_proof.observation_spec
|
|
209
|
+
>>> action_spec = env_proof.action_spec
|
|
210
|
+
>>> net = torch.nn.Linear(env_proof.observation_spec.shape[-1], action_spec.shape[-1])
|
|
211
|
+
>>> net_value = torch.nn.Linear(env_proof.observation_spec.shape[-1], 1) # for the purpose of testing
|
|
212
|
+
>>> policy = SafeModule(action_spec, net, in_keys=["observation"], out_keys=["action"])
|
|
213
|
+
>>> value = ValueOperator(net_value, in_keys=["observation"], out_keys=["state_action_value"])
|
|
214
|
+
>>> collector = SyncDataCollector(env_maker, policy, total_frames=100)
|
|
215
|
+
>>> loss_module = DDPGLoss(policy, value, gamma=0.99)
|
|
216
|
+
>>> recorder = env_proof
|
|
217
|
+
>>> target_net_updater = None
|
|
218
|
+
>>> policy_exploration = EGreedyWrapper(policy)
|
|
219
|
+
>>> replay_buffer = TensorDictReplayBuffer()
|
|
220
|
+
>>> dir = tempfile.gettempdir()
|
|
221
|
+
>>> logger = TensorboardLogger(exp_name=dir)
|
|
222
|
+
>>> trainer = make_trainer(collector, loss_module, recorder, target_net_updater, policy_exploration,
|
|
223
|
+
... replay_buffer, logger)
|
|
224
|
+
>>> torchrl_logger.info(trainer)
|
|
225
|
+
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
optimizer = OPTIMIZERS[cfg.optim.optimizer](
|
|
229
|
+
loss_module.parameters(),
|
|
230
|
+
lr=cfg.optim.lr,
|
|
231
|
+
weight_decay=cfg.optim.weight_decay,
|
|
232
|
+
eps=cfg.optim.eps,
|
|
233
|
+
**OmegaConf.to_container(cfg.optim.kwargs),
|
|
234
|
+
)
|
|
235
|
+
device = next(loss_module.parameters()).device
|
|
236
|
+
if cfg.optim.lr_scheduler == "cosine":
|
|
237
|
+
optim_scheduler = CosineAnnealingLR(
|
|
238
|
+
optimizer,
|
|
239
|
+
T_max=int(
|
|
240
|
+
cfg.collector.total_frames
|
|
241
|
+
/ cfg.collector.frames_per_batch
|
|
242
|
+
* cfg.optim.steps_per_batch
|
|
243
|
+
),
|
|
244
|
+
)
|
|
245
|
+
elif cfg.optim.lr_scheduler == "":
|
|
246
|
+
optim_scheduler = None
|
|
247
|
+
else:
|
|
248
|
+
raise NotImplementedError(f"lr scheduler {cfg.optim.lr_scheduler}")
|
|
249
|
+
|
|
250
|
+
if VERBOSE:
|
|
251
|
+
torchrl_logger.info(
|
|
252
|
+
f"collector = {collector}; \n"
|
|
253
|
+
f"loss_module = {loss_module}; \n"
|
|
254
|
+
f"recorder = {recorder}; \n"
|
|
255
|
+
f"target_net_updater = {target_net_updater}; \n"
|
|
256
|
+
f"policy_exploration = {policy_exploration}; \n"
|
|
257
|
+
f"replay_buffer = {replay_buffer}; \n"
|
|
258
|
+
f"logger = {logger}; \n"
|
|
259
|
+
f"cfg = {cfg}; \n"
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
if logger is not None:
|
|
263
|
+
# log hyperparams
|
|
264
|
+
logger.log_hparams(cfg)
|
|
265
|
+
|
|
266
|
+
trainer = Trainer(
|
|
267
|
+
collector=collector,
|
|
268
|
+
frame_skip=cfg.env.frame_skip,
|
|
269
|
+
total_frames=cfg.collector.total_frames * cfg.env.frame_skip,
|
|
270
|
+
loss_module=loss_module,
|
|
271
|
+
optimizer=optimizer,
|
|
272
|
+
logger=logger,
|
|
273
|
+
optim_steps_per_batch=cfg.optim.steps_per_batch,
|
|
274
|
+
clip_grad_norm=cfg.optim.clip_grad_norm,
|
|
275
|
+
clip_norm=cfg.optim.clip_norm,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
if torch.cuda.device_count() > 0:
|
|
279
|
+
trainer.register_op("pre_optim_steps", ClearCudaCache(1))
|
|
280
|
+
|
|
281
|
+
trainer.register_op("batch_process", lambda batch: batch.cpu())
|
|
282
|
+
|
|
283
|
+
if replay_buffer is not None:
|
|
284
|
+
# replay buffer is used 2 or 3 times: to register data, to sample
|
|
285
|
+
# data and to update priorities
|
|
286
|
+
rb_trainer = ReplayBufferTrainer(
|
|
287
|
+
replay_buffer,
|
|
288
|
+
cfg.buffer.batch_size,
|
|
289
|
+
flatten_tensordicts=True,
|
|
290
|
+
memmap=False,
|
|
291
|
+
device=device,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
trainer.register_op("batch_process", rb_trainer.extend)
|
|
295
|
+
trainer.register_op("process_optim_batch", rb_trainer.sample)
|
|
296
|
+
trainer.register_op("post_loss", rb_trainer.update_priority)
|
|
297
|
+
else:
|
|
298
|
+
# trainer.register_op("batch_process", mask_batch)
|
|
299
|
+
trainer.register_op(
|
|
300
|
+
"process_optim_batch",
|
|
301
|
+
BatchSubSampler(
|
|
302
|
+
batch_size=cfg.buffer.batch_size, sub_traj_len=cfg.buffer.sub_traj_len
|
|
303
|
+
),
|
|
304
|
+
)
|
|
305
|
+
trainer.register_op("process_optim_batch", lambda batch: batch.to(device))
|
|
306
|
+
|
|
307
|
+
if optim_scheduler is not None:
|
|
308
|
+
trainer.register_op("post_optim", optim_scheduler.step)
|
|
309
|
+
|
|
310
|
+
if target_net_updater is not None:
|
|
311
|
+
trainer.register_op("post_optim", target_net_updater.step)
|
|
312
|
+
|
|
313
|
+
if cfg.env.normalize_rewards_online:
|
|
314
|
+
# if used the running statistics of the rewards are computed and the
|
|
315
|
+
# rewards used for training will be normalized based on these.
|
|
316
|
+
reward_normalizer = RewardNormalizer(
|
|
317
|
+
scale=cfg.env.normalize_rewards_online_scale,
|
|
318
|
+
decay=cfg.env.normalize_rewards_online_decay,
|
|
319
|
+
)
|
|
320
|
+
trainer.register_op("batch_process", reward_normalizer.update_reward_stats)
|
|
321
|
+
trainer.register_op("process_optim_batch", reward_normalizer.normalize_reward)
|
|
322
|
+
|
|
323
|
+
if policy_exploration is not None and hasattr(policy_exploration, "step"):
|
|
324
|
+
trainer.register_op(
|
|
325
|
+
"post_steps", policy_exploration.step, frames=cfg.collector.frames_per_batch
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
trainer.register_op(
|
|
329
|
+
"post_steps_log", lambda *cfg: {"lr": optimizer.param_groups[0]["lr"]}
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
if recorder is not None:
|
|
333
|
+
# create recorder object
|
|
334
|
+
recorder_obj = LogValidationReward(
|
|
335
|
+
record_frames=cfg.logger.record_frames,
|
|
336
|
+
frame_skip=cfg.env.frame_skip,
|
|
337
|
+
policy_exploration=policy_exploration,
|
|
338
|
+
environment=recorder,
|
|
339
|
+
record_interval=cfg.logger.record_interval,
|
|
340
|
+
log_keys=cfg.logger.recorder_log_keys,
|
|
341
|
+
)
|
|
342
|
+
# register recorder
|
|
343
|
+
trainer.register_op(
|
|
344
|
+
"post_steps_log",
|
|
345
|
+
recorder_obj,
|
|
346
|
+
)
|
|
347
|
+
# call recorder - could be removed
|
|
348
|
+
recorder_obj(None)
|
|
349
|
+
# create explorative recorder - could be optional
|
|
350
|
+
recorder_obj_explore = LogValidationReward(
|
|
351
|
+
record_frames=cfg.logger.record_frames,
|
|
352
|
+
frame_skip=cfg.env.frame_skip,
|
|
353
|
+
policy_exploration=policy_exploration,
|
|
354
|
+
environment=recorder,
|
|
355
|
+
record_interval=cfg.logger.record_interval,
|
|
356
|
+
exploration_type=ExplorationType.RANDOM,
|
|
357
|
+
suffix="exploration",
|
|
358
|
+
out_keys={("next", "reward"): "r_evaluation_exploration"},
|
|
359
|
+
)
|
|
360
|
+
# register recorder
|
|
361
|
+
trainer.register_op(
|
|
362
|
+
"post_steps_log",
|
|
363
|
+
recorder_obj_explore,
|
|
364
|
+
)
|
|
365
|
+
# call recorder - could be removed
|
|
366
|
+
recorder_obj_explore(None)
|
|
367
|
+
|
|
368
|
+
trainer.register_op(
|
|
369
|
+
"post_steps", UpdateWeights(collector, update_weights_interval=1)
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
trainer.register_op("pre_steps_log", LogScalar())
|
|
373
|
+
trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.env.frame_skip))
|
|
374
|
+
|
|
375
|
+
return trainer
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def make_redq_model(
|
|
379
|
+
proof_environment: EnvBase,
|
|
380
|
+
cfg: DictConfig, # noqa: F821
|
|
381
|
+
device: DEVICE_TYPING = "cpu",
|
|
382
|
+
in_keys: Sequence[str] | None = None,
|
|
383
|
+
actor_net_kwargs=None,
|
|
384
|
+
qvalue_net_kwargs=None,
|
|
385
|
+
observation_key=None,
|
|
386
|
+
**kwargs,
|
|
387
|
+
) -> nn.ModuleList:
|
|
388
|
+
"""Actor and Q-value model constructor helper function for REDQ.
|
|
389
|
+
|
|
390
|
+
Follows default parameters proposed in REDQ original paper: https://openreview.net/pdf?id=AY8zfZm0tDd.
|
|
391
|
+
Other configurations can easily be implemented by modifying this function at will.
|
|
392
|
+
A single instance of the Q-value model is returned. It will be multiplicated by the loss function.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
proof_environment (EnvBase): a dummy environment to retrieve the observation and action spec
|
|
396
|
+
cfg (DictConfig): contains arguments of the REDQ script
|
|
397
|
+
device (torch.device, optional): device on which the model must be cast. Default is "cpu".
|
|
398
|
+
in_keys (iterable of strings, optional): observation key to be read by the actor, usually one of
|
|
399
|
+
`'observation_vector'` or `'pixels'`. If none is provided, one of these two keys is chosen
|
|
400
|
+
based on the `cfg.from_pixels` argument.
|
|
401
|
+
actor_net_kwargs (dict, optional): kwargs of the actor MLP.
|
|
402
|
+
qvalue_net_kwargs (dict, optional): kwargs of the qvalue MLP.
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
A nn.ModuleList containing the actor, qvalue operator(s) and the value operator.
|
|
406
|
+
|
|
407
|
+
"""
|
|
408
|
+
torch.manual_seed(cfg.seed)
|
|
409
|
+
tanh_loc = cfg.network.tanh_loc
|
|
410
|
+
default_policy_scale = cfg.network.default_policy_scale
|
|
411
|
+
gSDE = cfg.exploration.gSDE
|
|
412
|
+
|
|
413
|
+
action_spec = proof_environment.action_spec_unbatched
|
|
414
|
+
|
|
415
|
+
if actor_net_kwargs is None:
|
|
416
|
+
actor_net_kwargs = {}
|
|
417
|
+
if qvalue_net_kwargs is None:
|
|
418
|
+
qvalue_net_kwargs = {}
|
|
419
|
+
|
|
420
|
+
linear_layer_class = torch.nn.Linear if not cfg.exploration.noisy else NoisyLinear
|
|
421
|
+
|
|
422
|
+
out_features_actor = (2 - gSDE) * action_spec.shape[-1]
|
|
423
|
+
if cfg.env.from_pixels:
|
|
424
|
+
if in_keys is None:
|
|
425
|
+
in_keys_actor = ["pixels"]
|
|
426
|
+
else:
|
|
427
|
+
in_keys_actor = in_keys
|
|
428
|
+
actor_net_kwargs_default = {
|
|
429
|
+
"mlp_net_kwargs": {
|
|
430
|
+
"layer_class": linear_layer_class,
|
|
431
|
+
"activation_class": ACTIVATIONS[cfg.network.activation],
|
|
432
|
+
},
|
|
433
|
+
"conv_net_kwargs": {
|
|
434
|
+
"activation_class": ACTIVATIONS[cfg.network.activation]
|
|
435
|
+
},
|
|
436
|
+
}
|
|
437
|
+
actor_net_kwargs_default.update(actor_net_kwargs)
|
|
438
|
+
actor_net = DdpgCnnActor(out_features_actor, **actor_net_kwargs_default)
|
|
439
|
+
gSDE_state_key = "hidden"
|
|
440
|
+
out_keys_actor = ["param", "hidden"]
|
|
441
|
+
|
|
442
|
+
value_net_default_kwargs = {
|
|
443
|
+
"mlp_net_kwargs": {
|
|
444
|
+
"layer_class": linear_layer_class,
|
|
445
|
+
"activation_class": ACTIVATIONS[cfg.network.activation],
|
|
446
|
+
},
|
|
447
|
+
"conv_net_kwargs": {
|
|
448
|
+
"activation_class": ACTIVATIONS[cfg.network.activation]
|
|
449
|
+
},
|
|
450
|
+
}
|
|
451
|
+
value_net_default_kwargs.update(qvalue_net_kwargs)
|
|
452
|
+
|
|
453
|
+
in_keys_qvalue = ["pixels", "action"]
|
|
454
|
+
qvalue_net = DdpgCnnQNet(**value_net_default_kwargs)
|
|
455
|
+
else:
|
|
456
|
+
if in_keys is None:
|
|
457
|
+
in_keys_actor = ["observation_vector"]
|
|
458
|
+
else:
|
|
459
|
+
in_keys_actor = in_keys
|
|
460
|
+
|
|
461
|
+
actor_net_kwargs_default = {
|
|
462
|
+
"num_cells": [cfg.network.actor_cells] * cfg.network.actor_depth,
|
|
463
|
+
"out_features": out_features_actor,
|
|
464
|
+
"activation_class": ACTIVATIONS[cfg.network.activation],
|
|
465
|
+
}
|
|
466
|
+
actor_net_kwargs_default.update(actor_net_kwargs)
|
|
467
|
+
actor_net = MLP(**actor_net_kwargs_default)
|
|
468
|
+
out_keys_actor = ["param"]
|
|
469
|
+
gSDE_state_key = in_keys_actor[0]
|
|
470
|
+
|
|
471
|
+
qvalue_net_kwargs_default = {
|
|
472
|
+
"num_cells": [cfg.network.qvalue_cells] * cfg.network.qvalue_depth,
|
|
473
|
+
"out_features": 1,
|
|
474
|
+
"activation_class": ACTIVATIONS[cfg.network.activation],
|
|
475
|
+
}
|
|
476
|
+
qvalue_net_kwargs_default.update(qvalue_net_kwargs)
|
|
477
|
+
qvalue_net = MLP(
|
|
478
|
+
**qvalue_net_kwargs_default,
|
|
479
|
+
)
|
|
480
|
+
in_keys_qvalue = in_keys_actor + ["action"]
|
|
481
|
+
|
|
482
|
+
dist_class = TanhNormal
|
|
483
|
+
dist_kwargs = {
|
|
484
|
+
"low": action_spec.space.low,
|
|
485
|
+
"high": action_spec.space.high,
|
|
486
|
+
"tanh_loc": tanh_loc,
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
if not gSDE:
|
|
490
|
+
actor_net = nn.Sequential(
|
|
491
|
+
actor_net,
|
|
492
|
+
NormalParamExtractor(
|
|
493
|
+
scale_mapping=f"biased_softplus_{default_policy_scale}",
|
|
494
|
+
scale_lb=cfg.network.scale_lb,
|
|
495
|
+
),
|
|
496
|
+
)
|
|
497
|
+
actor_module = SafeModule(
|
|
498
|
+
actor_net,
|
|
499
|
+
in_keys=in_keys_actor,
|
|
500
|
+
out_keys=["loc", "scale"] + out_keys_actor[1:],
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
else:
|
|
504
|
+
actor_module = SafeModule(
|
|
505
|
+
actor_net,
|
|
506
|
+
in_keys=in_keys_actor,
|
|
507
|
+
out_keys=["action"] + out_keys_actor[1:], # will be overwritten
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
if action_spec.domain == "continuous":
|
|
511
|
+
min = action_spec.space.low
|
|
512
|
+
max = action_spec.space.high
|
|
513
|
+
transform = SafeTanhTransform()
|
|
514
|
+
if (min != -1).any() or (max != 1).any():
|
|
515
|
+
transform = d.ComposeTransform(
|
|
516
|
+
transform,
|
|
517
|
+
d.AffineTransform(loc=(max + min) / 2, scale=(max - min) / 2),
|
|
518
|
+
)
|
|
519
|
+
else:
|
|
520
|
+
raise RuntimeError("cannot use gSDE with discrete actions")
|
|
521
|
+
|
|
522
|
+
actor_module = SafeSequential(
|
|
523
|
+
actor_module,
|
|
524
|
+
SafeModule(
|
|
525
|
+
LazygSDEModule(transform=transform, device=device),
|
|
526
|
+
in_keys=["action", gSDE_state_key, "_eps_gSDE"],
|
|
527
|
+
out_keys=["loc", "scale", "action", "_eps_gSDE"],
|
|
528
|
+
),
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
actor = ProbabilisticActor(
|
|
532
|
+
spec=action_spec,
|
|
533
|
+
in_keys=["loc", "scale"],
|
|
534
|
+
module=actor_module,
|
|
535
|
+
distribution_class=dist_class,
|
|
536
|
+
distribution_kwargs=dist_kwargs,
|
|
537
|
+
default_interaction_type=InteractionType.RANDOM,
|
|
538
|
+
return_log_prob=True,
|
|
539
|
+
)
|
|
540
|
+
qvalue = ValueOperator(
|
|
541
|
+
in_keys=in_keys_qvalue,
|
|
542
|
+
module=qvalue_net,
|
|
543
|
+
)
|
|
544
|
+
model = nn.ModuleList([actor, qvalue]).to(device)
|
|
545
|
+
|
|
546
|
+
# init nets
|
|
547
|
+
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
|
|
548
|
+
td = proof_environment.fake_tensordict()
|
|
549
|
+
td = td.unsqueeze(-1)
|
|
550
|
+
td = td.to(device)
|
|
551
|
+
for net in model:
|
|
552
|
+
net(td)
|
|
553
|
+
del td
|
|
554
|
+
return model
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
def transformed_env_constructor(
|
|
558
|
+
cfg: DictConfig, # noqa: F821
|
|
559
|
+
video_tag: str = "",
|
|
560
|
+
logger: Logger | None = None,
|
|
561
|
+
stats: dict | None = None,
|
|
562
|
+
norm_obs_only: bool = False,
|
|
563
|
+
use_env_creator: bool = False,
|
|
564
|
+
custom_env_maker: Callable | None = None,
|
|
565
|
+
custom_env: EnvBase | None = None,
|
|
566
|
+
return_transformed_envs: bool = True,
|
|
567
|
+
action_dim_gsde: int | None = None,
|
|
568
|
+
state_dim_gsde: int | None = None,
|
|
569
|
+
batch_dims: int | None = 0,
|
|
570
|
+
obs_norm_state_dict: dict | None = None,
|
|
571
|
+
) -> Callable | EnvCreator:
|
|
572
|
+
"""Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor.
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
cfg (DictConfig): a DictConfig containing the arguments of the script.
|
|
576
|
+
video_tag (str, optional): video tag to be passed to the Logger object
|
|
577
|
+
logger (Logger, optional): logger associated with the script
|
|
578
|
+
stats (dict, optional): a dictionary containing the :obj:`loc` and :obj:`scale` for the `ObservationNorm` transform
|
|
579
|
+
norm_obs_only (bool, optional): If `True` and `VecNorm` is used, the reward won't be normalized online.
|
|
580
|
+
Default is `False`.
|
|
581
|
+
use_env_creator (bool, optional): whether the `EnvCreator` class should be used. By using `EnvCreator`,
|
|
582
|
+
one can make sure that running statistics will be put in shared memory and accessible for all workers
|
|
583
|
+
when using a `VecNorm` transform. Default is `True`.
|
|
584
|
+
custom_env_maker (callable, optional): if your env maker is not part
|
|
585
|
+
of torchrl env wrappers, a custom callable
|
|
586
|
+
can be passed instead. In this case it will override the
|
|
587
|
+
constructor retrieved from `args`.
|
|
588
|
+
custom_env (EnvBase, optional): if an existing environment needs to be
|
|
589
|
+
transformed_in, it can be passed directly to this helper. `custom_env_maker`
|
|
590
|
+
and `custom_env` are exclusive features.
|
|
591
|
+
return_transformed_envs (bool, optional): if ``True``, a transformed_in environment
|
|
592
|
+
is returned.
|
|
593
|
+
action_dim_gsde (int, Optional): if gSDE is used, this can present the action dim to initialize the noise.
|
|
594
|
+
Make sure this is indicated in environment executed in parallel.
|
|
595
|
+
state_dim_gsde: if gSDE is used, this can present the state dim to initialize the noise.
|
|
596
|
+
Make sure this is indicated in environment executed in parallel.
|
|
597
|
+
batch_dims (int, optional): number of dimensions of a batch of data. If a single env is
|
|
598
|
+
used, it should be 0 (default). If multiple envs are being transformed in parallel,
|
|
599
|
+
it should be set to 1 (or the number of dims of the batch).
|
|
600
|
+
obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform to be loaded into the
|
|
601
|
+
environment
|
|
602
|
+
"""
|
|
603
|
+
|
|
604
|
+
def make_transformed_env(**kwargs) -> TransformedEnv:
|
|
605
|
+
env_name = cfg.env.name
|
|
606
|
+
env_task = cfg.env.task
|
|
607
|
+
env_library = LIBS[cfg.env.library]
|
|
608
|
+
frame_skip = cfg.env.frame_skip
|
|
609
|
+
from_pixels = cfg.env.from_pixels
|
|
610
|
+
categorical_action_encoding = cfg.env.categorical_action_encoding
|
|
611
|
+
|
|
612
|
+
if custom_env is None and custom_env_maker is None:
|
|
613
|
+
if cfg.collector.device in ("", None):
|
|
614
|
+
device = "cpu" if not torch.cuda.is_available() else "cuda:0"
|
|
615
|
+
elif isinstance(cfg.collector.device, str):
|
|
616
|
+
device = cfg.collector.device
|
|
617
|
+
elif isinstance(cfg.collector.device, Sequence):
|
|
618
|
+
device = cfg.collector.device[0]
|
|
619
|
+
else:
|
|
620
|
+
raise ValueError(
|
|
621
|
+
"collector_device must be either a string or a sequence of strings"
|
|
622
|
+
)
|
|
623
|
+
env_kwargs = {
|
|
624
|
+
"env_name": env_name,
|
|
625
|
+
"device": device,
|
|
626
|
+
"frame_skip": frame_skip,
|
|
627
|
+
"from_pixels": from_pixels or len(video_tag),
|
|
628
|
+
"pixels_only": from_pixels,
|
|
629
|
+
}
|
|
630
|
+
if env_library is GymEnv:
|
|
631
|
+
env_kwargs.update(
|
|
632
|
+
{"categorical_action_encoding": categorical_action_encoding}
|
|
633
|
+
)
|
|
634
|
+
elif categorical_action_encoding:
|
|
635
|
+
raise NotImplementedError(
|
|
636
|
+
"categorical_action_encoding=True is currently only compatible with GymEnvs."
|
|
637
|
+
)
|
|
638
|
+
if env_library is DMControlEnv:
|
|
639
|
+
env_kwargs.update({"task_name": env_task})
|
|
640
|
+
env_kwargs.update(kwargs)
|
|
641
|
+
env = env_library(**env_kwargs)
|
|
642
|
+
elif custom_env is None and custom_env_maker is not None:
|
|
643
|
+
env = custom_env_maker(**kwargs)
|
|
644
|
+
elif custom_env_maker is None and custom_env is not None:
|
|
645
|
+
env = custom_env
|
|
646
|
+
else:
|
|
647
|
+
raise RuntimeError("cannot provide both custom_env and custom_env_maker")
|
|
648
|
+
|
|
649
|
+
if cfg.env.noops and custom_env is None:
|
|
650
|
+
# this is a bit hacky: if custom_env is not None, it is probably a ParallelEnv
|
|
651
|
+
# that already has its NoopResetEnv set for the contained envs.
|
|
652
|
+
# There is a risk however that we're just skipping the NoopsReset instantiation
|
|
653
|
+
env = TransformedEnv(env, NoopResetEnv(cfg.env.noops))
|
|
654
|
+
if not return_transformed_envs:
|
|
655
|
+
return env
|
|
656
|
+
|
|
657
|
+
return make_env_transforms(
|
|
658
|
+
env,
|
|
659
|
+
cfg,
|
|
660
|
+
video_tag,
|
|
661
|
+
logger,
|
|
662
|
+
env_name,
|
|
663
|
+
stats,
|
|
664
|
+
norm_obs_only,
|
|
665
|
+
env_library,
|
|
666
|
+
action_dim_gsde,
|
|
667
|
+
state_dim_gsde,
|
|
668
|
+
batch_dims=batch_dims,
|
|
669
|
+
obs_norm_state_dict=obs_norm_state_dict,
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
if use_env_creator:
|
|
673
|
+
return env_creator(make_transformed_env)
|
|
674
|
+
return make_transformed_env
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
def get_norm_state_dict(env):
|
|
678
|
+
"""Gets the normalization loc and scale from the env state_dict."""
|
|
679
|
+
sd = env.state_dict()
|
|
680
|
+
sd = {
|
|
681
|
+
key: val
|
|
682
|
+
for key, val in sd.items()
|
|
683
|
+
if key.endswith("loc") or key.endswith("scale")
|
|
684
|
+
}
|
|
685
|
+
return sd
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
def initialize_observation_norm_transforms(
|
|
689
|
+
proof_environment: EnvBase,
|
|
690
|
+
num_iter: int = 1000,
|
|
691
|
+
key: str | tuple[str, ...] = None,
|
|
692
|
+
):
|
|
693
|
+
"""Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`.
|
|
694
|
+
|
|
695
|
+
If an :obj:`ObservationNorm` already has non-null :obj:`loc` or :obj:`scale`, a call to :obj:`initialize_observation_norm_transforms` will be a no-op.
|
|
696
|
+
Similarly, if the transformed environment does not contain any :obj:`ObservationNorm`, a call to this function will have no effect.
|
|
697
|
+
If no key is provided but the observations of the :obj:`EnvBase` contains more than one key, an exception will
|
|
698
|
+
be raised.
|
|
699
|
+
|
|
700
|
+
Args:
|
|
701
|
+
proof_environment (EnvBase instance, optional): if provided, this env will
|
|
702
|
+
be used to execute the rollouts. If not, it will be created using
|
|
703
|
+
the cfg object.
|
|
704
|
+
num_iter (int): Number of iterations used for initializing the :obj:`ObservationNorms`
|
|
705
|
+
key (str, optional): if provided, the stats of this key will be gathered.
|
|
706
|
+
If not, it is expected that only one key exists in `env.observation_spec`.
|
|
707
|
+
|
|
708
|
+
"""
|
|
709
|
+
if not isinstance(proof_environment.transform, Compose) and not isinstance(
|
|
710
|
+
proof_environment.transform, ObservationNorm
|
|
711
|
+
):
|
|
712
|
+
return
|
|
713
|
+
|
|
714
|
+
if key is None:
|
|
715
|
+
keys = list(proof_environment.base_env.observation_spec.keys(True, True))
|
|
716
|
+
key = keys.pop()
|
|
717
|
+
if len(keys):
|
|
718
|
+
raise RuntimeError(
|
|
719
|
+
f"More than one key exists in the observation_specs: {[key] + keys} were found, "
|
|
720
|
+
"thus initialize_observation_norm_transforms cannot infer which to compute the stats of."
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
if isinstance(proof_environment.transform, Compose):
|
|
724
|
+
for transform in proof_environment.transform:
|
|
725
|
+
if isinstance(transform, ObservationNorm) and not transform.initialized:
|
|
726
|
+
transform.init_stats(num_iter=num_iter, key=key)
|
|
727
|
+
elif not proof_environment.transform.initialized:
|
|
728
|
+
proof_environment.transform.init_stats(num_iter=num_iter, key=key)
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
def parallel_env_constructor(
|
|
732
|
+
cfg: DictConfig, **kwargs # noqa: F821
|
|
733
|
+
) -> ParallelEnv | EnvCreator:
|
|
734
|
+
"""Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor.
|
|
735
|
+
|
|
736
|
+
Args:
|
|
737
|
+
cfg (DictConfig): config containing user-defined arguments
|
|
738
|
+
kwargs: keyword arguments for the `transformed_env_constructor` method.
|
|
739
|
+
"""
|
|
740
|
+
batch_transform = cfg.env.batch_transform
|
|
741
|
+
if not batch_transform:
|
|
742
|
+
raise NotImplementedError(
|
|
743
|
+
"batch_transform must be set to True for the recorder to be synced "
|
|
744
|
+
"with the collection envs."
|
|
745
|
+
)
|
|
746
|
+
if cfg.collector.env_per_collector == 1:
|
|
747
|
+
kwargs.update({"cfg": cfg, "use_env_creator": True})
|
|
748
|
+
make_transformed_env = transformed_env_constructor(**kwargs)
|
|
749
|
+
return make_transformed_env
|
|
750
|
+
kwargs.update({"cfg": cfg, "use_env_creator": True})
|
|
751
|
+
make_transformed_env = transformed_env_constructor(
|
|
752
|
+
return_transformed_envs=not batch_transform, **kwargs
|
|
753
|
+
)
|
|
754
|
+
parallel_env = ParallelEnv(
|
|
755
|
+
num_workers=cfg.collector.env_per_collector,
|
|
756
|
+
create_env_fn=make_transformed_env,
|
|
757
|
+
create_env_kwargs=None,
|
|
758
|
+
serial_for_single=True,
|
|
759
|
+
pin_memory=False,
|
|
760
|
+
)
|
|
761
|
+
if batch_transform:
|
|
762
|
+
kwargs.update(
|
|
763
|
+
{
|
|
764
|
+
"cfg": cfg,
|
|
765
|
+
"use_env_creator": False,
|
|
766
|
+
"custom_env": parallel_env,
|
|
767
|
+
"batch_dims": 1,
|
|
768
|
+
}
|
|
769
|
+
)
|
|
770
|
+
env = transformed_env_constructor(**kwargs)()
|
|
771
|
+
return env
|
|
772
|
+
return parallel_env
|
|
773
|
+
|
|
774
|
+
|
|
775
|
+
def retrieve_observation_norms_state_dict(proof_environment: TransformedEnv):
|
|
776
|
+
"""Traverses the transforms of the environment and retrieves the :obj:`ObservationNorm` state dicts.
|
|
777
|
+
|
|
778
|
+
Returns a list of tuple (idx, state_dict) for each :obj:`ObservationNorm` transform in proof_environment
|
|
779
|
+
If the environment transforms do not contain any :obj:`ObservationNorm`, returns an empty list
|
|
780
|
+
|
|
781
|
+
Args:
|
|
782
|
+
proof_environment (EnvBase instance, optional): the :obj:``TransformedEnv` to retrieve the :obj:`ObservationNorm`
|
|
783
|
+
state dict from
|
|
784
|
+
"""
|
|
785
|
+
obs_norm_state_dicts = []
|
|
786
|
+
|
|
787
|
+
if isinstance(proof_environment.transform, Compose):
|
|
788
|
+
for idx, transform in enumerate(proof_environment.transform):
|
|
789
|
+
if isinstance(transform, ObservationNorm):
|
|
790
|
+
obs_norm_state_dicts.append((idx, transform.state_dict()))
|
|
791
|
+
|
|
792
|
+
if isinstance(proof_environment.transform, ObservationNorm):
|
|
793
|
+
obs_norm_state_dicts.append((0, proof_environment.transform.state_dict()))
|
|
794
|
+
|
|
795
|
+
return obs_norm_state_dicts
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
def make_env_transforms(
|
|
799
|
+
env,
|
|
800
|
+
cfg,
|
|
801
|
+
video_tag,
|
|
802
|
+
logger,
|
|
803
|
+
env_name,
|
|
804
|
+
stats,
|
|
805
|
+
norm_obs_only,
|
|
806
|
+
env_library,
|
|
807
|
+
action_dim_gsde,
|
|
808
|
+
state_dim_gsde,
|
|
809
|
+
batch_dims=0,
|
|
810
|
+
obs_norm_state_dict=None,
|
|
811
|
+
):
|
|
812
|
+
"""Creates the typical transforms for and env."""
|
|
813
|
+
env = TransformedEnv(env)
|
|
814
|
+
|
|
815
|
+
from_pixels = cfg.env.from_pixels
|
|
816
|
+
vecnorm = cfg.env.vecnorm
|
|
817
|
+
norm_rewards = vecnorm and cfg.env.norm_rewards
|
|
818
|
+
_norm_obs_only = norm_obs_only or not norm_rewards
|
|
819
|
+
reward_scaling = cfg.env.reward_scaling
|
|
820
|
+
reward_loc = cfg.env.reward_loc
|
|
821
|
+
|
|
822
|
+
if len(video_tag):
|
|
823
|
+
center_crop = cfg.env.center_crop
|
|
824
|
+
if center_crop:
|
|
825
|
+
center_crop = center_crop[0]
|
|
826
|
+
env.append_transform(
|
|
827
|
+
VideoRecorder(
|
|
828
|
+
logger=logger,
|
|
829
|
+
tag=f"{video_tag}_{env_name}_video",
|
|
830
|
+
center_crop=center_crop,
|
|
831
|
+
),
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
if from_pixels:
|
|
835
|
+
if not cfg.env.catframes:
|
|
836
|
+
raise RuntimeError(
|
|
837
|
+
"this env builder currently only accepts positive catframes values"
|
|
838
|
+
"when pixels are being used."
|
|
839
|
+
)
|
|
840
|
+
env.append_transform(ToTensorImage())
|
|
841
|
+
if cfg.env.center_crop:
|
|
842
|
+
env.append_transform(CenterCrop(*cfg.env.center_crop))
|
|
843
|
+
env.append_transform(Resize(cfg.env.image_size, cfg.env.image_size))
|
|
844
|
+
if cfg.env.grayscale:
|
|
845
|
+
env.append_transform(GrayScale())
|
|
846
|
+
env.append_transform(FlattenObservation(0, -3, allow_positive_dim=True))
|
|
847
|
+
env.append_transform(CatFrames(N=cfg.env.catframes, in_keys=["pixels"], dim=-3))
|
|
848
|
+
if stats is None and obs_norm_state_dict is None:
|
|
849
|
+
obs_stats = {}
|
|
850
|
+
elif stats is None:
|
|
851
|
+
obs_stats = copy(obs_norm_state_dict)
|
|
852
|
+
else:
|
|
853
|
+
obs_stats = copy(stats)
|
|
854
|
+
obs_stats["standard_normal"] = True
|
|
855
|
+
obs_norm = ObservationNorm(**obs_stats, in_keys=["pixels"])
|
|
856
|
+
env.append_transform(obs_norm)
|
|
857
|
+
if norm_rewards:
|
|
858
|
+
reward_scaling = 1.0
|
|
859
|
+
reward_loc = 0.0
|
|
860
|
+
if norm_obs_only:
|
|
861
|
+
reward_scaling = 1.0
|
|
862
|
+
reward_loc = 0.0
|
|
863
|
+
if reward_scaling is not None:
|
|
864
|
+
env.append_transform(RewardScaling(reward_loc, reward_scaling))
|
|
865
|
+
|
|
866
|
+
if not from_pixels:
|
|
867
|
+
selected_keys = [
|
|
868
|
+
key
|
|
869
|
+
for key in env.observation_spec.keys(True, True)
|
|
870
|
+
if ("pixels" not in key) and (key not in env.state_spec.keys(True, True))
|
|
871
|
+
]
|
|
872
|
+
|
|
873
|
+
# even if there is a single tensor, it'll be renamed in "observation_vector"
|
|
874
|
+
out_key = "observation_vector"
|
|
875
|
+
env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key))
|
|
876
|
+
|
|
877
|
+
if not vecnorm:
|
|
878
|
+
if stats is None and obs_norm_state_dict is None:
|
|
879
|
+
_stats = {}
|
|
880
|
+
elif stats is None:
|
|
881
|
+
_stats = copy(obs_norm_state_dict)
|
|
882
|
+
else:
|
|
883
|
+
_stats = copy(stats)
|
|
884
|
+
_stats.update({"standard_normal": True})
|
|
885
|
+
obs_norm = ObservationNorm(
|
|
886
|
+
**_stats,
|
|
887
|
+
in_keys=[out_key],
|
|
888
|
+
)
|
|
889
|
+
env.append_transform(obs_norm)
|
|
890
|
+
else:
|
|
891
|
+
env.append_transform(
|
|
892
|
+
VecNorm(
|
|
893
|
+
in_keys=[out_key, "reward"] if not _norm_obs_only else [out_key],
|
|
894
|
+
decay=0.9999,
|
|
895
|
+
)
|
|
896
|
+
)
|
|
897
|
+
|
|
898
|
+
env.append_transform(DoubleToFloat())
|
|
899
|
+
|
|
900
|
+
if hasattr(cfg, "catframes") and cfg.env.catframes:
|
|
901
|
+
env.append_transform(
|
|
902
|
+
CatFrames(N=cfg.env.catframes, in_keys=[out_key], dim=-1)
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
else:
|
|
906
|
+
env.append_transform(DoubleToFloat())
|
|
907
|
+
|
|
908
|
+
if hasattr(cfg, "gSDE") and cfg.exploration.gSDE:
|
|
909
|
+
env.append_transform(
|
|
910
|
+
gSDENoise(action_dim=action_dim_gsde, state_dim=state_dim_gsde)
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
env.append_transform(StepCounter())
|
|
914
|
+
env.append_transform(InitTracker())
|
|
915
|
+
|
|
916
|
+
return env
|
|
917
|
+
|
|
918
|
+
|
|
919
|
+
def make_redq_loss(model, cfg) -> tuple[REDQLoss_deprecated, TargetNetUpdater | None]:
|
|
920
|
+
"""Builds the REDQ loss module."""
|
|
921
|
+
loss_kwargs = {}
|
|
922
|
+
loss_kwargs.update({"loss_function": cfg.loss.loss_function})
|
|
923
|
+
loss_kwargs.update({"delay_qvalue": cfg.loss.type == "double"})
|
|
924
|
+
loss_class = REDQLoss_deprecated
|
|
925
|
+
if isinstance(model, ActorValueOperator):
|
|
926
|
+
actor_model = model.get_policy_operator()
|
|
927
|
+
qvalue_model = model.get_value_operator()
|
|
928
|
+
elif isinstance(model, ActorCriticOperator):
|
|
929
|
+
raise RuntimeError(
|
|
930
|
+
"Although REDQ Q-value depends upon selected actions, using the"
|
|
931
|
+
"ActorCriticOperator will lead to resampling of the actions when"
|
|
932
|
+
"computing the Q-value loss, which we don't want. Please use the"
|
|
933
|
+
"ActorValueOperator instead."
|
|
934
|
+
)
|
|
935
|
+
else:
|
|
936
|
+
actor_model, qvalue_model = model
|
|
937
|
+
|
|
938
|
+
loss_module = loss_class(
|
|
939
|
+
actor_network=actor_model,
|
|
940
|
+
qvalue_network=qvalue_model,
|
|
941
|
+
num_qvalue_nets=cfg.loss.num_q_values,
|
|
942
|
+
gSDE=cfg.exploration.gSDE,
|
|
943
|
+
**loss_kwargs,
|
|
944
|
+
)
|
|
945
|
+
loss_module.make_value_estimator(gamma=cfg.loss.gamma)
|
|
946
|
+
target_net_updater = make_target_updater(cfg, loss_module)
|
|
947
|
+
return loss_module, target_net_updater
|
|
948
|
+
|
|
949
|
+
|
|
950
|
+
def make_target_updater(
|
|
951
|
+
cfg: DictConfig, loss_module: LossModule # noqa: F821
|
|
952
|
+
) -> TargetNetUpdater | None:
|
|
953
|
+
"""Builds a target network weight update object."""
|
|
954
|
+
if cfg.loss.type == "double":
|
|
955
|
+
if not cfg.loss.hard_update:
|
|
956
|
+
target_net_updater = SoftUpdate(
|
|
957
|
+
loss_module, eps=1 - 1 / cfg.loss.value_network_update_interval
|
|
958
|
+
)
|
|
959
|
+
else:
|
|
960
|
+
target_net_updater = HardUpdate(
|
|
961
|
+
loss_module,
|
|
962
|
+
value_network_update_interval=cfg.loss.value_network_update_interval,
|
|
963
|
+
)
|
|
964
|
+
else:
|
|
965
|
+
if cfg.hard_update:
|
|
966
|
+
raise RuntimeError(
|
|
967
|
+
"hard/soft-update are supposed to be used with double SAC loss. "
|
|
968
|
+
"Consider using --loss=double or discarding the hard_update flag."
|
|
969
|
+
)
|
|
970
|
+
target_net_updater = None
|
|
971
|
+
return target_net_updater
|
|
972
|
+
|
|
973
|
+
|
|
974
|
+
def make_collector_offpolicy(
|
|
975
|
+
make_env: Callable[[], EnvBase],
|
|
976
|
+
actor_model_explore: TensorDictModuleWrapper | ProbabilisticTensorDictSequential,
|
|
977
|
+
cfg: DictConfig, # noqa: F821
|
|
978
|
+
make_env_kwargs: dict | None = None,
|
|
979
|
+
) -> DataCollectorBase:
|
|
980
|
+
"""Returns a data collector for off-policy sota-implementations.
|
|
981
|
+
|
|
982
|
+
Args:
|
|
983
|
+
make_env (Callable): environment creator
|
|
984
|
+
actor_model_explore (SafeModule): Model instance used for evaluation and exploration update
|
|
985
|
+
cfg (DictConfig): config for creating collector object
|
|
986
|
+
make_env_kwargs (dict): kwargs for the env creator
|
|
987
|
+
|
|
988
|
+
"""
|
|
989
|
+
if cfg.collector.async_collection:
|
|
990
|
+
collector_helper = sync_async_collector
|
|
991
|
+
else:
|
|
992
|
+
collector_helper = sync_sync_collector
|
|
993
|
+
|
|
994
|
+
if cfg.collector.multi_step:
|
|
995
|
+
ms = MultiStep(
|
|
996
|
+
gamma=cfg.loss.gamma,
|
|
997
|
+
n_steps=cfg.collector.n_steps_return,
|
|
998
|
+
)
|
|
999
|
+
else:
|
|
1000
|
+
ms = None
|
|
1001
|
+
|
|
1002
|
+
env_kwargs = {}
|
|
1003
|
+
if make_env_kwargs is not None and isinstance(make_env_kwargs, dict):
|
|
1004
|
+
env_kwargs.update(make_env_kwargs)
|
|
1005
|
+
elif make_env_kwargs is not None:
|
|
1006
|
+
env_kwargs = make_env_kwargs
|
|
1007
|
+
if cfg.collector.device in ("", None):
|
|
1008
|
+
cfg.collector.device = "cpu" if not torch.cuda.is_available() else "cuda:0"
|
|
1009
|
+
else:
|
|
1010
|
+
cfg.collector.device = (
|
|
1011
|
+
cfg.collector.device
|
|
1012
|
+
if len(cfg.collector.device) > 1
|
|
1013
|
+
else cfg.collector.device[0]
|
|
1014
|
+
)
|
|
1015
|
+
collector_helper_kwargs = {
|
|
1016
|
+
"env_fns": make_env,
|
|
1017
|
+
"env_kwargs": env_kwargs,
|
|
1018
|
+
"policy": actor_model_explore,
|
|
1019
|
+
"max_frames_per_traj": cfg.collector.max_frames_per_traj,
|
|
1020
|
+
"frames_per_batch": cfg.collector.frames_per_batch,
|
|
1021
|
+
"total_frames": cfg.collector.total_frames,
|
|
1022
|
+
"postproc": ms,
|
|
1023
|
+
"num_env_per_collector": 1,
|
|
1024
|
+
# we already took care of building the make_parallel_env function
|
|
1025
|
+
"num_collectors": -cfg.num_workers // -cfg.collector.env_per_collector,
|
|
1026
|
+
"device": cfg.collector.device,
|
|
1027
|
+
"init_random_frames": cfg.collector.init_random_frames,
|
|
1028
|
+
"split_trajs": True,
|
|
1029
|
+
# trajectories must be separated if multi-step is used
|
|
1030
|
+
}
|
|
1031
|
+
|
|
1032
|
+
collector = collector_helper(**collector_helper_kwargs)
|
|
1033
|
+
collector.set_seed(cfg.seed)
|
|
1034
|
+
return collector
|
|
1035
|
+
|
|
1036
|
+
|
|
1037
|
+
def make_replay_buffer(
|
|
1038
|
+
device: DEVICE_TYPING, cfg: DictConfig # noqa: F821
|
|
1039
|
+
) -> ReplayBuffer: # noqa: F821
|
|
1040
|
+
"""Builds a replay buffer using the config built from ReplayArgsConfig."""
|
|
1041
|
+
device = torch.device(device)
|
|
1042
|
+
if not cfg.buffer.prb:
|
|
1043
|
+
sampler = RandomSampler()
|
|
1044
|
+
else:
|
|
1045
|
+
sampler = PrioritizedSampler(
|
|
1046
|
+
max_capacity=cfg.buffer.size,
|
|
1047
|
+
alpha=0.7,
|
|
1048
|
+
beta=0.5,
|
|
1049
|
+
)
|
|
1050
|
+
buffer = TensorDictReplayBuffer(
|
|
1051
|
+
storage=LazyMemmapStorage(
|
|
1052
|
+
cfg.buffer.size,
|
|
1053
|
+
scratch_dir=cfg.buffer.scratch_dir,
|
|
1054
|
+
),
|
|
1055
|
+
sampler=sampler,
|
|
1056
|
+
pin_memory=device != torch.device("cpu"),
|
|
1057
|
+
prefetch=cfg.buffer.prefetch,
|
|
1058
|
+
batch_size=cfg.buffer.batch_size,
|
|
1059
|
+
)
|
|
1060
|
+
return buffer
|