torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
## Reproducing Importance Weighted Actor-Learner Architecture (IMPALA) Algorithm Results
|
|
2
|
+
|
|
3
|
+
This repository contains scripts that enable training agents using the IMPALA Algorithm on MuJoCo and Atari environments. We follow the original paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) by Espeholt et al. 2018.
|
|
4
|
+
|
|
5
|
+
## Examples Structure
|
|
6
|
+
|
|
7
|
+
Please note that we provide 2 examples, one for single node training and one for distributed training. Both examples rely on the same utils file, but besides that are independent. Each example contains the following files:
|
|
8
|
+
|
|
9
|
+
1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. impala_single_node_ray.py).
|
|
10
|
+
|
|
11
|
+
2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils.py).
|
|
12
|
+
|
|
13
|
+
3. **Configuration File:** This file includes default hyperparameters specified in the original paper. For the multi-node case, the file also includes the configuration file of the Ray cluster. Users can modify these hyperparameters to customize their experiments (e.g. config_single_node.yaml).
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
## Running the Examples
|
|
17
|
+
|
|
18
|
+
You can execute the single node IMPALA algorithm on Atari environments by running the following command:
|
|
19
|
+
|
|
20
|
+
```bash
|
|
21
|
+
python impala_single_node.py
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
You can execute the multi-node IMPALA algorithm on Atari environments by running the following command:
|
|
25
|
+
|
|
26
|
+
```bash
|
|
27
|
+
python impala_single_node_ray.py
|
|
28
|
+
```
|
|
29
|
+
or
|
|
30
|
+
|
|
31
|
+
```bash
|
|
32
|
+
python impala_single_node_submitit.py
|
|
33
|
+
```
|
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
This script reproduces the IMPALA Algorithm
|
|
8
|
+
results from Espeholt et al. 2018 for the on Atari Environments.
|
|
9
|
+
"""
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import hydra
|
|
13
|
+
from torchrl._utils import logger as torchrl_logger
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@hydra.main(config_path="", config_name="config_multi_node_ray", version_base="1.1")
|
|
17
|
+
def main(cfg: DictConfig): # noqa: F821
|
|
18
|
+
|
|
19
|
+
import time
|
|
20
|
+
|
|
21
|
+
import torch.optim
|
|
22
|
+
import tqdm
|
|
23
|
+
|
|
24
|
+
from tensordict import TensorDict
|
|
25
|
+
from torchrl.collectors import SyncDataCollector
|
|
26
|
+
from torchrl.collectors.distributed import RayCollector
|
|
27
|
+
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
|
28
|
+
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
|
|
29
|
+
from torchrl.envs import ExplorationType, set_exploration_type
|
|
30
|
+
from torchrl.objectives import A2CLoss
|
|
31
|
+
from torchrl.objectives.value import VTrace
|
|
32
|
+
from torchrl.record.loggers import generate_exp_name, get_logger
|
|
33
|
+
from utils import eval_model, make_env, make_ppo_models
|
|
34
|
+
|
|
35
|
+
device = cfg.local_device
|
|
36
|
+
if not device:
|
|
37
|
+
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
|
|
38
|
+
else:
|
|
39
|
+
device = torch.device(device)
|
|
40
|
+
|
|
41
|
+
# Correct for frame_skip
|
|
42
|
+
frame_skip = 4
|
|
43
|
+
total_frames = cfg.collector.total_frames // frame_skip
|
|
44
|
+
frames_per_batch = cfg.collector.frames_per_batch // frame_skip
|
|
45
|
+
test_interval = cfg.logger.test_interval // frame_skip
|
|
46
|
+
|
|
47
|
+
# Extract other config parameters
|
|
48
|
+
batch_size = cfg.loss.batch_size # Number of rollouts per batch
|
|
49
|
+
num_workers = (
|
|
50
|
+
cfg.collector.num_workers
|
|
51
|
+
) # Number of parallel workers collecting rollouts
|
|
52
|
+
lr = cfg.optim.lr
|
|
53
|
+
anneal_lr = cfg.optim.anneal_lr
|
|
54
|
+
sgd_updates = cfg.loss.sgd_updates
|
|
55
|
+
max_grad_norm = cfg.optim.max_grad_norm
|
|
56
|
+
num_test_episodes = cfg.logger.num_test_episodes
|
|
57
|
+
total_network_updates = (
|
|
58
|
+
total_frames // (frames_per_batch * batch_size)
|
|
59
|
+
) * cfg.loss.sgd_updates
|
|
60
|
+
|
|
61
|
+
# Create models (check utils.py)
|
|
62
|
+
actor, critic = make_ppo_models(cfg.env.env_name, cfg.env.backend)
|
|
63
|
+
actor, critic = actor.to(device), critic.to(device)
|
|
64
|
+
|
|
65
|
+
# Create collector
|
|
66
|
+
ray_init_config = {
|
|
67
|
+
"address": cfg.ray_init_config.address,
|
|
68
|
+
"num_cpus": cfg.ray_init_config.num_cpus,
|
|
69
|
+
"num_gpus": cfg.ray_init_config.num_gpus,
|
|
70
|
+
"resources": cfg.ray_init_config.resources,
|
|
71
|
+
"object_store_memory": cfg.ray_init_config.object_store_memory,
|
|
72
|
+
"local_mode": cfg.ray_init_config.local_mode,
|
|
73
|
+
"ignore_reinit_error": cfg.ray_init_config.ignore_reinit_error,
|
|
74
|
+
"include_dashboard": cfg.ray_init_config.include_dashboard,
|
|
75
|
+
"dashboard_host": cfg.ray_init_config.dashboard_host,
|
|
76
|
+
"dashboard_port": cfg.ray_init_config.dashboard_port,
|
|
77
|
+
"job_config": cfg.ray_init_config.job_config,
|
|
78
|
+
"configure_logging": cfg.ray_init_config.configure_logging,
|
|
79
|
+
"logging_level": cfg.ray_init_config.logging_level,
|
|
80
|
+
"logging_format": cfg.ray_init_config.logging_format,
|
|
81
|
+
"log_to_driver": cfg.ray_init_config.log_to_driver,
|
|
82
|
+
"namespace": cfg.ray_init_config.namespace,
|
|
83
|
+
"runtime_env": cfg.ray_init_config.runtime_env,
|
|
84
|
+
"storage": cfg.ray_init_config.storage,
|
|
85
|
+
}
|
|
86
|
+
remote_config = {
|
|
87
|
+
"num_cpus": cfg.remote_worker_resources.num_cpus,
|
|
88
|
+
"num_gpus": cfg.remote_worker_resources.num_gpus
|
|
89
|
+
if torch.cuda.device_count()
|
|
90
|
+
else 0,
|
|
91
|
+
"memory": cfg.remote_worker_resources.memory,
|
|
92
|
+
}
|
|
93
|
+
collector = RayCollector(
|
|
94
|
+
create_env_fn=[make_env(cfg.env.env_name, device, gym_backend=cfg.env.backend)]
|
|
95
|
+
* num_workers,
|
|
96
|
+
policy=actor,
|
|
97
|
+
collector_class=SyncDataCollector,
|
|
98
|
+
frames_per_batch=frames_per_batch,
|
|
99
|
+
total_frames=total_frames,
|
|
100
|
+
max_frames_per_traj=-1,
|
|
101
|
+
ray_init_config=ray_init_config,
|
|
102
|
+
remote_configs=remote_config,
|
|
103
|
+
sync=False,
|
|
104
|
+
update_after_each_batch=True,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Create data buffer
|
|
108
|
+
sampler = SamplerWithoutReplacement()
|
|
109
|
+
data_buffer = TensorDictReplayBuffer(
|
|
110
|
+
storage=LazyMemmapStorage(frames_per_batch * batch_size),
|
|
111
|
+
sampler=sampler,
|
|
112
|
+
batch_size=frames_per_batch * batch_size,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Create loss and adv modules
|
|
116
|
+
adv_module = VTrace(
|
|
117
|
+
gamma=cfg.loss.gamma,
|
|
118
|
+
value_network=critic,
|
|
119
|
+
actor_network=actor,
|
|
120
|
+
average_adv=False,
|
|
121
|
+
)
|
|
122
|
+
loss_module = A2CLoss(
|
|
123
|
+
actor_network=actor,
|
|
124
|
+
critic_network=critic,
|
|
125
|
+
loss_critic_type=cfg.loss.loss_critic_type,
|
|
126
|
+
entropy_coeff=cfg.loss.entropy_coeff,
|
|
127
|
+
critic_coeff=cfg.loss.critic_coeff,
|
|
128
|
+
)
|
|
129
|
+
loss_module.set_keys(done="eol", terminated="eol")
|
|
130
|
+
|
|
131
|
+
# Create optimizer
|
|
132
|
+
optim = torch.optim.RMSprop(
|
|
133
|
+
loss_module.parameters(),
|
|
134
|
+
lr=cfg.optim.lr,
|
|
135
|
+
weight_decay=cfg.optim.weight_decay,
|
|
136
|
+
eps=cfg.optim.eps,
|
|
137
|
+
alpha=cfg.optim.alpha,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Create logger
|
|
141
|
+
logger = None
|
|
142
|
+
if cfg.logger.backend:
|
|
143
|
+
exp_name = generate_exp_name(
|
|
144
|
+
"IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}"
|
|
145
|
+
)
|
|
146
|
+
logger = get_logger(
|
|
147
|
+
cfg.logger.backend,
|
|
148
|
+
logger_name="impala",
|
|
149
|
+
experiment_name=exp_name,
|
|
150
|
+
wandb_kwargs={
|
|
151
|
+
"config": dict(cfg),
|
|
152
|
+
"project": cfg.logger.project_name,
|
|
153
|
+
"group": cfg.logger.group_name,
|
|
154
|
+
},
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Create test environment
|
|
158
|
+
test_env = make_env(
|
|
159
|
+
cfg.env.env_name, device, gym_backend=cfg.env.backend, is_test=True
|
|
160
|
+
)
|
|
161
|
+
test_env.eval()
|
|
162
|
+
|
|
163
|
+
# Main loop
|
|
164
|
+
collected_frames = 0
|
|
165
|
+
num_network_updates = 0
|
|
166
|
+
pbar = tqdm.tqdm(total=total_frames)
|
|
167
|
+
accumulator = []
|
|
168
|
+
start_time = sampling_start = time.time()
|
|
169
|
+
for i, data in enumerate(collector):
|
|
170
|
+
|
|
171
|
+
metrics_to_log = {}
|
|
172
|
+
sampling_time = time.time() - sampling_start
|
|
173
|
+
frames_in_batch = data.numel()
|
|
174
|
+
collected_frames += frames_in_batch * frame_skip
|
|
175
|
+
pbar.update(data.numel())
|
|
176
|
+
|
|
177
|
+
# Get training rewards and episode lengths
|
|
178
|
+
episode_rewards = data["next", "episode_reward"][data["next", "terminated"]]
|
|
179
|
+
if len(episode_rewards) > 0:
|
|
180
|
+
episode_length = data["next", "step_count"][data["next", "terminated"]]
|
|
181
|
+
metrics_to_log.update(
|
|
182
|
+
{
|
|
183
|
+
"train/reward": episode_rewards.mean().item(),
|
|
184
|
+
"train/episode_length": episode_length.sum().item()
|
|
185
|
+
/ len(episode_length),
|
|
186
|
+
}
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
if len(accumulator) < batch_size:
|
|
190
|
+
accumulator.append(data)
|
|
191
|
+
if logger:
|
|
192
|
+
for key, value in metrics_to_log.items():
|
|
193
|
+
logger.log_scalar(key, value, collected_frames)
|
|
194
|
+
continue
|
|
195
|
+
|
|
196
|
+
losses = TensorDict(batch_size=[sgd_updates])
|
|
197
|
+
training_start = time.time()
|
|
198
|
+
for j in range(sgd_updates):
|
|
199
|
+
|
|
200
|
+
# Create a single batch of trajectories
|
|
201
|
+
stacked_data = torch.stack(accumulator, dim=0).contiguous()
|
|
202
|
+
stacked_data = stacked_data.to(device, non_blocking=True)
|
|
203
|
+
|
|
204
|
+
# Compute advantage
|
|
205
|
+
with torch.no_grad():
|
|
206
|
+
stacked_data = adv_module(stacked_data)
|
|
207
|
+
|
|
208
|
+
# Add to replay buffer
|
|
209
|
+
for stacked_d in stacked_data:
|
|
210
|
+
stacked_data_reshape = stacked_d.reshape(-1)
|
|
211
|
+
data_buffer.extend(stacked_data_reshape)
|
|
212
|
+
|
|
213
|
+
for batch in data_buffer:
|
|
214
|
+
|
|
215
|
+
# Linearly decrease the learning rate and clip epsilon
|
|
216
|
+
alpha = 1.0
|
|
217
|
+
if anneal_lr:
|
|
218
|
+
alpha = 1 - (num_network_updates / total_network_updates)
|
|
219
|
+
for group in optim.param_groups:
|
|
220
|
+
group["lr"] = lr * alpha
|
|
221
|
+
num_network_updates += 1
|
|
222
|
+
|
|
223
|
+
# Get a data batch
|
|
224
|
+
batch = batch.to(device, non_blocking=True)
|
|
225
|
+
|
|
226
|
+
# Forward pass loss
|
|
227
|
+
loss = loss_module(batch)
|
|
228
|
+
losses[j] = loss.select(
|
|
229
|
+
"loss_critic", "loss_entropy", "loss_objective"
|
|
230
|
+
).detach()
|
|
231
|
+
loss_sum = (
|
|
232
|
+
loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Backward pass
|
|
236
|
+
loss_sum.backward()
|
|
237
|
+
torch.nn.utils.clip_grad_norm_(
|
|
238
|
+
list(loss_module.parameters()), max_norm=max_grad_norm
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# Update the networks
|
|
242
|
+
optim.step()
|
|
243
|
+
optim.zero_grad()
|
|
244
|
+
|
|
245
|
+
# Get training losses and times
|
|
246
|
+
training_time = time.time() - training_start
|
|
247
|
+
losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
|
|
248
|
+
for key, value in losses.items():
|
|
249
|
+
metrics_to_log.update({f"train/{key}": value.item()})
|
|
250
|
+
metrics_to_log.update(
|
|
251
|
+
{
|
|
252
|
+
"train/lr": alpha * lr,
|
|
253
|
+
"train/sampling_time": sampling_time,
|
|
254
|
+
"train/training_time": training_time,
|
|
255
|
+
}
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Get test rewards
|
|
259
|
+
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
|
|
260
|
+
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
|
|
261
|
+
i * frames_in_batch * frame_skip
|
|
262
|
+
) // test_interval:
|
|
263
|
+
actor.eval()
|
|
264
|
+
eval_start = time.time()
|
|
265
|
+
test_reward = eval_model(
|
|
266
|
+
actor, test_env, num_episodes=num_test_episodes
|
|
267
|
+
)
|
|
268
|
+
eval_time = time.time() - eval_start
|
|
269
|
+
metrics_to_log.update(
|
|
270
|
+
{
|
|
271
|
+
"eval/reward": test_reward,
|
|
272
|
+
"eval/time": eval_time,
|
|
273
|
+
}
|
|
274
|
+
)
|
|
275
|
+
actor.train()
|
|
276
|
+
|
|
277
|
+
if logger:
|
|
278
|
+
for key, value in metrics_to_log.items():
|
|
279
|
+
logger.log_scalar(key, value, collected_frames)
|
|
280
|
+
|
|
281
|
+
collector.update_policy_weights_()
|
|
282
|
+
sampling_start = time.time()
|
|
283
|
+
accumulator = []
|
|
284
|
+
|
|
285
|
+
collector.shutdown()
|
|
286
|
+
end_time = time.time()
|
|
287
|
+
execution_time = end_time - start_time
|
|
288
|
+
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
if __name__ == "__main__":
|
|
292
|
+
main()
|
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
This script reproduces the IMPALA Algorithm
|
|
8
|
+
results from Espeholt et al. 2018 for the on Atari Environments.
|
|
9
|
+
"""
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import hydra
|
|
13
|
+
from torchrl._utils import logger as torchrl_logger
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@hydra.main(
|
|
17
|
+
config_path="", config_name="config_multi_node_submitit", version_base="1.1"
|
|
18
|
+
)
|
|
19
|
+
def main(cfg: DictConfig): # noqa: F821
|
|
20
|
+
|
|
21
|
+
import time
|
|
22
|
+
|
|
23
|
+
import torch.optim
|
|
24
|
+
import tqdm
|
|
25
|
+
|
|
26
|
+
from tensordict import TensorDict
|
|
27
|
+
from torchrl.collectors import SyncDataCollector
|
|
28
|
+
from torchrl.collectors.distributed import DistributedDataCollector
|
|
29
|
+
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
|
30
|
+
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
|
|
31
|
+
from torchrl.envs import ExplorationType, set_exploration_type
|
|
32
|
+
from torchrl.objectives import A2CLoss
|
|
33
|
+
from torchrl.objectives.value import VTrace
|
|
34
|
+
from torchrl.record.loggers import generate_exp_name, get_logger
|
|
35
|
+
from utils import eval_model, make_env, make_ppo_models
|
|
36
|
+
|
|
37
|
+
device = cfg.local_device
|
|
38
|
+
if not device:
|
|
39
|
+
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
|
|
40
|
+
else:
|
|
41
|
+
device = torch.device(device)
|
|
42
|
+
|
|
43
|
+
# Correct for frame_skip
|
|
44
|
+
frame_skip = 4
|
|
45
|
+
total_frames = cfg.collector.total_frames // frame_skip
|
|
46
|
+
frames_per_batch = cfg.collector.frames_per_batch // frame_skip
|
|
47
|
+
test_interval = cfg.logger.test_interval // frame_skip
|
|
48
|
+
|
|
49
|
+
# Extract other config parameters
|
|
50
|
+
batch_size = cfg.loss.batch_size # Number of rollouts per batch
|
|
51
|
+
num_workers = (
|
|
52
|
+
cfg.collector.num_workers
|
|
53
|
+
) # Number of parallel workers collecting rollouts
|
|
54
|
+
lr = cfg.optim.lr
|
|
55
|
+
anneal_lr = cfg.optim.anneal_lr
|
|
56
|
+
sgd_updates = cfg.loss.sgd_updates
|
|
57
|
+
max_grad_norm = cfg.optim.max_grad_norm
|
|
58
|
+
num_test_episodes = cfg.logger.num_test_episodes
|
|
59
|
+
total_network_updates = (
|
|
60
|
+
total_frames // (frames_per_batch * batch_size)
|
|
61
|
+
) * cfg.loss.sgd_updates
|
|
62
|
+
|
|
63
|
+
# Create models (check utils.py)
|
|
64
|
+
actor, critic = make_ppo_models(cfg.env.env_name, cfg.env.backend)
|
|
65
|
+
actor, critic = actor.to(device), critic.to(device)
|
|
66
|
+
|
|
67
|
+
slurm_kwargs = {
|
|
68
|
+
"timeout_min": cfg.slurm_config.timeout_min,
|
|
69
|
+
"slurm_partition": cfg.slurm_config.slurm_partition,
|
|
70
|
+
"slurm_cpus_per_task": cfg.slurm_config.slurm_cpus_per_task,
|
|
71
|
+
"slurm_gpus_per_node": cfg.slurm_config.slurm_gpus_per_node,
|
|
72
|
+
}
|
|
73
|
+
# Create collector
|
|
74
|
+
device_str = "device" if num_workers <= 1 else "devices"
|
|
75
|
+
if cfg.collector.backend == "nccl":
|
|
76
|
+
collector_kwargs = {device_str: "cuda:0", f"storing_{device_str}": "cuda:0"}
|
|
77
|
+
elif cfg.collector.backend == "gloo":
|
|
78
|
+
collector_kwargs = {device_str: "cpu", f"storing_{device_str}": "cpu"}
|
|
79
|
+
else:
|
|
80
|
+
raise NotImplementedError(
|
|
81
|
+
f"device assignment not implemented for backend {cfg.collector.backend}"
|
|
82
|
+
)
|
|
83
|
+
collector = DistributedDataCollector(
|
|
84
|
+
create_env_fn=[make_env(cfg.env.env_name, device, gym_backend=cfg.env.backend)]
|
|
85
|
+
* num_workers,
|
|
86
|
+
policy=actor,
|
|
87
|
+
num_workers_per_collector=1,
|
|
88
|
+
frames_per_batch=frames_per_batch,
|
|
89
|
+
total_frames=total_frames,
|
|
90
|
+
collector_class=SyncDataCollector,
|
|
91
|
+
collector_kwargs=collector_kwargs,
|
|
92
|
+
slurm_kwargs=slurm_kwargs,
|
|
93
|
+
storing_device="cuda:0" if cfg.collector.backend == "nccl" else "cpu",
|
|
94
|
+
launcher="submitit",
|
|
95
|
+
# update_after_each_batch=True,
|
|
96
|
+
backend=cfg.collector.backend,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Create data buffer
|
|
100
|
+
sampler = SamplerWithoutReplacement()
|
|
101
|
+
data_buffer = TensorDictReplayBuffer(
|
|
102
|
+
storage=LazyMemmapStorage(frames_per_batch * batch_size),
|
|
103
|
+
sampler=sampler,
|
|
104
|
+
batch_size=frames_per_batch * batch_size,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Create loss and adv modules
|
|
108
|
+
adv_module = VTrace(
|
|
109
|
+
gamma=cfg.loss.gamma,
|
|
110
|
+
value_network=critic,
|
|
111
|
+
actor_network=actor,
|
|
112
|
+
average_adv=False,
|
|
113
|
+
)
|
|
114
|
+
loss_module = A2CLoss(
|
|
115
|
+
actor_network=actor,
|
|
116
|
+
critic_network=critic,
|
|
117
|
+
loss_critic_type=cfg.loss.loss_critic_type,
|
|
118
|
+
entropy_coeff=cfg.loss.entropy_coeff,
|
|
119
|
+
critic_coeff=cfg.loss.critic_coeff,
|
|
120
|
+
)
|
|
121
|
+
loss_module.set_keys(done="eol", terminated="eol")
|
|
122
|
+
|
|
123
|
+
# Create optimizer
|
|
124
|
+
optim = torch.optim.RMSprop(
|
|
125
|
+
loss_module.parameters(),
|
|
126
|
+
lr=cfg.optim.lr,
|
|
127
|
+
weight_decay=cfg.optim.weight_decay,
|
|
128
|
+
eps=cfg.optim.eps,
|
|
129
|
+
alpha=cfg.optim.alpha,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Create logger
|
|
133
|
+
logger = None
|
|
134
|
+
if cfg.logger.backend:
|
|
135
|
+
exp_name = generate_exp_name(
|
|
136
|
+
"IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}"
|
|
137
|
+
)
|
|
138
|
+
logger = get_logger(
|
|
139
|
+
cfg.logger.backend,
|
|
140
|
+
logger_name="impala",
|
|
141
|
+
experiment_name=exp_name,
|
|
142
|
+
wandb_kwargs={
|
|
143
|
+
"config": dict(cfg),
|
|
144
|
+
"project": cfg.logger.project_name,
|
|
145
|
+
"group": cfg.logger.group_name,
|
|
146
|
+
},
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Create test environment
|
|
150
|
+
test_env = make_env(
|
|
151
|
+
cfg.env.env_name, device, gym_backend=cfg.env.backend, is_test=True
|
|
152
|
+
)
|
|
153
|
+
test_env.eval()
|
|
154
|
+
|
|
155
|
+
# Main loop
|
|
156
|
+
collected_frames = 0
|
|
157
|
+
num_network_updates = 0
|
|
158
|
+
pbar = tqdm.tqdm(total=total_frames)
|
|
159
|
+
accumulator = []
|
|
160
|
+
start_time = sampling_start = time.time()
|
|
161
|
+
for i, data in enumerate(collector):
|
|
162
|
+
|
|
163
|
+
metrics_to_log = {}
|
|
164
|
+
sampling_time = time.time() - sampling_start
|
|
165
|
+
frames_in_batch = data.numel()
|
|
166
|
+
collected_frames += frames_in_batch * frame_skip
|
|
167
|
+
pbar.update(data.numel())
|
|
168
|
+
|
|
169
|
+
# Get training rewards and episode lengths
|
|
170
|
+
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
|
|
171
|
+
if len(episode_rewards) > 0:
|
|
172
|
+
episode_length = data["next", "step_count"][data["next", "done"]]
|
|
173
|
+
metrics_to_log.update(
|
|
174
|
+
{
|
|
175
|
+
"train/reward": episode_rewards.mean().item(),
|
|
176
|
+
"train/episode_length": episode_length.sum().item()
|
|
177
|
+
/ len(episode_length),
|
|
178
|
+
}
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
if len(accumulator) < batch_size:
|
|
182
|
+
accumulator.append(data)
|
|
183
|
+
if logger:
|
|
184
|
+
for key, value in metrics_to_log.items():
|
|
185
|
+
logger.log_scalar(key, value, collected_frames)
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
losses = TensorDict(batch_size=[sgd_updates])
|
|
189
|
+
training_start = time.time()
|
|
190
|
+
for j in range(sgd_updates):
|
|
191
|
+
|
|
192
|
+
# Create a single batch of trajectories
|
|
193
|
+
stacked_data = torch.stack(accumulator, dim=0).contiguous()
|
|
194
|
+
stacked_data = stacked_data.to(device, non_blocking=True)
|
|
195
|
+
|
|
196
|
+
# Compute advantage
|
|
197
|
+
with torch.no_grad():
|
|
198
|
+
stacked_data = adv_module(stacked_data)
|
|
199
|
+
|
|
200
|
+
# Add to replay buffer
|
|
201
|
+
for stacked_d in stacked_data:
|
|
202
|
+
stacked_data_reshape = stacked_d.reshape(-1)
|
|
203
|
+
data_buffer.extend(stacked_data_reshape)
|
|
204
|
+
|
|
205
|
+
for batch in data_buffer:
|
|
206
|
+
|
|
207
|
+
# Linearly decrease the learning rate and clip epsilon
|
|
208
|
+
alpha = 1.0
|
|
209
|
+
if anneal_lr:
|
|
210
|
+
alpha = 1 - (num_network_updates / total_network_updates)
|
|
211
|
+
for group in optim.param_groups:
|
|
212
|
+
group["lr"] = lr * alpha
|
|
213
|
+
num_network_updates += 1
|
|
214
|
+
|
|
215
|
+
# Get a data batch
|
|
216
|
+
batch = batch.to(device)
|
|
217
|
+
|
|
218
|
+
# Forward pass loss
|
|
219
|
+
loss = loss_module(batch)
|
|
220
|
+
losses[j] = loss.select(
|
|
221
|
+
"loss_critic", "loss_entropy", "loss_objective"
|
|
222
|
+
).detach()
|
|
223
|
+
loss_sum = (
|
|
224
|
+
loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Backward pass
|
|
228
|
+
loss_sum.backward()
|
|
229
|
+
torch.nn.utils.clip_grad_norm_(
|
|
230
|
+
list(loss_module.parameters()), max_norm=max_grad_norm
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# Update the networks
|
|
234
|
+
optim.step()
|
|
235
|
+
optim.zero_grad()
|
|
236
|
+
|
|
237
|
+
# Get training losses and times
|
|
238
|
+
training_time = time.time() - training_start
|
|
239
|
+
losses = losses.apply(lambda x: x.float().mean(), batch_size=[])
|
|
240
|
+
for key, value in losses.items():
|
|
241
|
+
metrics_to_log.update({f"train/{key}": value.item()})
|
|
242
|
+
metrics_to_log.update(
|
|
243
|
+
{
|
|
244
|
+
"train/lr": alpha * lr,
|
|
245
|
+
"train/sampling_time": sampling_time,
|
|
246
|
+
"train/training_time": training_time,
|
|
247
|
+
}
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# Get test rewards
|
|
251
|
+
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
|
|
252
|
+
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
|
|
253
|
+
i * frames_in_batch * frame_skip
|
|
254
|
+
) // test_interval:
|
|
255
|
+
actor.eval()
|
|
256
|
+
eval_start = time.time()
|
|
257
|
+
test_reward = eval_model(
|
|
258
|
+
actor, test_env, num_episodes=num_test_episodes
|
|
259
|
+
)
|
|
260
|
+
eval_time = time.time() - eval_start
|
|
261
|
+
metrics_to_log.update(
|
|
262
|
+
{
|
|
263
|
+
"eval/reward": test_reward,
|
|
264
|
+
"eval/time": eval_time,
|
|
265
|
+
}
|
|
266
|
+
)
|
|
267
|
+
actor.train()
|
|
268
|
+
|
|
269
|
+
if logger:
|
|
270
|
+
for key, value in metrics_to_log.items():
|
|
271
|
+
logger.log_scalar(key, value, collected_frames)
|
|
272
|
+
|
|
273
|
+
collector.update_policy_weights_()
|
|
274
|
+
sampling_start = time.time()
|
|
275
|
+
accumulator = []
|
|
276
|
+
|
|
277
|
+
collector.shutdown()
|
|
278
|
+
end_time = time.time()
|
|
279
|
+
execution_time = end_time - start_time
|
|
280
|
+
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
if __name__ == "__main__":
|
|
284
|
+
main()
|