torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""Async SAC Example.
|
|
6
|
+
|
|
7
|
+
WARNING: This isn't a SOTA implementation but a rudimentary implementation of SAC where inference
|
|
8
|
+
and training are entirely decoupled. It can achieve a 20x speedup if compile and cudagraph are used.
|
|
9
|
+
Two GPUs are required for this script to run.
|
|
10
|
+
The API is currently being perfected, and contributions are welcome (as usual!) - see the TODOs in this script.
|
|
11
|
+
|
|
12
|
+
This is a simple self-contained example of a SAC training script.
|
|
13
|
+
|
|
14
|
+
It supports state environments like MuJoCo.
|
|
15
|
+
|
|
16
|
+
The helper functions are coded in the utils.py associated with this script.
|
|
17
|
+
"""
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import time
|
|
21
|
+
|
|
22
|
+
import warnings
|
|
23
|
+
from functools import partial
|
|
24
|
+
|
|
25
|
+
import hydra
|
|
26
|
+
import numpy as np
|
|
27
|
+
import tensordict
|
|
28
|
+
import torch
|
|
29
|
+
import torch.cuda
|
|
30
|
+
import tqdm
|
|
31
|
+
from tensordict import TensorDict
|
|
32
|
+
from tensordict.nn import CudaGraphModule
|
|
33
|
+
from torchrl._utils import (
|
|
34
|
+
compile_with_warmup,
|
|
35
|
+
get_available_device,
|
|
36
|
+
logger as torchrl_logger,
|
|
37
|
+
timeit,
|
|
38
|
+
)
|
|
39
|
+
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
40
|
+
from torchrl.objectives import group_optimizers
|
|
41
|
+
from torchrl.record.loggers import generate_exp_name, get_logger
|
|
42
|
+
from utils import (
|
|
43
|
+
dump_video,
|
|
44
|
+
log_metrics,
|
|
45
|
+
make_collector_async,
|
|
46
|
+
make_environment,
|
|
47
|
+
make_loss_module,
|
|
48
|
+
make_replay_buffer,
|
|
49
|
+
make_sac_agent,
|
|
50
|
+
make_sac_optimizer,
|
|
51
|
+
make_train_environment,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
torch.set_float32_matmul_precision("high")
|
|
55
|
+
tensordict.nn.functional_modules._exclude_td_from_pytree().set()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@hydra.main(version_base="1.1", config_path="", config_name="config-async")
|
|
59
|
+
def main(cfg: DictConfig): # noqa: F821
|
|
60
|
+
device = (
|
|
61
|
+
torch.device(cfg.network.device)
|
|
62
|
+
if cfg.network.device
|
|
63
|
+
else get_available_device()
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Create logger
|
|
67
|
+
exp_name = generate_exp_name("SAC", cfg.logger.exp_name)
|
|
68
|
+
logger = None
|
|
69
|
+
if cfg.logger.backend:
|
|
70
|
+
logger = get_logger(
|
|
71
|
+
logger_type=cfg.logger.backend,
|
|
72
|
+
logger_name="async_sac_logging",
|
|
73
|
+
experiment_name=exp_name,
|
|
74
|
+
wandb_kwargs={
|
|
75
|
+
"mode": cfg.logger.mode,
|
|
76
|
+
"config": dict(cfg),
|
|
77
|
+
"project": cfg.logger.project_name,
|
|
78
|
+
"group": cfg.logger.group_name,
|
|
79
|
+
},
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
torch.manual_seed(cfg.env.seed)
|
|
83
|
+
np.random.seed(cfg.env.seed)
|
|
84
|
+
|
|
85
|
+
# Create environments
|
|
86
|
+
_, eval_env = make_environment(cfg, logger=logger)
|
|
87
|
+
|
|
88
|
+
# TODO: This should be simplified. We need to create the policy on cuda:1 directly because of the bounds
|
|
89
|
+
# of the TanhDistribution which cannot be sent to cuda:1 within the distribution construction (ie, the
|
|
90
|
+
# distribution kwargs need to have access to the low / high values on the right device for compile and
|
|
91
|
+
# cudagraph to work).
|
|
92
|
+
# Create agent
|
|
93
|
+
dummy_train_env = make_train_environment(cfg)
|
|
94
|
+
model, _ = make_sac_agent(cfg, dummy_train_env, eval_env, device)
|
|
95
|
+
_, exploration_policy = make_sac_agent(cfg, dummy_train_env, eval_env, "cuda:1")
|
|
96
|
+
dummy_train_env.close(raise_if_closed=False)
|
|
97
|
+
del dummy_train_env
|
|
98
|
+
exploration_policy.load_state_dict(model[0].state_dict())
|
|
99
|
+
|
|
100
|
+
# Create SAC loss
|
|
101
|
+
loss_module, target_net_updater = make_loss_module(cfg, model)
|
|
102
|
+
|
|
103
|
+
compile_mode = None
|
|
104
|
+
if cfg.compile.compile:
|
|
105
|
+
compile_mode = cfg.compile.compile_mode
|
|
106
|
+
if compile_mode in ("", None):
|
|
107
|
+
if cfg.compile.cudagraphs:
|
|
108
|
+
compile_mode = "default"
|
|
109
|
+
else:
|
|
110
|
+
compile_mode = "reduce-overhead"
|
|
111
|
+
compile_mode_collector = compile_mode # "reduce-overhead"
|
|
112
|
+
|
|
113
|
+
# TODO: enabling prefetch for mp RBs would speed up sampling which is currently responsible for
|
|
114
|
+
# half of the compute time on the trainer side.
|
|
115
|
+
# Create replay buffer
|
|
116
|
+
replay_buffer = make_replay_buffer(
|
|
117
|
+
batch_size=cfg.optim.batch_size,
|
|
118
|
+
prb=cfg.replay_buffer.prb,
|
|
119
|
+
buffer_size=cfg.replay_buffer.size,
|
|
120
|
+
scratch_dir=cfg.replay_buffer.scratch_dir,
|
|
121
|
+
device=device,
|
|
122
|
+
shared=True,
|
|
123
|
+
prefetch=0,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# TODO: Simplify this - ideally we'd like to share the uninitialized lazy tensor storage and fetch it once
|
|
127
|
+
# it's initialized
|
|
128
|
+
replay_buffer.extend(make_train_environment(cfg).rollout(1).view(-1))
|
|
129
|
+
replay_buffer.empty()
|
|
130
|
+
|
|
131
|
+
# Create off-policy collector and start it
|
|
132
|
+
collector = make_collector_async(
|
|
133
|
+
cfg,
|
|
134
|
+
partial(make_train_environment, cfg),
|
|
135
|
+
exploration_policy,
|
|
136
|
+
compile_mode=compile_mode_collector,
|
|
137
|
+
replay_buffer=replay_buffer,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Create optimizers
|
|
141
|
+
(
|
|
142
|
+
optimizer_actor,
|
|
143
|
+
optimizer_critic,
|
|
144
|
+
optimizer_alpha,
|
|
145
|
+
) = make_sac_optimizer(cfg, loss_module)
|
|
146
|
+
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha)
|
|
147
|
+
del optimizer_actor, optimizer_critic, optimizer_alpha
|
|
148
|
+
|
|
149
|
+
def update(sampled_tensordict):
|
|
150
|
+
# Compute loss
|
|
151
|
+
loss_td = loss_module(sampled_tensordict)
|
|
152
|
+
|
|
153
|
+
actor_loss = loss_td["loss_actor"]
|
|
154
|
+
q_loss = loss_td["loss_qvalue"]
|
|
155
|
+
alpha_loss = loss_td["loss_alpha"]
|
|
156
|
+
|
|
157
|
+
(actor_loss + q_loss + alpha_loss).sum().backward()
|
|
158
|
+
optimizer.step()
|
|
159
|
+
|
|
160
|
+
# Update qnet_target params
|
|
161
|
+
target_net_updater.step()
|
|
162
|
+
|
|
163
|
+
optimizer.zero_grad(set_to_none=True)
|
|
164
|
+
return loss_td.detach()
|
|
165
|
+
|
|
166
|
+
if cfg.compile.compile:
|
|
167
|
+
update = compile_with_warmup(update, mode=compile_mode, warmup=2)
|
|
168
|
+
|
|
169
|
+
cfg.compile.cudagraphs
|
|
170
|
+
if cfg.compile.cudagraphs:
|
|
171
|
+
warnings.warn(
|
|
172
|
+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
|
|
173
|
+
category=UserWarning,
|
|
174
|
+
)
|
|
175
|
+
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=10)
|
|
176
|
+
|
|
177
|
+
# Main loop
|
|
178
|
+
init_random_frames = cfg.collector.init_random_frames
|
|
179
|
+
|
|
180
|
+
prb = cfg.replay_buffer.prb
|
|
181
|
+
update_freq = cfg.collector.update_freq
|
|
182
|
+
|
|
183
|
+
eval_rollout_steps = cfg.env.max_episode_steps
|
|
184
|
+
log_freq = cfg.logger.log_freq
|
|
185
|
+
|
|
186
|
+
# TODO: customize this
|
|
187
|
+
num_updates = 1000
|
|
188
|
+
total_iter = 1000
|
|
189
|
+
pbar = tqdm.tqdm(total=total_iter * num_updates)
|
|
190
|
+
params = TensorDict.from_module(model[0]).data
|
|
191
|
+
|
|
192
|
+
# Wait till we have enough data to start training
|
|
193
|
+
while replay_buffer.write_count <= init_random_frames:
|
|
194
|
+
time.sleep(0.01)
|
|
195
|
+
|
|
196
|
+
losses = []
|
|
197
|
+
for i in range(total_iter * num_updates):
|
|
198
|
+
timeit.printevery(
|
|
199
|
+
num_prints=total_iter * num_updates // log_freq,
|
|
200
|
+
total_count=total_iter * num_updates,
|
|
201
|
+
erase=True,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
if (i % update_freq) == 0:
|
|
205
|
+
# Update weights of the inference policy
|
|
206
|
+
torchrl_logger.info("Updating weights")
|
|
207
|
+
collector.update_policy_weights_(params)
|
|
208
|
+
|
|
209
|
+
pbar.update(1)
|
|
210
|
+
|
|
211
|
+
# Optimization steps
|
|
212
|
+
with timeit("train"):
|
|
213
|
+
with timeit("train - rb - sample"):
|
|
214
|
+
# Sample from replay buffer
|
|
215
|
+
sampled_tensordict = replay_buffer.sample()
|
|
216
|
+
|
|
217
|
+
with timeit("train - update"):
|
|
218
|
+
torch.compiler.cudagraph_mark_step_begin()
|
|
219
|
+
loss_td = update(sampled_tensordict).clone()
|
|
220
|
+
losses.append(loss_td.select("loss_actor", "loss_qvalue", "loss_alpha"))
|
|
221
|
+
|
|
222
|
+
# Update priority
|
|
223
|
+
if prb:
|
|
224
|
+
replay_buffer.update_priority(sampled_tensordict)
|
|
225
|
+
|
|
226
|
+
# Logging
|
|
227
|
+
if (i % log_freq) == (log_freq - 1):
|
|
228
|
+
torchrl_logger.info("Logging")
|
|
229
|
+
collected_frames = replay_buffer.write_count
|
|
230
|
+
metrics_to_log = {}
|
|
231
|
+
if collected_frames >= init_random_frames:
|
|
232
|
+
losses_m = torch.stack(losses).mean()
|
|
233
|
+
losses = []
|
|
234
|
+
metrics_to_log["train/q_loss"] = losses_m.get("loss_qvalue")
|
|
235
|
+
metrics_to_log["train/actor_loss"] = losses_m.get("loss_actor")
|
|
236
|
+
metrics_to_log["train/alpha_loss"] = losses_m.get("loss_alpha")
|
|
237
|
+
metrics_to_log["train/alpha"] = loss_td["alpha"]
|
|
238
|
+
metrics_to_log["train/entropy"] = loss_td["entropy"]
|
|
239
|
+
metrics_to_log["train/collected_frames"] = int(collected_frames)
|
|
240
|
+
|
|
241
|
+
# Evaluation
|
|
242
|
+
with set_exploration_type(
|
|
243
|
+
ExplorationType.DETERMINISTIC
|
|
244
|
+
), torch.no_grad(), timeit("eval"):
|
|
245
|
+
eval_rollout = eval_env.rollout(
|
|
246
|
+
eval_rollout_steps,
|
|
247
|
+
model[0],
|
|
248
|
+
auto_cast_to_device=True,
|
|
249
|
+
break_when_any_done=True,
|
|
250
|
+
)
|
|
251
|
+
eval_env.apply(dump_video)
|
|
252
|
+
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
|
|
253
|
+
metrics_to_log["eval/reward"] = eval_reward
|
|
254
|
+
torchrl_logger.info(f"Logs: {metrics_to_log}")
|
|
255
|
+
if logger is not None:
|
|
256
|
+
metrics_to_log.update(timeit.todict(prefix="time"))
|
|
257
|
+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
|
|
258
|
+
log_metrics(logger, metrics_to_log, collected_frames)
|
|
259
|
+
|
|
260
|
+
collector.shutdown()
|
|
261
|
+
if not eval_env.is_closed:
|
|
262
|
+
eval_env.close()
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
if __name__ == "__main__":
|
|
266
|
+
main()
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""SAC Example.
|
|
6
|
+
|
|
7
|
+
This is a simple self-contained example of a SAC training script.
|
|
8
|
+
|
|
9
|
+
It supports state environments like MuJoCo.
|
|
10
|
+
|
|
11
|
+
The helper functions are coded in the utils.py associated with this script.
|
|
12
|
+
"""
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import warnings
|
|
16
|
+
|
|
17
|
+
import hydra
|
|
18
|
+
import numpy as np
|
|
19
|
+
import torch
|
|
20
|
+
import torch.cuda
|
|
21
|
+
import tqdm
|
|
22
|
+
from tensordict import TensorDict
|
|
23
|
+
from tensordict.nn import CudaGraphModule
|
|
24
|
+
from torchrl._utils import compile_with_warmup, get_available_device, timeit
|
|
25
|
+
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
26
|
+
from torchrl.objectives import group_optimizers
|
|
27
|
+
from torchrl.record.loggers import generate_exp_name, get_logger
|
|
28
|
+
from utils import (
|
|
29
|
+
dump_video,
|
|
30
|
+
log_metrics,
|
|
31
|
+
make_collector,
|
|
32
|
+
make_environment,
|
|
33
|
+
make_loss_module,
|
|
34
|
+
make_replay_buffer,
|
|
35
|
+
make_sac_agent,
|
|
36
|
+
make_sac_optimizer,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
torch.set_float32_matmul_precision("high")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@hydra.main(version_base="1.1", config_path="", config_name="config")
|
|
43
|
+
def main(cfg: DictConfig): # noqa: F821
|
|
44
|
+
device = (
|
|
45
|
+
torch.device(cfg.network.device)
|
|
46
|
+
if cfg.network.device
|
|
47
|
+
else get_available_device()
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Create logger
|
|
51
|
+
exp_name = generate_exp_name("SAC", cfg.logger.exp_name)
|
|
52
|
+
logger = None
|
|
53
|
+
if cfg.logger.backend:
|
|
54
|
+
logger = get_logger(
|
|
55
|
+
logger_type=cfg.logger.backend,
|
|
56
|
+
logger_name="sac_logging",
|
|
57
|
+
experiment_name=exp_name,
|
|
58
|
+
wandb_kwargs={
|
|
59
|
+
"mode": cfg.logger.mode,
|
|
60
|
+
"config": dict(cfg),
|
|
61
|
+
"project": cfg.logger.project_name,
|
|
62
|
+
"group": cfg.logger.group_name,
|
|
63
|
+
},
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
torch.manual_seed(cfg.env.seed)
|
|
67
|
+
np.random.seed(cfg.env.seed)
|
|
68
|
+
|
|
69
|
+
# Create environments
|
|
70
|
+
train_env, eval_env = make_environment(cfg, logger=logger)
|
|
71
|
+
|
|
72
|
+
# Create agent
|
|
73
|
+
model, exploration_policy = make_sac_agent(cfg, train_env, eval_env, device)
|
|
74
|
+
|
|
75
|
+
# Create SAC loss
|
|
76
|
+
loss_module, target_net_updater = make_loss_module(cfg, model)
|
|
77
|
+
|
|
78
|
+
compile_mode = None
|
|
79
|
+
if cfg.compile.compile:
|
|
80
|
+
compile_mode = cfg.compile.compile_mode
|
|
81
|
+
if compile_mode in ("", None):
|
|
82
|
+
if cfg.compile.cudagraphs:
|
|
83
|
+
compile_mode = "default"
|
|
84
|
+
else:
|
|
85
|
+
compile_mode = "reduce-overhead"
|
|
86
|
+
|
|
87
|
+
# Create off-policy collector
|
|
88
|
+
collector = make_collector(
|
|
89
|
+
cfg, train_env, exploration_policy, compile_mode=compile_mode
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Create replay buffer
|
|
93
|
+
replay_buffer = make_replay_buffer(
|
|
94
|
+
batch_size=cfg.optim.batch_size,
|
|
95
|
+
prb=cfg.replay_buffer.prb,
|
|
96
|
+
buffer_size=cfg.replay_buffer.size,
|
|
97
|
+
scratch_dir=cfg.replay_buffer.scratch_dir,
|
|
98
|
+
device=device,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Create optimizers
|
|
102
|
+
(
|
|
103
|
+
optimizer_actor,
|
|
104
|
+
optimizer_critic,
|
|
105
|
+
optimizer_alpha,
|
|
106
|
+
) = make_sac_optimizer(cfg, loss_module)
|
|
107
|
+
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha)
|
|
108
|
+
del optimizer_actor, optimizer_critic, optimizer_alpha
|
|
109
|
+
|
|
110
|
+
def update(sampled_tensordict):
|
|
111
|
+
# Compute loss
|
|
112
|
+
loss_td = loss_module(sampled_tensordict)
|
|
113
|
+
|
|
114
|
+
actor_loss = loss_td["loss_actor"]
|
|
115
|
+
q_loss = loss_td["loss_qvalue"]
|
|
116
|
+
alpha_loss = loss_td["loss_alpha"]
|
|
117
|
+
|
|
118
|
+
(actor_loss + q_loss + alpha_loss).sum().backward()
|
|
119
|
+
optimizer.step()
|
|
120
|
+
optimizer.zero_grad(set_to_none=True)
|
|
121
|
+
|
|
122
|
+
# Update qnet_target params
|
|
123
|
+
target_net_updater.step()
|
|
124
|
+
return loss_td.detach()
|
|
125
|
+
|
|
126
|
+
if cfg.compile.compile:
|
|
127
|
+
update = compile_with_warmup(update, mode=compile_mode, warmup=1)
|
|
128
|
+
|
|
129
|
+
if cfg.compile.cudagraphs:
|
|
130
|
+
warnings.warn(
|
|
131
|
+
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
|
|
132
|
+
category=UserWarning,
|
|
133
|
+
)
|
|
134
|
+
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
|
|
135
|
+
|
|
136
|
+
# Main loop
|
|
137
|
+
collected_frames = 0
|
|
138
|
+
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
|
|
139
|
+
|
|
140
|
+
init_random_frames = cfg.collector.init_random_frames
|
|
141
|
+
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
|
|
142
|
+
prb = cfg.replay_buffer.prb
|
|
143
|
+
eval_iter = cfg.logger.eval_iter
|
|
144
|
+
frames_per_batch = cfg.collector.frames_per_batch
|
|
145
|
+
eval_rollout_steps = cfg.env.max_episode_steps
|
|
146
|
+
|
|
147
|
+
collector_iter = iter(collector)
|
|
148
|
+
total_iter = len(collector)
|
|
149
|
+
|
|
150
|
+
for i in range(total_iter):
|
|
151
|
+
timeit.printevery(num_prints=1000, total_count=total_iter, erase=True)
|
|
152
|
+
|
|
153
|
+
with timeit("collect"):
|
|
154
|
+
tensordict = next(collector_iter)
|
|
155
|
+
|
|
156
|
+
# Update weights of the inference policy
|
|
157
|
+
collector.update_policy_weights_()
|
|
158
|
+
|
|
159
|
+
current_frames = tensordict.numel()
|
|
160
|
+
pbar.update(current_frames)
|
|
161
|
+
|
|
162
|
+
with timeit("rb - extend"):
|
|
163
|
+
# Add to replay buffer
|
|
164
|
+
tensordict = tensordict.reshape(-1)
|
|
165
|
+
replay_buffer.extend(tensordict)
|
|
166
|
+
|
|
167
|
+
collected_frames += current_frames
|
|
168
|
+
|
|
169
|
+
# Optimization steps
|
|
170
|
+
with timeit("train"):
|
|
171
|
+
if collected_frames >= init_random_frames:
|
|
172
|
+
losses = TensorDict(batch_size=[num_updates])
|
|
173
|
+
for i in range(num_updates):
|
|
174
|
+
with timeit("rb - sample"):
|
|
175
|
+
# Sample from replay buffer
|
|
176
|
+
sampled_tensordict = replay_buffer.sample()
|
|
177
|
+
|
|
178
|
+
with timeit("update"):
|
|
179
|
+
torch.compiler.cudagraph_mark_step_begin()
|
|
180
|
+
loss_td = update(sampled_tensordict).clone()
|
|
181
|
+
losses[i] = loss_td.select(
|
|
182
|
+
"loss_actor", "loss_qvalue", "loss_alpha"
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# Update priority
|
|
186
|
+
if prb:
|
|
187
|
+
replay_buffer.update_priority(sampled_tensordict)
|
|
188
|
+
|
|
189
|
+
episode_end = (
|
|
190
|
+
tensordict["next", "done"]
|
|
191
|
+
if tensordict["next", "done"].any()
|
|
192
|
+
else tensordict["next", "truncated"]
|
|
193
|
+
)
|
|
194
|
+
episode_rewards = tensordict["next", "episode_reward"][episode_end]
|
|
195
|
+
|
|
196
|
+
# Logging
|
|
197
|
+
metrics_to_log = {}
|
|
198
|
+
if len(episode_rewards) > 0:
|
|
199
|
+
episode_length = tensordict["next", "step_count"][episode_end]
|
|
200
|
+
metrics_to_log["train/reward"] = episode_rewards
|
|
201
|
+
metrics_to_log["train/episode_length"] = episode_length.sum() / len(
|
|
202
|
+
episode_length
|
|
203
|
+
)
|
|
204
|
+
if collected_frames >= init_random_frames:
|
|
205
|
+
losses = losses.mean()
|
|
206
|
+
metrics_to_log["train/q_loss"] = losses.get("loss_qvalue")
|
|
207
|
+
metrics_to_log["train/actor_loss"] = losses.get("loss_actor")
|
|
208
|
+
metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha")
|
|
209
|
+
metrics_to_log["train/alpha"] = loss_td["alpha"]
|
|
210
|
+
metrics_to_log["train/entropy"] = loss_td["entropy"]
|
|
211
|
+
|
|
212
|
+
# Evaluation
|
|
213
|
+
if abs(collected_frames % eval_iter) < frames_per_batch:
|
|
214
|
+
with set_exploration_type(
|
|
215
|
+
ExplorationType.DETERMINISTIC
|
|
216
|
+
), torch.no_grad(), timeit("eval"):
|
|
217
|
+
eval_rollout = eval_env.rollout(
|
|
218
|
+
eval_rollout_steps,
|
|
219
|
+
model[0],
|
|
220
|
+
auto_cast_to_device=True,
|
|
221
|
+
break_when_any_done=True,
|
|
222
|
+
)
|
|
223
|
+
eval_env.apply(dump_video)
|
|
224
|
+
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
|
|
225
|
+
metrics_to_log["eval/reward"] = eval_reward
|
|
226
|
+
if logger is not None:
|
|
227
|
+
metrics_to_log.update(timeit.todict(prefix="time"))
|
|
228
|
+
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
|
|
229
|
+
log_metrics(logger, metrics_to_log, collected_frames)
|
|
230
|
+
|
|
231
|
+
collector.shutdown()
|
|
232
|
+
if not eval_env.is_closed:
|
|
233
|
+
eval_env.close()
|
|
234
|
+
if not train_env.is_closed:
|
|
235
|
+
train_env.close()
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
if __name__ == "__main__":
|
|
239
|
+
main()
|