torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""DDPG Example.
|
|
6
|
+
|
|
7
|
+
This is a simple self-contained example of a DDPG training script.
|
|
8
|
+
|
|
9
|
+
It supports state environments like MuJoCo.
|
|
10
|
+
|
|
11
|
+
The helper functions are coded in the utils.py associated with this script.
|
|
12
|
+
"""
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import warnings
|
|
16
|
+
|
|
17
|
+
import hydra
|
|
18
|
+
import numpy as np
|
|
19
|
+
import torch
|
|
20
|
+
import torch.cuda
|
|
21
|
+
import tqdm
|
|
22
|
+
from tensordict import TensorDict
|
|
23
|
+
from tensordict.nn import CudaGraphModule
|
|
24
|
+
from torchrl._utils import get_available_device, timeit
|
|
25
|
+
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
26
|
+
from torchrl.objectives import group_optimizers
|
|
27
|
+
from torchrl.record.loggers import generate_exp_name, get_logger
|
|
28
|
+
from utils import (
|
|
29
|
+
dump_video,
|
|
30
|
+
log_metrics,
|
|
31
|
+
make_collector,
|
|
32
|
+
make_ddpg_agent,
|
|
33
|
+
make_environment,
|
|
34
|
+
make_loss_module,
|
|
35
|
+
make_optimizer,
|
|
36
|
+
make_replay_buffer,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@hydra.main(version_base="1.1", config_path="", config_name="config")
|
|
41
|
+
def main(cfg: DictConfig): # noqa: F821
|
|
42
|
+
device = (
|
|
43
|
+
torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
|
|
44
|
+
)
|
|
45
|
+
collector_device = (
|
|
46
|
+
torch.device(cfg.collector.device)
|
|
47
|
+
if cfg.collector.device
|
|
48
|
+
else get_available_device()
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# Create logger
|
|
52
|
+
exp_name = generate_exp_name("DDPG", cfg.logger.exp_name)
|
|
53
|
+
logger = None
|
|
54
|
+
if cfg.logger.backend:
|
|
55
|
+
logger = get_logger(
|
|
56
|
+
logger_type=cfg.logger.backend,
|
|
57
|
+
logger_name="ddpg_logging",
|
|
58
|
+
experiment_name=exp_name,
|
|
59
|
+
wandb_kwargs={
|
|
60
|
+
"mode": cfg.logger.mode,
|
|
61
|
+
"config": dict(cfg),
|
|
62
|
+
"project": cfg.logger.project_name,
|
|
63
|
+
"group": cfg.logger.group_name,
|
|
64
|
+
},
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Set seeds
|
|
68
|
+
torch.manual_seed(cfg.env.seed)
|
|
69
|
+
np.random.seed(cfg.env.seed)
|
|
70
|
+
|
|
71
|
+
# Create environments
|
|
72
|
+
train_env, eval_env = make_environment(cfg, logger=logger)
|
|
73
|
+
|
|
74
|
+
# Create agent
|
|
75
|
+
model, exploration_policy = make_ddpg_agent(cfg, train_env, eval_env, device)
|
|
76
|
+
|
|
77
|
+
# Create DDPG loss
|
|
78
|
+
loss_module, target_net_updater = make_loss_module(cfg, model)
|
|
79
|
+
|
|
80
|
+
compile_mode = None
|
|
81
|
+
if cfg.compile.compile:
|
|
82
|
+
if cfg.compile.compile_mode not in (None, ""):
|
|
83
|
+
compile_mode = cfg.compile.compile_mode
|
|
84
|
+
elif cfg.compile.cudagraphs:
|
|
85
|
+
compile_mode = "default"
|
|
86
|
+
else:
|
|
87
|
+
compile_mode = "reduce-overhead"
|
|
88
|
+
|
|
89
|
+
# Create off-policy collector
|
|
90
|
+
collector = make_collector(
|
|
91
|
+
cfg,
|
|
92
|
+
train_env,
|
|
93
|
+
exploration_policy,
|
|
94
|
+
compile=cfg.compile.compile,
|
|
95
|
+
compile_mode=compile_mode,
|
|
96
|
+
cudagraph=cfg.compile.cudagraphs,
|
|
97
|
+
device=collector_device,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Create replay buffer
|
|
101
|
+
replay_buffer = make_replay_buffer(
|
|
102
|
+
batch_size=cfg.optim.batch_size,
|
|
103
|
+
prb=cfg.replay_buffer.prb,
|
|
104
|
+
buffer_size=cfg.replay_buffer.size,
|
|
105
|
+
scratch_dir=cfg.replay_buffer.scratch_dir,
|
|
106
|
+
device="cpu",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Create optimizers
|
|
110
|
+
optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)
|
|
111
|
+
optimizer = group_optimizers(optimizer_actor, optimizer_critic)
|
|
112
|
+
|
|
113
|
+
def update(sampled_tensordict):
|
|
114
|
+
optimizer.zero_grad(set_to_none=True)
|
|
115
|
+
|
|
116
|
+
td_loss: TensorDict = loss_module(sampled_tensordict)
|
|
117
|
+
td_loss.sum(reduce=True).backward()
|
|
118
|
+
optimizer.step()
|
|
119
|
+
|
|
120
|
+
# Update qnet_target params
|
|
121
|
+
target_net_updater.step()
|
|
122
|
+
return td_loss.detach()
|
|
123
|
+
|
|
124
|
+
if cfg.compile.compile:
|
|
125
|
+
update = torch.compile(update, mode=compile_mode)
|
|
126
|
+
if cfg.compile.cudagraphs:
|
|
127
|
+
warnings.warn(
|
|
128
|
+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
|
|
129
|
+
category=UserWarning,
|
|
130
|
+
)
|
|
131
|
+
update = CudaGraphModule(update, warmup=50)
|
|
132
|
+
|
|
133
|
+
# Main loop
|
|
134
|
+
collected_frames = 0
|
|
135
|
+
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
|
|
136
|
+
|
|
137
|
+
init_random_frames = cfg.collector.init_random_frames
|
|
138
|
+
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
|
|
139
|
+
prb = cfg.replay_buffer.prb
|
|
140
|
+
frames_per_batch = cfg.collector.frames_per_batch
|
|
141
|
+
eval_iter = cfg.logger.eval_iter
|
|
142
|
+
eval_rollout_steps = cfg.env.max_episode_steps
|
|
143
|
+
|
|
144
|
+
c_iter = iter(collector)
|
|
145
|
+
total_iter = len(collector)
|
|
146
|
+
for _ in range(total_iter):
|
|
147
|
+
timeit.printevery(1000, total_iter, erase=True)
|
|
148
|
+
with timeit("collecting"):
|
|
149
|
+
tensordict = next(c_iter)
|
|
150
|
+
# Update exploration policy
|
|
151
|
+
exploration_policy[1].step(tensordict.numel())
|
|
152
|
+
|
|
153
|
+
# Update weights of the inference policy
|
|
154
|
+
collector.update_policy_weights_()
|
|
155
|
+
|
|
156
|
+
current_frames = tensordict.numel()
|
|
157
|
+
pbar.update(current_frames)
|
|
158
|
+
|
|
159
|
+
# Add to replay buffer
|
|
160
|
+
with timeit("rb - extend"):
|
|
161
|
+
tensordict = tensordict.reshape(-1)
|
|
162
|
+
replay_buffer.extend(tensordict)
|
|
163
|
+
|
|
164
|
+
collected_frames += current_frames
|
|
165
|
+
|
|
166
|
+
# Optimization steps
|
|
167
|
+
if collected_frames >= init_random_frames:
|
|
168
|
+
tds = []
|
|
169
|
+
for _ in range(num_updates):
|
|
170
|
+
# Sample from replay buffer
|
|
171
|
+
with timeit("rb - sample"):
|
|
172
|
+
sampled_tensordict = replay_buffer.sample().to(device)
|
|
173
|
+
with timeit("update"):
|
|
174
|
+
torch.compiler.cudagraph_mark_step_begin()
|
|
175
|
+
td_loss = update(sampled_tensordict)
|
|
176
|
+
tds.append(td_loss.clone())
|
|
177
|
+
|
|
178
|
+
# Update priority
|
|
179
|
+
if prb:
|
|
180
|
+
replay_buffer.update_priority(sampled_tensordict)
|
|
181
|
+
tds = torch.stack(tds)
|
|
182
|
+
|
|
183
|
+
episode_end = (
|
|
184
|
+
tensordict["next", "done"]
|
|
185
|
+
if tensordict["next", "done"].any()
|
|
186
|
+
else tensordict["next", "truncated"]
|
|
187
|
+
)
|
|
188
|
+
episode_rewards = tensordict["next", "episode_reward"][episode_end]
|
|
189
|
+
|
|
190
|
+
# Logging
|
|
191
|
+
metrics_to_log = {}
|
|
192
|
+
if len(episode_rewards) > 0:
|
|
193
|
+
episode_length = tensordict["next", "step_count"][episode_end]
|
|
194
|
+
metrics_to_log["train/reward"] = episode_rewards.mean().item()
|
|
195
|
+
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
|
|
196
|
+
episode_length
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
if collected_frames >= init_random_frames:
|
|
200
|
+
tds = TensorDict(train=tds).flatten_keys("/").mean()
|
|
201
|
+
metrics_to_log.update(tds.to_dict())
|
|
202
|
+
|
|
203
|
+
# Evaluation
|
|
204
|
+
if abs(collected_frames % eval_iter) < frames_per_batch:
|
|
205
|
+
with set_exploration_type(
|
|
206
|
+
ExplorationType.DETERMINISTIC
|
|
207
|
+
), torch.no_grad(), timeit("eval"):
|
|
208
|
+
eval_rollout = eval_env.rollout(
|
|
209
|
+
eval_rollout_steps,
|
|
210
|
+
exploration_policy,
|
|
211
|
+
auto_cast_to_device=True,
|
|
212
|
+
break_when_any_done=True,
|
|
213
|
+
)
|
|
214
|
+
eval_env.apply(dump_video)
|
|
215
|
+
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
|
|
216
|
+
metrics_to_log["eval/reward"] = eval_reward
|
|
217
|
+
|
|
218
|
+
if logger is not None:
|
|
219
|
+
metrics_to_log.update(timeit.todict(prefix="time"))
|
|
220
|
+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
|
|
221
|
+
log_metrics(logger, metrics_to_log, collected_frames)
|
|
222
|
+
|
|
223
|
+
collector.shutdown()
|
|
224
|
+
if not eval_env.is_closed:
|
|
225
|
+
eval_env.close()
|
|
226
|
+
if not train_env.is_closed:
|
|
227
|
+
train_env.close()
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
if __name__ == "__main__":
|
|
231
|
+
main()
|
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from tensordict.nn import TensorDictModule, TensorDictSequential
|
|
12
|
+
|
|
13
|
+
from torch import nn, optim
|
|
14
|
+
from torchrl.collectors import SyncDataCollector
|
|
15
|
+
from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer
|
|
16
|
+
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
|
|
17
|
+
from torchrl.envs import (
|
|
18
|
+
CatTensors,
|
|
19
|
+
Compose,
|
|
20
|
+
DMControlEnv,
|
|
21
|
+
DoubleToFloat,
|
|
22
|
+
EnvCreator,
|
|
23
|
+
InitTracker,
|
|
24
|
+
ParallelEnv,
|
|
25
|
+
RewardSum,
|
|
26
|
+
StepCounter,
|
|
27
|
+
TransformedEnv,
|
|
28
|
+
)
|
|
29
|
+
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
|
|
30
|
+
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
31
|
+
from torchrl.modules import (
|
|
32
|
+
AdditiveGaussianModule,
|
|
33
|
+
MLP,
|
|
34
|
+
OrnsteinUhlenbeckProcessModule,
|
|
35
|
+
TanhModule,
|
|
36
|
+
ValueOperator,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
from torchrl.objectives import SoftUpdate
|
|
40
|
+
from torchrl.objectives.ddpg import DDPGLoss
|
|
41
|
+
from torchrl.record import VideoRecorder
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# ====================================================================
|
|
45
|
+
# Environment utils
|
|
46
|
+
# -----------------
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def env_maker(cfg, device="cpu", from_pixels=False):
|
|
50
|
+
lib = cfg.env.library
|
|
51
|
+
if lib in ("gym", "gymnasium"):
|
|
52
|
+
with set_gym_backend(lib):
|
|
53
|
+
return GymEnv(
|
|
54
|
+
cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False
|
|
55
|
+
)
|
|
56
|
+
elif lib == "dm_control":
|
|
57
|
+
env = DMControlEnv(
|
|
58
|
+
cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False
|
|
59
|
+
)
|
|
60
|
+
return TransformedEnv(
|
|
61
|
+
env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
|
|
62
|
+
)
|
|
63
|
+
else:
|
|
64
|
+
raise NotImplementedError(f"Unknown lib {lib}.")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def apply_env_transforms(env, max_episode_steps=1000):
|
|
68
|
+
transformed_env = TransformedEnv(
|
|
69
|
+
env,
|
|
70
|
+
Compose(
|
|
71
|
+
InitTracker(),
|
|
72
|
+
StepCounter(max_episode_steps),
|
|
73
|
+
DoubleToFloat(),
|
|
74
|
+
RewardSum(),
|
|
75
|
+
),
|
|
76
|
+
)
|
|
77
|
+
return transformed_env
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def make_environment(cfg, logger):
|
|
81
|
+
"""Make environments for training and evaluation."""
|
|
82
|
+
maker = functools.partial(env_maker, cfg, from_pixels=False)
|
|
83
|
+
parallel_env = ParallelEnv(
|
|
84
|
+
cfg.collector.env_per_collector,
|
|
85
|
+
EnvCreator(maker),
|
|
86
|
+
serial_for_single=True,
|
|
87
|
+
)
|
|
88
|
+
parallel_env.set_seed(cfg.env.seed)
|
|
89
|
+
|
|
90
|
+
train_env = apply_env_transforms(
|
|
91
|
+
parallel_env, max_episode_steps=cfg.env.max_episode_steps
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
maker = functools.partial(env_maker, cfg, from_pixels=cfg.logger.video)
|
|
95
|
+
eval_env = TransformedEnv(
|
|
96
|
+
ParallelEnv(
|
|
97
|
+
cfg.logger.num_eval_envs,
|
|
98
|
+
EnvCreator(maker),
|
|
99
|
+
serial_for_single=True,
|
|
100
|
+
),
|
|
101
|
+
train_env.transform.clone(),
|
|
102
|
+
)
|
|
103
|
+
eval_env.set_seed(0)
|
|
104
|
+
if cfg.logger.video:
|
|
105
|
+
eval_env = eval_env.append_transform(
|
|
106
|
+
VideoRecorder(logger, tag="rendered", in_keys=["pixels"])
|
|
107
|
+
)
|
|
108
|
+
return train_env, eval_env
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
# ====================================================================
|
|
112
|
+
# Collector and replay buffer
|
|
113
|
+
# ---------------------------
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def make_collector(
|
|
117
|
+
cfg,
|
|
118
|
+
train_env,
|
|
119
|
+
actor_model_explore,
|
|
120
|
+
compile=False,
|
|
121
|
+
compile_mode=None,
|
|
122
|
+
cudagraph=False,
|
|
123
|
+
device: torch.device | None = None,
|
|
124
|
+
):
|
|
125
|
+
"""Make collector."""
|
|
126
|
+
collector = SyncDataCollector(
|
|
127
|
+
train_env,
|
|
128
|
+
actor_model_explore,
|
|
129
|
+
frames_per_batch=cfg.collector.frames_per_batch,
|
|
130
|
+
init_random_frames=cfg.collector.init_random_frames,
|
|
131
|
+
reset_at_each_iter=cfg.collector.reset_at_each_iter,
|
|
132
|
+
total_frames=cfg.collector.total_frames,
|
|
133
|
+
device=device,
|
|
134
|
+
compile_policy={"mode": compile_mode, "fullgraph": True} if compile else False,
|
|
135
|
+
cudagraph_policy=cudagraph,
|
|
136
|
+
)
|
|
137
|
+
collector.set_seed(cfg.env.seed)
|
|
138
|
+
return collector
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def make_replay_buffer(
|
|
142
|
+
batch_size,
|
|
143
|
+
prb=False,
|
|
144
|
+
buffer_size=1000000,
|
|
145
|
+
scratch_dir=None,
|
|
146
|
+
device="cpu",
|
|
147
|
+
prefetch=3,
|
|
148
|
+
):
|
|
149
|
+
if prb:
|
|
150
|
+
replay_buffer = TensorDictPrioritizedReplayBuffer(
|
|
151
|
+
alpha=0.7,
|
|
152
|
+
beta=0.5,
|
|
153
|
+
pin_memory=False,
|
|
154
|
+
prefetch=prefetch,
|
|
155
|
+
storage=LazyMemmapStorage(
|
|
156
|
+
buffer_size,
|
|
157
|
+
scratch_dir=scratch_dir,
|
|
158
|
+
device=device,
|
|
159
|
+
),
|
|
160
|
+
batch_size=batch_size,
|
|
161
|
+
)
|
|
162
|
+
else:
|
|
163
|
+
replay_buffer = TensorDictReplayBuffer(
|
|
164
|
+
pin_memory=False,
|
|
165
|
+
prefetch=prefetch,
|
|
166
|
+
storage=LazyMemmapStorage(
|
|
167
|
+
buffer_size,
|
|
168
|
+
scratch_dir=scratch_dir,
|
|
169
|
+
device=device,
|
|
170
|
+
),
|
|
171
|
+
batch_size=batch_size,
|
|
172
|
+
)
|
|
173
|
+
return replay_buffer
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
# ====================================================================
|
|
177
|
+
# Model
|
|
178
|
+
# -----
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def make_ddpg_agent(cfg, train_env, eval_env, device):
|
|
182
|
+
"""Make DDPG agent."""
|
|
183
|
+
# Define Actor Network
|
|
184
|
+
in_keys = ["observation"]
|
|
185
|
+
action_spec = train_env.action_spec_unbatched
|
|
186
|
+
actor_net_kwargs = {
|
|
187
|
+
"num_cells": cfg.network.hidden_sizes,
|
|
188
|
+
"out_features": action_spec.shape[-1],
|
|
189
|
+
"activation_class": get_activation(cfg),
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
actor_net = MLP(**actor_net_kwargs)
|
|
193
|
+
|
|
194
|
+
in_keys_actor = in_keys
|
|
195
|
+
actor_module = TensorDictModule(
|
|
196
|
+
actor_net,
|
|
197
|
+
in_keys=in_keys_actor,
|
|
198
|
+
out_keys=["param"],
|
|
199
|
+
)
|
|
200
|
+
actor = TensorDictSequential(
|
|
201
|
+
actor_module,
|
|
202
|
+
TanhModule(
|
|
203
|
+
in_keys=["param"],
|
|
204
|
+
out_keys=["action"],
|
|
205
|
+
),
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Define Critic Network
|
|
209
|
+
qvalue_net_kwargs = {
|
|
210
|
+
"num_cells": cfg.network.hidden_sizes,
|
|
211
|
+
"out_features": 1,
|
|
212
|
+
"activation_class": get_activation(cfg),
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
qvalue_net = MLP(
|
|
216
|
+
**qvalue_net_kwargs,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
qvalue = ValueOperator(
|
|
220
|
+
in_keys=["action"] + in_keys,
|
|
221
|
+
module=qvalue_net,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
model = nn.ModuleList([actor, qvalue]).to(device)
|
|
225
|
+
|
|
226
|
+
# init nets
|
|
227
|
+
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
|
|
228
|
+
td = eval_env.reset()
|
|
229
|
+
td = td.to(device)
|
|
230
|
+
for net in model:
|
|
231
|
+
net(td)
|
|
232
|
+
del td
|
|
233
|
+
eval_env.close()
|
|
234
|
+
|
|
235
|
+
# Exploration wrappers:
|
|
236
|
+
if cfg.network.noise_type == "ou":
|
|
237
|
+
actor_model_explore = TensorDictSequential(
|
|
238
|
+
model[0],
|
|
239
|
+
OrnsteinUhlenbeckProcessModule(
|
|
240
|
+
spec=action_spec,
|
|
241
|
+
annealing_num_steps=1_000_000,
|
|
242
|
+
device=device,
|
|
243
|
+
safe=False,
|
|
244
|
+
),
|
|
245
|
+
)
|
|
246
|
+
elif cfg.network.noise_type == "gaussian":
|
|
247
|
+
actor_model_explore = TensorDictSequential(
|
|
248
|
+
model[0],
|
|
249
|
+
AdditiveGaussianModule(
|
|
250
|
+
spec=action_spec,
|
|
251
|
+
sigma_end=1.0,
|
|
252
|
+
sigma_init=1.0,
|
|
253
|
+
mean=0.0,
|
|
254
|
+
std=0.1,
|
|
255
|
+
device=device,
|
|
256
|
+
safe=False,
|
|
257
|
+
),
|
|
258
|
+
)
|
|
259
|
+
else:
|
|
260
|
+
raise NotImplementedError
|
|
261
|
+
|
|
262
|
+
return model, actor_model_explore
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
# ====================================================================
|
|
266
|
+
# DDPG Loss
|
|
267
|
+
# ---------
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def make_loss_module(cfg, model):
|
|
271
|
+
"""Make loss module and target network updater."""
|
|
272
|
+
# Create DDPG loss
|
|
273
|
+
loss_module = DDPGLoss(
|
|
274
|
+
actor_network=model[0],
|
|
275
|
+
value_network=model[1],
|
|
276
|
+
loss_function=cfg.optim.loss_function,
|
|
277
|
+
delay_actor=True,
|
|
278
|
+
delay_value=True,
|
|
279
|
+
)
|
|
280
|
+
loss_module.make_value_estimator(gamma=cfg.optim.gamma)
|
|
281
|
+
|
|
282
|
+
# Define Target Network Updater
|
|
283
|
+
target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak)
|
|
284
|
+
return loss_module, target_net_updater
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def make_optimizer(cfg, loss_module):
|
|
288
|
+
critic_params = list(loss_module.value_network_params.flatten_keys().values())
|
|
289
|
+
actor_params = list(loss_module.actor_network_params.flatten_keys().values())
|
|
290
|
+
|
|
291
|
+
optimizer_actor = optim.Adam(
|
|
292
|
+
actor_params, lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay
|
|
293
|
+
)
|
|
294
|
+
optimizer_critic = optim.Adam(
|
|
295
|
+
critic_params,
|
|
296
|
+
lr=cfg.optim.lr,
|
|
297
|
+
weight_decay=cfg.optim.weight_decay,
|
|
298
|
+
)
|
|
299
|
+
return optimizer_actor, optimizer_critic
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
# ====================================================================
|
|
303
|
+
# General utils
|
|
304
|
+
# ---------
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def log_metrics(logger, metrics, step):
|
|
308
|
+
for metric_name, metric_value in metrics.items():
|
|
309
|
+
logger.log_scalar(metric_name, metric_value, step)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def get_activation(cfg):
|
|
313
|
+
if cfg.network.activation == "relu":
|
|
314
|
+
return nn.ReLU
|
|
315
|
+
elif cfg.network.activation == "tanh":
|
|
316
|
+
return nn.Tanh
|
|
317
|
+
elif cfg.network.activation == "leaky_relu":
|
|
318
|
+
return nn.LeakyReLU
|
|
319
|
+
else:
|
|
320
|
+
raise NotImplementedError
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def dump_video(module):
|
|
324
|
+
if isinstance(module, VideoRecorder):
|
|
325
|
+
module.dump()
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""Decision Transformer Example.
|
|
6
|
+
This is a self-contained example of an offline Decision Transformer training script.
|
|
7
|
+
The helper functions are coded in the utils.py associated with this script.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import warnings
|
|
13
|
+
|
|
14
|
+
import hydra
|
|
15
|
+
import numpy as np
|
|
16
|
+
import torch
|
|
17
|
+
import tqdm
|
|
18
|
+
from tensordict import TensorDict
|
|
19
|
+
from tensordict.nn import CudaGraphModule
|
|
20
|
+
from torchrl._utils import get_available_device, logger as torchrl_logger, timeit
|
|
21
|
+
from torchrl.envs.libs.gym import set_gym_backend
|
|
22
|
+
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
23
|
+
from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper
|
|
24
|
+
from torchrl.record import VideoRecorder
|
|
25
|
+
from utils import (
|
|
26
|
+
dump_video,
|
|
27
|
+
log_metrics,
|
|
28
|
+
make_dt_loss,
|
|
29
|
+
make_dt_model,
|
|
30
|
+
make_dt_optimizer,
|
|
31
|
+
make_env,
|
|
32
|
+
make_logger,
|
|
33
|
+
make_offline_replay_buffer,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@hydra.main(config_path="", config_name="dt_config", version_base="1.1")
|
|
38
|
+
def main(cfg: DictConfig): # noqa: F821
|
|
39
|
+
set_gym_backend(cfg.env.backend).set()
|
|
40
|
+
|
|
41
|
+
model_device = (
|
|
42
|
+
torch.device(cfg.optim.device) if cfg.optim.device else get_available_device()
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Set seeds
|
|
46
|
+
torch.manual_seed(cfg.env.seed)
|
|
47
|
+
np.random.seed(cfg.env.seed)
|
|
48
|
+
|
|
49
|
+
# Create logger
|
|
50
|
+
logger = make_logger(cfg)
|
|
51
|
+
|
|
52
|
+
# Create offline replay buffer
|
|
53
|
+
offline_buffer, obs_loc, obs_std = make_offline_replay_buffer(
|
|
54
|
+
cfg.replay_buffer, cfg.env.reward_scaling
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Create test environment
|
|
58
|
+
test_env = make_env(
|
|
59
|
+
cfg.env, obs_loc, obs_std, from_pixels=cfg.logger.video, device=model_device
|
|
60
|
+
)
|
|
61
|
+
if cfg.logger.video:
|
|
62
|
+
test_env = test_env.append_transform(
|
|
63
|
+
VideoRecorder(logger, tag="rendered", in_keys=["pixels"])
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Create policy model
|
|
67
|
+
actor = make_dt_model(cfg, device=model_device)
|
|
68
|
+
|
|
69
|
+
# Create loss
|
|
70
|
+
loss_module = make_dt_loss(cfg.loss, actor, device=model_device)
|
|
71
|
+
|
|
72
|
+
# Create optimizer
|
|
73
|
+
transformer_optim, scheduler = make_dt_optimizer(
|
|
74
|
+
cfg.optim, loss_module, model_device
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Create inference policy
|
|
78
|
+
inference_policy = DecisionTransformerInferenceWrapper(
|
|
79
|
+
policy=actor,
|
|
80
|
+
inference_context=cfg.env.inference_context,
|
|
81
|
+
device=model_device,
|
|
82
|
+
)
|
|
83
|
+
inference_policy.set_tensor_keys(
|
|
84
|
+
observation="observation_cat",
|
|
85
|
+
action="action_cat",
|
|
86
|
+
return_to_go="return_to_go_cat",
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps
|
|
90
|
+
clip_grad = cfg.optim.clip_grad
|
|
91
|
+
|
|
92
|
+
def update(data: TensorDict) -> TensorDict:
|
|
93
|
+
transformer_optim.zero_grad(set_to_none=True)
|
|
94
|
+
# Compute loss
|
|
95
|
+
loss_vals = loss_module(data)
|
|
96
|
+
transformer_loss = loss_vals["loss"]
|
|
97
|
+
|
|
98
|
+
transformer_loss.backward()
|
|
99
|
+
torch.nn.utils.clip_grad_norm_(actor.parameters(), clip_grad)
|
|
100
|
+
transformer_optim.step()
|
|
101
|
+
|
|
102
|
+
return loss_vals
|
|
103
|
+
|
|
104
|
+
if cfg.compile.compile:
|
|
105
|
+
compile_mode = cfg.compile.compile_mode
|
|
106
|
+
if compile_mode in ("", None):
|
|
107
|
+
if cfg.compile.cudagraphs:
|
|
108
|
+
compile_mode = "default"
|
|
109
|
+
else:
|
|
110
|
+
compile_mode = "reduce-overhead"
|
|
111
|
+
update = torch.compile(update, mode=compile_mode, dynamic=True)
|
|
112
|
+
if cfg.compile.cudagraphs:
|
|
113
|
+
warnings.warn(
|
|
114
|
+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
|
|
115
|
+
category=UserWarning,
|
|
116
|
+
)
|
|
117
|
+
update = CudaGraphModule(update, warmup=50)
|
|
118
|
+
|
|
119
|
+
eval_steps = cfg.logger.eval_steps
|
|
120
|
+
pretrain_log_interval = cfg.logger.pretrain_log_interval
|
|
121
|
+
reward_scaling = cfg.env.reward_scaling
|
|
122
|
+
|
|
123
|
+
torchrl_logger.info(" ***Pretraining*** ")
|
|
124
|
+
# Pretraining
|
|
125
|
+
pbar = tqdm.tqdm(range(pretrain_gradient_steps))
|
|
126
|
+
for i in pbar:
|
|
127
|
+
timeit.printevery(1000, pretrain_gradient_steps, erase=True)
|
|
128
|
+
# Sample data
|
|
129
|
+
with timeit("rb - sample"):
|
|
130
|
+
data = offline_buffer.sample().to(model_device)
|
|
131
|
+
with timeit("update"):
|
|
132
|
+
loss_vals = update(data)
|
|
133
|
+
scheduler.step()
|
|
134
|
+
# Log metrics
|
|
135
|
+
metrics_to_log = {"train/loss": loss_vals["loss"]}
|
|
136
|
+
|
|
137
|
+
# Evaluation
|
|
138
|
+
with set_exploration_type(
|
|
139
|
+
ExplorationType.DETERMINISTIC
|
|
140
|
+
), torch.no_grad(), timeit("eval"):
|
|
141
|
+
if i % pretrain_log_interval == 0:
|
|
142
|
+
eval_td = test_env.rollout(
|
|
143
|
+
max_steps=eval_steps,
|
|
144
|
+
policy=inference_policy,
|
|
145
|
+
auto_cast_to_device=True,
|
|
146
|
+
)
|
|
147
|
+
test_env.apply(dump_video)
|
|
148
|
+
metrics_to_log["eval/reward"] = (
|
|
149
|
+
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if logger is not None:
|
|
153
|
+
metrics_to_log.update(timeit.todict(prefix="time"))
|
|
154
|
+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
|
|
155
|
+
log_metrics(logger, metrics_to_log, i)
|
|
156
|
+
|
|
157
|
+
pbar.close()
|
|
158
|
+
if not test_env.is_closed:
|
|
159
|
+
test_env.close()
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
if __name__ == "__main__":
|
|
163
|
+
main()
|