torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""Discrete SAC Example.
|
|
6
|
+
|
|
7
|
+
This is a simple self-contained example of a discrete SAC training script.
|
|
8
|
+
|
|
9
|
+
It supports gym state environments like CartPole.
|
|
10
|
+
|
|
11
|
+
The helper functions are coded in the utils.py associated with this script.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import warnings
|
|
17
|
+
|
|
18
|
+
import hydra
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
import torch.cuda
|
|
22
|
+
import tqdm
|
|
23
|
+
from tensordict.nn import CudaGraphModule
|
|
24
|
+
from torchrl._utils import timeit
|
|
25
|
+
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
26
|
+
from torchrl.objectives import group_optimizers
|
|
27
|
+
from torchrl.record.loggers import generate_exp_name, get_logger
|
|
28
|
+
from utils import (
|
|
29
|
+
dump_video,
|
|
30
|
+
log_metrics,
|
|
31
|
+
make_collector,
|
|
32
|
+
make_environment,
|
|
33
|
+
make_loss_module,
|
|
34
|
+
make_optimizer,
|
|
35
|
+
make_replay_buffer,
|
|
36
|
+
make_sac_agent,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@hydra.main(version_base="1.1", config_path="", config_name="config")
|
|
41
|
+
def main(cfg: DictConfig): # noqa: F821
|
|
42
|
+
device = cfg.network.device
|
|
43
|
+
if device in ("", None):
|
|
44
|
+
if torch.cuda.is_available():
|
|
45
|
+
device = "cuda:0"
|
|
46
|
+
else:
|
|
47
|
+
device = "cpu"
|
|
48
|
+
device = torch.device(device)
|
|
49
|
+
|
|
50
|
+
# Create logger
|
|
51
|
+
exp_name = generate_exp_name("DiscreteSAC", cfg.logger.exp_name)
|
|
52
|
+
logger = None
|
|
53
|
+
if cfg.logger.backend:
|
|
54
|
+
logger = get_logger(
|
|
55
|
+
logger_type=cfg.logger.backend,
|
|
56
|
+
logger_name="DiscreteSAC_logging",
|
|
57
|
+
experiment_name=exp_name,
|
|
58
|
+
wandb_kwargs={
|
|
59
|
+
"mode": cfg.logger.mode,
|
|
60
|
+
"config": dict(cfg),
|
|
61
|
+
"project": cfg.logger.project_name,
|
|
62
|
+
"group": cfg.logger.group_name,
|
|
63
|
+
},
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Set seeds
|
|
67
|
+
torch.manual_seed(cfg.env.seed)
|
|
68
|
+
np.random.seed(cfg.env.seed)
|
|
69
|
+
|
|
70
|
+
# Create environments
|
|
71
|
+
train_env, eval_env = make_environment(cfg, logger=logger)
|
|
72
|
+
|
|
73
|
+
# Create agent
|
|
74
|
+
model = make_sac_agent(cfg, train_env, eval_env, device)
|
|
75
|
+
|
|
76
|
+
# Create TD3 loss
|
|
77
|
+
loss_module, target_net_updater = make_loss_module(cfg, model)
|
|
78
|
+
|
|
79
|
+
# Create replay buffer
|
|
80
|
+
replay_buffer = make_replay_buffer(
|
|
81
|
+
batch_size=cfg.optim.batch_size,
|
|
82
|
+
prb=cfg.replay_buffer.prb,
|
|
83
|
+
buffer_size=cfg.replay_buffer.size,
|
|
84
|
+
scratch_dir=cfg.replay_buffer.scratch_dir,
|
|
85
|
+
device="cpu",
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Create optimizers
|
|
89
|
+
optimizer_actor, optimizer_critic, optimizer_alpha = make_optimizer(
|
|
90
|
+
cfg, loss_module
|
|
91
|
+
)
|
|
92
|
+
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha)
|
|
93
|
+
del optimizer_actor, optimizer_critic, optimizer_alpha
|
|
94
|
+
|
|
95
|
+
def update(sampled_tensordict):
|
|
96
|
+
optimizer.zero_grad(set_to_none=True)
|
|
97
|
+
|
|
98
|
+
# Compute loss
|
|
99
|
+
loss_out = loss_module(sampled_tensordict)
|
|
100
|
+
|
|
101
|
+
actor_loss, q_loss, alpha_loss = (
|
|
102
|
+
loss_out["loss_actor"],
|
|
103
|
+
loss_out["loss_qvalue"],
|
|
104
|
+
loss_out["loss_alpha"],
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Update critic
|
|
108
|
+
(q_loss + actor_loss + alpha_loss).backward()
|
|
109
|
+
optimizer.step()
|
|
110
|
+
|
|
111
|
+
# Update target params
|
|
112
|
+
target_net_updater.step()
|
|
113
|
+
|
|
114
|
+
return loss_out.detach()
|
|
115
|
+
|
|
116
|
+
compile_mode = None
|
|
117
|
+
if cfg.compile.compile:
|
|
118
|
+
compile_mode = cfg.compile.compile_mode
|
|
119
|
+
if compile_mode in ("", None):
|
|
120
|
+
if cfg.compile.cudagraphs:
|
|
121
|
+
compile_mode = "default"
|
|
122
|
+
else:
|
|
123
|
+
compile_mode = "reduce-overhead"
|
|
124
|
+
update = torch.compile(update, mode=compile_mode)
|
|
125
|
+
if cfg.compile.cudagraphs:
|
|
126
|
+
warnings.warn(
|
|
127
|
+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
|
|
128
|
+
category=UserWarning,
|
|
129
|
+
)
|
|
130
|
+
update = CudaGraphModule(update, warmup=50)
|
|
131
|
+
|
|
132
|
+
# Create off-policy collector
|
|
133
|
+
collector = make_collector(
|
|
134
|
+
cfg,
|
|
135
|
+
train_env,
|
|
136
|
+
model[0],
|
|
137
|
+
compile=compile_mode is not None,
|
|
138
|
+
compile_mode=compile_mode,
|
|
139
|
+
cudagraphs=cfg.compile.cudagraphs,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# Main loop
|
|
143
|
+
collected_frames = 0
|
|
144
|
+
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
|
|
145
|
+
|
|
146
|
+
init_random_frames = cfg.collector.init_random_frames
|
|
147
|
+
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
|
|
148
|
+
prb = cfg.replay_buffer.prb
|
|
149
|
+
eval_rollout_steps = cfg.env.max_episode_steps
|
|
150
|
+
eval_iter = cfg.logger.eval_iter
|
|
151
|
+
frames_per_batch = cfg.collector.frames_per_batch
|
|
152
|
+
|
|
153
|
+
c_iter = iter(collector)
|
|
154
|
+
total_iter = len(collector)
|
|
155
|
+
for i in range(total_iter):
|
|
156
|
+
timeit.printevery(1000, total_iter, erase=True)
|
|
157
|
+
with timeit("collecting"):
|
|
158
|
+
collected_data = next(c_iter)
|
|
159
|
+
|
|
160
|
+
# Update weights of the inference policy
|
|
161
|
+
collector.update_policy_weights_()
|
|
162
|
+
current_frames = collected_data.numel()
|
|
163
|
+
|
|
164
|
+
pbar.update(current_frames)
|
|
165
|
+
|
|
166
|
+
collected_data = collected_data.reshape(-1)
|
|
167
|
+
with timeit("rb - extend"):
|
|
168
|
+
# Add to replay buffer
|
|
169
|
+
replay_buffer.extend(collected_data)
|
|
170
|
+
collected_frames += current_frames
|
|
171
|
+
|
|
172
|
+
# Optimization steps
|
|
173
|
+
if collected_frames >= init_random_frames:
|
|
174
|
+
tds = []
|
|
175
|
+
for _ in range(num_updates):
|
|
176
|
+
with timeit("rb - sample"):
|
|
177
|
+
# Sample from replay buffer
|
|
178
|
+
sampled_tensordict = replay_buffer.sample()
|
|
179
|
+
|
|
180
|
+
with timeit("update"):
|
|
181
|
+
torch.compiler.cudagraph_mark_step_begin()
|
|
182
|
+
sampled_tensordict = sampled_tensordict.to(device)
|
|
183
|
+
loss_out = update(sampled_tensordict).clone()
|
|
184
|
+
|
|
185
|
+
tds.append(loss_out)
|
|
186
|
+
|
|
187
|
+
# Update priority
|
|
188
|
+
if prb:
|
|
189
|
+
replay_buffer.update_priority(sampled_tensordict)
|
|
190
|
+
tds = torch.stack(tds).mean()
|
|
191
|
+
|
|
192
|
+
# Logging
|
|
193
|
+
episode_end = (
|
|
194
|
+
collected_data["next", "done"]
|
|
195
|
+
if collected_data["next", "done"].any()
|
|
196
|
+
else collected_data["next", "truncated"]
|
|
197
|
+
)
|
|
198
|
+
episode_rewards = collected_data["next", "episode_reward"][episode_end]
|
|
199
|
+
|
|
200
|
+
metrics_to_log = {}
|
|
201
|
+
if len(episode_rewards) > 0:
|
|
202
|
+
episode_length = collected_data["next", "step_count"][episode_end]
|
|
203
|
+
metrics_to_log["train/reward"] = episode_rewards.mean().item()
|
|
204
|
+
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
|
|
205
|
+
episode_length
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
if collected_frames >= init_random_frames:
|
|
209
|
+
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
|
|
210
|
+
metrics_to_log["train/a_loss"] = tds["loss_actor"]
|
|
211
|
+
metrics_to_log["train/alpha_loss"] = tds["loss_alpha"]
|
|
212
|
+
|
|
213
|
+
# Evaluation
|
|
214
|
+
prev_test_frame = ((i - 1) * frames_per_batch) // eval_iter
|
|
215
|
+
cur_test_frame = (i * frames_per_batch) // eval_iter
|
|
216
|
+
final = current_frames >= collector.total_frames
|
|
217
|
+
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
|
|
218
|
+
with set_exploration_type(
|
|
219
|
+
ExplorationType.DETERMINISTIC
|
|
220
|
+
), torch.no_grad(), timeit("eval"):
|
|
221
|
+
eval_rollout = eval_env.rollout(
|
|
222
|
+
eval_rollout_steps,
|
|
223
|
+
model[0],
|
|
224
|
+
auto_cast_to_device=True,
|
|
225
|
+
break_when_any_done=True,
|
|
226
|
+
)
|
|
227
|
+
eval_env.apply(dump_video)
|
|
228
|
+
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
|
|
229
|
+
metrics_to_log["eval/reward"] = eval_reward
|
|
230
|
+
if logger is not None:
|
|
231
|
+
metrics_to_log.update(timeit.todict(prefix="time"))
|
|
232
|
+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
|
|
233
|
+
log_metrics(logger, metrics_to_log, collected_frames)
|
|
234
|
+
|
|
235
|
+
collector.shutdown()
|
|
236
|
+
if not eval_env.is_closed:
|
|
237
|
+
eval_env.close()
|
|
238
|
+
if not train_env.is_closed:
|
|
239
|
+
train_env.close()
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
if __name__ == "__main__":
|
|
243
|
+
main()
|
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
import tempfile
|
|
9
|
+
from contextlib import nullcontext
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from tensordict.nn import InteractionType, TensorDictModule
|
|
13
|
+
|
|
14
|
+
from torch import nn, optim
|
|
15
|
+
from torchrl.collectors import SyncDataCollector
|
|
16
|
+
from torchrl.data import (
|
|
17
|
+
Composite,
|
|
18
|
+
TensorDictPrioritizedReplayBuffer,
|
|
19
|
+
TensorDictReplayBuffer,
|
|
20
|
+
)
|
|
21
|
+
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
|
|
22
|
+
from torchrl.envs import (
|
|
23
|
+
CatTensors,
|
|
24
|
+
Compose,
|
|
25
|
+
DMControlEnv,
|
|
26
|
+
DoubleToFloat,
|
|
27
|
+
EnvCreator,
|
|
28
|
+
InitTracker,
|
|
29
|
+
ParallelEnv,
|
|
30
|
+
RewardSum,
|
|
31
|
+
StepCounter,
|
|
32
|
+
TransformedEnv,
|
|
33
|
+
)
|
|
34
|
+
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
|
|
35
|
+
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
36
|
+
from torchrl.modules import MLP, SafeModule
|
|
37
|
+
from torchrl.modules.distributions import OneHotCategorical
|
|
38
|
+
|
|
39
|
+
from torchrl.modules.tensordict_module.actors import ProbabilisticActor
|
|
40
|
+
from torchrl.objectives import SoftUpdate
|
|
41
|
+
from torchrl.objectives.sac import DiscreteSACLoss
|
|
42
|
+
from torchrl.record import VideoRecorder
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# ====================================================================
|
|
46
|
+
# Environment utils
|
|
47
|
+
# -----------------
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def env_maker(cfg, device="cpu", from_pixels=False):
|
|
51
|
+
lib = cfg.env.library
|
|
52
|
+
if lib in ("gym", "gymnasium"):
|
|
53
|
+
with set_gym_backend(lib):
|
|
54
|
+
return GymEnv(
|
|
55
|
+
cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False
|
|
56
|
+
)
|
|
57
|
+
elif lib == "dm_control":
|
|
58
|
+
env = DMControlEnv(
|
|
59
|
+
cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False
|
|
60
|
+
)
|
|
61
|
+
return TransformedEnv(
|
|
62
|
+
env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
raise NotImplementedError(f"Unknown lib {lib}.")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def apply_env_transforms(env, max_episode_steps):
|
|
69
|
+
transformed_env = TransformedEnv(
|
|
70
|
+
env,
|
|
71
|
+
Compose(
|
|
72
|
+
StepCounter(max_steps=max_episode_steps),
|
|
73
|
+
InitTracker(),
|
|
74
|
+
DoubleToFloat(),
|
|
75
|
+
RewardSum(),
|
|
76
|
+
),
|
|
77
|
+
)
|
|
78
|
+
return transformed_env
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def make_environment(cfg, logger=None):
|
|
82
|
+
"""Make environments for training and evaluation."""
|
|
83
|
+
maker = functools.partial(env_maker, cfg)
|
|
84
|
+
parallel_env = ParallelEnv(
|
|
85
|
+
cfg.collector.env_per_collector,
|
|
86
|
+
EnvCreator(maker),
|
|
87
|
+
serial_for_single=True,
|
|
88
|
+
)
|
|
89
|
+
parallel_env.set_seed(cfg.env.seed)
|
|
90
|
+
|
|
91
|
+
train_env = apply_env_transforms(
|
|
92
|
+
parallel_env, max_episode_steps=cfg.env.max_episode_steps
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
maker = functools.partial(env_maker, cfg, from_pixels=cfg.logger.video)
|
|
96
|
+
eval_env = TransformedEnv(
|
|
97
|
+
ParallelEnv(
|
|
98
|
+
cfg.collector.env_per_collector,
|
|
99
|
+
EnvCreator(maker),
|
|
100
|
+
serial_for_single=True,
|
|
101
|
+
),
|
|
102
|
+
train_env.transform.clone(),
|
|
103
|
+
)
|
|
104
|
+
if cfg.logger.video:
|
|
105
|
+
eval_env = eval_env.insert_transform(
|
|
106
|
+
0, VideoRecorder(logger, tag="rendered", in_keys=["pixels"])
|
|
107
|
+
)
|
|
108
|
+
return train_env, eval_env
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
# ====================================================================
|
|
112
|
+
# Collector and replay buffer
|
|
113
|
+
# ---------------------------
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def make_collector(
|
|
117
|
+
cfg,
|
|
118
|
+
train_env,
|
|
119
|
+
actor_model_explore,
|
|
120
|
+
compile=False,
|
|
121
|
+
compile_mode=None,
|
|
122
|
+
cudagraphs=False,
|
|
123
|
+
):
|
|
124
|
+
"""Make collector."""
|
|
125
|
+
device = cfg.collector.device
|
|
126
|
+
if device in ("", None):
|
|
127
|
+
if torch.cuda.is_available():
|
|
128
|
+
device = "cuda:0"
|
|
129
|
+
else:
|
|
130
|
+
device = "cpu"
|
|
131
|
+
device = torch.device(device)
|
|
132
|
+
collector = SyncDataCollector(
|
|
133
|
+
train_env,
|
|
134
|
+
actor_model_explore,
|
|
135
|
+
init_random_frames=cfg.collector.init_random_frames,
|
|
136
|
+
frames_per_batch=cfg.collector.frames_per_batch,
|
|
137
|
+
total_frames=cfg.collector.total_frames,
|
|
138
|
+
reset_at_each_iter=cfg.collector.reset_at_each_iter,
|
|
139
|
+
device=device,
|
|
140
|
+
storing_device="cpu",
|
|
141
|
+
compile_policy=False if not compile else {"mode": compile_mode},
|
|
142
|
+
cudagraph_policy={"warmup": 10} if cudagraphs else False,
|
|
143
|
+
)
|
|
144
|
+
collector.set_seed(cfg.env.seed)
|
|
145
|
+
return collector
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def make_replay_buffer(
|
|
149
|
+
batch_size,
|
|
150
|
+
prb=False,
|
|
151
|
+
buffer_size=1000000,
|
|
152
|
+
scratch_dir=None,
|
|
153
|
+
device="cpu",
|
|
154
|
+
prefetch=3,
|
|
155
|
+
):
|
|
156
|
+
with (
|
|
157
|
+
tempfile.TemporaryDirectory()
|
|
158
|
+
if scratch_dir is None
|
|
159
|
+
else nullcontext(scratch_dir)
|
|
160
|
+
) as scratch_dir:
|
|
161
|
+
if prb:
|
|
162
|
+
replay_buffer = TensorDictPrioritizedReplayBuffer(
|
|
163
|
+
alpha=0.7,
|
|
164
|
+
beta=0.5,
|
|
165
|
+
pin_memory=False,
|
|
166
|
+
prefetch=prefetch,
|
|
167
|
+
storage=LazyMemmapStorage(
|
|
168
|
+
buffer_size,
|
|
169
|
+
scratch_dir=scratch_dir,
|
|
170
|
+
device=device,
|
|
171
|
+
),
|
|
172
|
+
batch_size=batch_size,
|
|
173
|
+
)
|
|
174
|
+
else:
|
|
175
|
+
replay_buffer = TensorDictReplayBuffer(
|
|
176
|
+
pin_memory=False,
|
|
177
|
+
prefetch=prefetch,
|
|
178
|
+
storage=LazyMemmapStorage(
|
|
179
|
+
buffer_size,
|
|
180
|
+
scratch_dir=scratch_dir,
|
|
181
|
+
device=device,
|
|
182
|
+
),
|
|
183
|
+
batch_size=batch_size,
|
|
184
|
+
)
|
|
185
|
+
return replay_buffer
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
# ====================================================================
|
|
189
|
+
# Model
|
|
190
|
+
# -----
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def make_sac_agent(cfg, train_env, eval_env, device):
|
|
194
|
+
"""Make discrete SAC agent."""
|
|
195
|
+
# Define Actor Network
|
|
196
|
+
in_keys = ["observation"]
|
|
197
|
+
action_spec = train_env.action_spec
|
|
198
|
+
if train_env.batch_size:
|
|
199
|
+
action_spec = action_spec[(0,) * len(train_env.batch_size)]
|
|
200
|
+
# Define Actor Network
|
|
201
|
+
in_keys = ["observation"]
|
|
202
|
+
|
|
203
|
+
actor_net_kwargs = {
|
|
204
|
+
"num_cells": cfg.network.hidden_sizes,
|
|
205
|
+
"out_features": action_spec.shape[-1],
|
|
206
|
+
"activation_class": get_activation(cfg),
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
actor_net = MLP(**actor_net_kwargs)
|
|
210
|
+
|
|
211
|
+
actor_module = SafeModule(
|
|
212
|
+
module=actor_net,
|
|
213
|
+
in_keys=in_keys,
|
|
214
|
+
out_keys=["logits"],
|
|
215
|
+
)
|
|
216
|
+
actor = ProbabilisticActor(
|
|
217
|
+
spec=Composite(action=eval_env.action_spec),
|
|
218
|
+
module=actor_module,
|
|
219
|
+
in_keys=["logits"],
|
|
220
|
+
out_keys=["action"],
|
|
221
|
+
distribution_class=OneHotCategorical,
|
|
222
|
+
distribution_kwargs={},
|
|
223
|
+
default_interaction_type=InteractionType.RANDOM,
|
|
224
|
+
return_log_prob=False,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Define Critic Network
|
|
228
|
+
qvalue_net_kwargs = {
|
|
229
|
+
"num_cells": cfg.network.hidden_sizes,
|
|
230
|
+
"out_features": action_spec.shape[-1],
|
|
231
|
+
"activation_class": get_activation(cfg),
|
|
232
|
+
}
|
|
233
|
+
qvalue_net = MLP(
|
|
234
|
+
**qvalue_net_kwargs,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
qvalue = TensorDictModule(
|
|
238
|
+
in_keys=in_keys,
|
|
239
|
+
out_keys=["action_value"],
|
|
240
|
+
module=qvalue_net,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
model = torch.nn.ModuleList([actor, qvalue]).to(device)
|
|
244
|
+
# init nets
|
|
245
|
+
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
|
|
246
|
+
td = eval_env.reset()
|
|
247
|
+
td = td.to(device)
|
|
248
|
+
for net in model:
|
|
249
|
+
net(td)
|
|
250
|
+
del td
|
|
251
|
+
eval_env.close()
|
|
252
|
+
|
|
253
|
+
return model
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
# ====================================================================
|
|
257
|
+
# Discrete SAC Loss
|
|
258
|
+
# ---------
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def make_loss_module(cfg, model):
|
|
262
|
+
"""Make loss module and target network updater."""
|
|
263
|
+
# Create discrete SAC loss
|
|
264
|
+
loss_module = DiscreteSACLoss(
|
|
265
|
+
actor_network=model[0],
|
|
266
|
+
qvalue_network=model[1],
|
|
267
|
+
num_actions=model[0].spec["action"].space.n,
|
|
268
|
+
num_qvalue_nets=2,
|
|
269
|
+
loss_function=cfg.optim.loss_function,
|
|
270
|
+
target_entropy_weight=cfg.optim.target_entropy_weight,
|
|
271
|
+
delay_qvalue=True,
|
|
272
|
+
)
|
|
273
|
+
loss_module.make_value_estimator(gamma=cfg.optim.gamma)
|
|
274
|
+
|
|
275
|
+
# Define Target Network Updater
|
|
276
|
+
target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak)
|
|
277
|
+
return loss_module, target_net_updater
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def make_optimizer(cfg, loss_module):
|
|
281
|
+
critic_params = list(loss_module.qvalue_network_params.flatten_keys().values())
|
|
282
|
+
actor_params = list(loss_module.actor_network_params.flatten_keys().values())
|
|
283
|
+
|
|
284
|
+
optimizer_actor = optim.Adam(
|
|
285
|
+
actor_params,
|
|
286
|
+
lr=cfg.optim.lr,
|
|
287
|
+
weight_decay=cfg.optim.weight_decay,
|
|
288
|
+
)
|
|
289
|
+
optimizer_critic = optim.Adam(
|
|
290
|
+
critic_params,
|
|
291
|
+
lr=cfg.optim.lr,
|
|
292
|
+
weight_decay=cfg.optim.weight_decay,
|
|
293
|
+
)
|
|
294
|
+
optimizer_alpha = optim.Adam(
|
|
295
|
+
[loss_module.log_alpha],
|
|
296
|
+
lr=3.0e-4,
|
|
297
|
+
)
|
|
298
|
+
return optimizer_actor, optimizer_critic, optimizer_alpha
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
# ====================================================================
|
|
302
|
+
# General utils
|
|
303
|
+
# ---------
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def log_metrics(logger, metrics, step):
|
|
307
|
+
for metric_name, metric_value in metrics.items():
|
|
308
|
+
logger.log_scalar(metric_name, metric_value, step)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def get_activation(cfg):
|
|
312
|
+
if cfg.network.activation == "relu":
|
|
313
|
+
return nn.ReLU
|
|
314
|
+
elif cfg.network.activation == "tanh":
|
|
315
|
+
return nn.Tanh
|
|
316
|
+
elif cfg.network.activation == "leaky_relu":
|
|
317
|
+
return nn.LeakyReLU
|
|
318
|
+
else:
|
|
319
|
+
raise NotImplementedError
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def dump_video(module):
|
|
323
|
+
if isinstance(module, VideoRecorder):
|
|
324
|
+
module.dump()
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
## Reproducing Deep Q-Learning (DQN) Algorithm Results
|
|
2
|
+
|
|
3
|
+
This repository contains scripts that enable training agents using the Deep Q-Learning (DQN) Algorithm on CartPole and Atari environments. For Atari, We follow the original paper [Playing Atari with Deep Reinforcement Learning](https://arxiv.org/abs/1312.5602) by Mnih et al. (2013).
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
## Examples Structure
|
|
7
|
+
|
|
8
|
+
Please note that each example is independent of each other for the sake of simplicity. Each example contains the following files:
|
|
9
|
+
|
|
10
|
+
1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. dqn_atari.py).
|
|
11
|
+
|
|
12
|
+
2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils_atari.py).
|
|
13
|
+
|
|
14
|
+
3. **Configuration File:** This file includes default hyperparameters specified in the original paper. Users can modify these hyperparameters to customize their experiments (e.g. config_atari.yaml).
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
## Running the Examples
|
|
18
|
+
|
|
19
|
+
You can execute the DQN algorithm on the CartPole environment by running the following command:
|
|
20
|
+
|
|
21
|
+
```bash
|
|
22
|
+
python dqn_cartpole.py
|
|
23
|
+
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
You can execute the DQN algorithm on Atari environments by running the following command:
|
|
27
|
+
|
|
28
|
+
```bash
|
|
29
|
+
python dqn_atari.py
|
|
30
|
+
```
|