torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,586 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import contextlib
|
|
8
|
+
import functools
|
|
9
|
+
import importlib.util
|
|
10
|
+
import time
|
|
11
|
+
|
|
12
|
+
import hydra
|
|
13
|
+
import torch
|
|
14
|
+
import torch.cuda
|
|
15
|
+
import tqdm
|
|
16
|
+
|
|
17
|
+
from dreamer_utils import (
|
|
18
|
+
_default_device,
|
|
19
|
+
DreamerProfiler,
|
|
20
|
+
dump_video,
|
|
21
|
+
log_metrics,
|
|
22
|
+
make_collector,
|
|
23
|
+
make_dreamer,
|
|
24
|
+
make_environments,
|
|
25
|
+
make_replay_buffer,
|
|
26
|
+
make_storage_transform,
|
|
27
|
+
)
|
|
28
|
+
from omegaconf import DictConfig
|
|
29
|
+
|
|
30
|
+
# mixed precision training
|
|
31
|
+
from torch.amp import GradScaler
|
|
32
|
+
from torch.autograd.profiler import record_function
|
|
33
|
+
from torch.nn.utils import clip_grad_norm_
|
|
34
|
+
from torchrl._utils import compile_with_warmup, logger as torchrl_logger, timeit
|
|
35
|
+
from torchrl.envs.llm.transforms import PolicyVersion
|
|
36
|
+
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
37
|
+
from torchrl.objectives.dreamer import (
|
|
38
|
+
DreamerActorLoss,
|
|
39
|
+
DreamerModelLoss,
|
|
40
|
+
DreamerValueLoss,
|
|
41
|
+
)
|
|
42
|
+
from torchrl.record.loggers import generate_exp_name, get_logger
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@hydra.main(version_base="1.1", config_path="", config_name="config")
|
|
46
|
+
def main(cfg: DictConfig): # noqa: F821
|
|
47
|
+
# cfg = correct_for_frame_skip(cfg)
|
|
48
|
+
|
|
49
|
+
device = _default_device(cfg.networks.device)
|
|
50
|
+
assert device.type == "cuda", "Dreamer only supports CUDA devices"
|
|
51
|
+
|
|
52
|
+
# Early check for video dependencies before starting training
|
|
53
|
+
if cfg.logger.video:
|
|
54
|
+
missing_deps = []
|
|
55
|
+
if importlib.util.find_spec("moviepy") is None:
|
|
56
|
+
missing_deps.append("moviepy (pip install moviepy)")
|
|
57
|
+
if importlib.util.find_spec("torchvision") is None:
|
|
58
|
+
missing_deps.append("torchvision (pip install torchvision)")
|
|
59
|
+
if missing_deps:
|
|
60
|
+
raise ImportError(
|
|
61
|
+
f"Video logging requires: {', '.join(missing_deps)}\n"
|
|
62
|
+
"Alternatively, disable video logging with: logger.video=False"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Create logger
|
|
66
|
+
exp_name = generate_exp_name("Dreamer", cfg.logger.exp_name)
|
|
67
|
+
logger = None
|
|
68
|
+
if cfg.logger.backend:
|
|
69
|
+
logger = get_logger(
|
|
70
|
+
logger_type=cfg.logger.backend,
|
|
71
|
+
logger_name="dreamer_logging",
|
|
72
|
+
experiment_name=exp_name,
|
|
73
|
+
wandb_kwargs={
|
|
74
|
+
"mode": cfg.logger.mode,
|
|
75
|
+
"project": cfg.logger.project,
|
|
76
|
+
},
|
|
77
|
+
)
|
|
78
|
+
# Log hyperparameters using wandb.config.update() with OmegaConf resolution
|
|
79
|
+
# This properly resolves interpolations like ${env.name} and uses the official wandb API
|
|
80
|
+
if hasattr(logger, "log_hparams"):
|
|
81
|
+
logger.log_hparams(cfg)
|
|
82
|
+
|
|
83
|
+
# make_environments returns (train_env_factory, test_env) for async collection
|
|
84
|
+
train_env_factory, test_env = make_environments(
|
|
85
|
+
cfg=cfg,
|
|
86
|
+
parallel_envs=cfg.env.n_parallel_envs,
|
|
87
|
+
logger=logger,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Make dreamer components
|
|
91
|
+
action_key = "action"
|
|
92
|
+
value_key = "state_value"
|
|
93
|
+
(
|
|
94
|
+
world_model,
|
|
95
|
+
model_based_env,
|
|
96
|
+
model_based_env_eval,
|
|
97
|
+
actor_model,
|
|
98
|
+
value_model,
|
|
99
|
+
policy,
|
|
100
|
+
) = make_dreamer(
|
|
101
|
+
cfg=cfg,
|
|
102
|
+
device=device,
|
|
103
|
+
action_key=action_key,
|
|
104
|
+
value_key=value_key,
|
|
105
|
+
use_decoder_in_env=cfg.logger.video,
|
|
106
|
+
logger=logger,
|
|
107
|
+
)
|
|
108
|
+
# Losses
|
|
109
|
+
world_model_loss = DreamerModelLoss(world_model)
|
|
110
|
+
# Adapt loss keys to gym backend
|
|
111
|
+
if cfg.env.backend == "gym":
|
|
112
|
+
world_model_loss.set_keys(pixels="observation", reco_pixels="reco_observation")
|
|
113
|
+
|
|
114
|
+
actor_loss = DreamerActorLoss(
|
|
115
|
+
actor_model,
|
|
116
|
+
value_model,
|
|
117
|
+
model_based_env,
|
|
118
|
+
imagination_horizon=cfg.optimization.imagination_horizon,
|
|
119
|
+
discount_loss=True,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
actor_loss.make_value_estimator(
|
|
123
|
+
gamma=cfg.optimization.gamma, lmbda=cfg.optimization.lmbda
|
|
124
|
+
)
|
|
125
|
+
value_loss = DreamerValueLoss(
|
|
126
|
+
value_model, discount_loss=True, gamma=cfg.optimization.gamma
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Make replay buffer with minimal sample-time transforms
|
|
130
|
+
# Note: Buffer must be created BEFORE collector for true async collection
|
|
131
|
+
batch_size = cfg.replay_buffer.batch_size
|
|
132
|
+
batch_length = cfg.replay_buffer.batch_length
|
|
133
|
+
buffer_size = cfg.replay_buffer.buffer_size
|
|
134
|
+
scratch_dir = cfg.replay_buffer.scratch_dir
|
|
135
|
+
prefetch = cfg.replay_buffer.prefetch
|
|
136
|
+
profiling_enabled = cfg.profiling.enabled
|
|
137
|
+
replay_buffer = make_replay_buffer(
|
|
138
|
+
batch_size=batch_size,
|
|
139
|
+
batch_seq_len=batch_length,
|
|
140
|
+
buffer_size=buffer_size,
|
|
141
|
+
buffer_scratch_dir=scratch_dir,
|
|
142
|
+
device=device,
|
|
143
|
+
prefetch=prefetch if not profiling_enabled else None,
|
|
144
|
+
pixel_obs=cfg.env.from_pixels,
|
|
145
|
+
grayscale=cfg.env.grayscale,
|
|
146
|
+
image_size=cfg.env.image_size,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Create storage transform for extend-time processing (applied once per frame)
|
|
150
|
+
storage_transform = make_storage_transform(
|
|
151
|
+
pixel_obs=cfg.env.from_pixels,
|
|
152
|
+
grayscale=cfg.env.grayscale,
|
|
153
|
+
image_size=cfg.env.image_size,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Create policy version tracker for async collection
|
|
157
|
+
# This tracks policy versions so we can correlate collected data with policy updates
|
|
158
|
+
policy_version = PolicyVersion(version_type="int")
|
|
159
|
+
|
|
160
|
+
# Make async multi-collector with replay buffer for true async collection
|
|
161
|
+
# Device allocation: cuda:0 for training, cuda:1+ for collectors (if multi-GPU)
|
|
162
|
+
collector = make_collector(
|
|
163
|
+
cfg,
|
|
164
|
+
train_env_factory,
|
|
165
|
+
policy,
|
|
166
|
+
training_device=device,
|
|
167
|
+
replay_buffer=replay_buffer,
|
|
168
|
+
storage_transform=storage_transform,
|
|
169
|
+
track_policy_version=policy_version,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Enable collector worker profiling if configured
|
|
173
|
+
if profiling_enabled and cfg.profiling.collector.enabled:
|
|
174
|
+
torchrl_logger.info(
|
|
175
|
+
f"Enabling collector profiling: workers={cfg.profiling.collector.workers}, "
|
|
176
|
+
f"num_rollouts={cfg.profiling.collector.num_rollouts}, "
|
|
177
|
+
f"warmup_rollouts={cfg.profiling.collector.warmup_rollouts}, "
|
|
178
|
+
f"init_random_frames_override={cfg.profiling.collector.init_random_frames_override}"
|
|
179
|
+
)
|
|
180
|
+
collector.enable_profile(
|
|
181
|
+
workers=list(cfg.profiling.collector.workers),
|
|
182
|
+
num_rollouts=cfg.profiling.collector.num_rollouts,
|
|
183
|
+
warmup_rollouts=cfg.profiling.collector.warmup_rollouts,
|
|
184
|
+
save_path=cfg.profiling.collector.trace_file,
|
|
185
|
+
activities=["cpu", "cuda"] if cfg.profiling.profile_cuda else ["cpu"],
|
|
186
|
+
record_shapes=cfg.profiling.record_shapes,
|
|
187
|
+
profile_memory=cfg.profiling.profile_memory,
|
|
188
|
+
with_stack=cfg.profiling.with_stack,
|
|
189
|
+
with_flops=cfg.profiling.with_flops,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# Training config
|
|
193
|
+
total_optim_steps = cfg.optimization.total_optim_steps
|
|
194
|
+
log_every = cfg.optimization.log_every
|
|
195
|
+
grad_clip = cfg.optimization.grad_clip
|
|
196
|
+
eval_every = cfg.logger.eval_every
|
|
197
|
+
eval_rollout_steps = cfg.logger.eval_rollout_steps
|
|
198
|
+
|
|
199
|
+
# Override total_optim_steps if profiling is enabled
|
|
200
|
+
if profiling_enabled:
|
|
201
|
+
total_optim_steps = cfg.profiling.total_optim_steps
|
|
202
|
+
|
|
203
|
+
# Training loop - progress bar tracks optimization steps
|
|
204
|
+
pbar = tqdm.tqdm(total=total_optim_steps, desc="Optim steps")
|
|
205
|
+
|
|
206
|
+
# Make optimizer (fused=True for faster GPU execution)
|
|
207
|
+
use_fused = device.type == "cuda"
|
|
208
|
+
world_model_opt = torch.optim.Adam(
|
|
209
|
+
world_model.parameters(), lr=cfg.optimization.world_model_lr, fused=use_fused
|
|
210
|
+
)
|
|
211
|
+
actor_opt = torch.optim.Adam(
|
|
212
|
+
actor_model.parameters(), lr=cfg.optimization.actor_lr, fused=use_fused
|
|
213
|
+
)
|
|
214
|
+
value_opt = torch.optim.Adam(
|
|
215
|
+
value_model.parameters(), lr=cfg.optimization.value_lr, fused=use_fused
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Grad scaler for mixed precision training https://pytorch.org/docs/stable/amp.html
|
|
219
|
+
# autocast can be: false, true (=bfloat16), float16, bfloat16
|
|
220
|
+
autocast_cfg = cfg.optimization.autocast
|
|
221
|
+
if autocast_cfg in (False, "false", "False"):
|
|
222
|
+
autocast_dtype = None
|
|
223
|
+
elif autocast_cfg in (True, "true", "True", "bfloat16"):
|
|
224
|
+
autocast_dtype = torch.bfloat16
|
|
225
|
+
elif autocast_cfg == "float16":
|
|
226
|
+
autocast_dtype = torch.float16
|
|
227
|
+
else:
|
|
228
|
+
raise ValueError(
|
|
229
|
+
f"Invalid autocast value: {autocast_cfg}. Use false, true, float16, or bfloat16."
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
if autocast_dtype is not None:
|
|
233
|
+
scaler1 = GradScaler()
|
|
234
|
+
scaler2 = GradScaler()
|
|
235
|
+
scaler3 = GradScaler()
|
|
236
|
+
|
|
237
|
+
# Enable TensorFloat32 for better performance on Ampere+ GPUs
|
|
238
|
+
if device.type == "cuda":
|
|
239
|
+
torch.set_float32_matmul_precision("high")
|
|
240
|
+
|
|
241
|
+
compile_cfg = cfg.optimization.compile
|
|
242
|
+
compile_enabled = compile_cfg.enabled
|
|
243
|
+
compile_losses = set(compile_cfg.losses)
|
|
244
|
+
if compile_enabled:
|
|
245
|
+
torch._dynamo.config.capture_scalar_outputs = True
|
|
246
|
+
|
|
247
|
+
compile_warmup = 3
|
|
248
|
+
torchrl_logger.info(f"Compiling loss modules with warmup={compile_warmup}")
|
|
249
|
+
backend = compile_cfg.backend
|
|
250
|
+
mode = compile_cfg.mode
|
|
251
|
+
|
|
252
|
+
# Note: We do NOT compile rssm_prior/rssm_posterior here because they are
|
|
253
|
+
# shared with the policy used in the collector. Compiling them would cause
|
|
254
|
+
# issues with the MultiCollector workers.
|
|
255
|
+
#
|
|
256
|
+
# Instead, we compile the loss modules themselves which wraps the forward pass.
|
|
257
|
+
# fullgraph=False allows graph breaks which can help with inductor issues.
|
|
258
|
+
# warmup=compile_warmup runs eagerly for first `compile_warmup` calls before compiling.
|
|
259
|
+
if "world_model" in compile_losses:
|
|
260
|
+
world_model_loss = compile_with_warmup(
|
|
261
|
+
world_model_loss,
|
|
262
|
+
backend=backend,
|
|
263
|
+
mode=mode,
|
|
264
|
+
fullgraph=False,
|
|
265
|
+
warmup=compile_warmup,
|
|
266
|
+
)
|
|
267
|
+
if "actor" in compile_losses:
|
|
268
|
+
actor_loss = compile_with_warmup(
|
|
269
|
+
actor_loss, backend=backend, mode=mode, warmup=compile_warmup
|
|
270
|
+
)
|
|
271
|
+
if "value" in compile_losses:
|
|
272
|
+
value_loss = compile_with_warmup(
|
|
273
|
+
value_loss, backend=backend, mode=mode, warmup=compile_warmup
|
|
274
|
+
)
|
|
275
|
+
else:
|
|
276
|
+
compile_warmup = 0
|
|
277
|
+
|
|
278
|
+
# Throughput tracking
|
|
279
|
+
t_log_start = time.time()
|
|
280
|
+
|
|
281
|
+
# Profiling setup (encapsulated in helper class)
|
|
282
|
+
profiler = DreamerProfiler(cfg, device, pbar, compile_warmup=compile_warmup)
|
|
283
|
+
|
|
284
|
+
# Start async collection - collector fills the buffer in background
|
|
285
|
+
torchrl_logger.info("Starting async collection...")
|
|
286
|
+
torchrl_logger.debug(f"Collector type: {type(collector).__name__}")
|
|
287
|
+
torchrl_logger.debug(f"Number of collector workers: {cfg.collector.num_collectors}")
|
|
288
|
+
collector.start()
|
|
289
|
+
torchrl_logger.debug("collector.start() completed")
|
|
290
|
+
|
|
291
|
+
# Wait for enough samples to start training
|
|
292
|
+
# The collector handles init_random_frames internally, but we also wait here
|
|
293
|
+
# to ensure the buffer has enough data before we start sampling.
|
|
294
|
+
# Use init_random_frames_override when collector profiling is enabled
|
|
295
|
+
if profiling_enabled and cfg.profiling.collector.enabled:
|
|
296
|
+
min_frames_to_start = cfg.profiling.collector.init_random_frames_override
|
|
297
|
+
torchrl_logger.info(
|
|
298
|
+
f"Collector profiling: overriding init_random_frames to {min_frames_to_start}"
|
|
299
|
+
)
|
|
300
|
+
else:
|
|
301
|
+
min_frames_to_start = cfg.collector.init_random_frames
|
|
302
|
+
|
|
303
|
+
# Always need at least batch_size frames to sample a batch
|
|
304
|
+
# (bug fix: init_random_frames_override=0 would hang on empty buffer)
|
|
305
|
+
min_frames_to_start = max(min_frames_to_start, batch_size)
|
|
306
|
+
torchrl_logger.info(
|
|
307
|
+
f"Waiting for {min_frames_to_start} initial frames before training..."
|
|
308
|
+
)
|
|
309
|
+
while replay_buffer.write_count < min_frames_to_start:
|
|
310
|
+
time.sleep(0.1)
|
|
311
|
+
|
|
312
|
+
torchrl_logger.info(
|
|
313
|
+
f"Collected {replay_buffer.write_count} frames (random frames phase complete: {min_frames_to_start} frames). "
|
|
314
|
+
f"Starting training..."
|
|
315
|
+
)
|
|
316
|
+
torchrl_logger.info(
|
|
317
|
+
"NOTE: From now on, collectors will use the policy instead of random actions. "
|
|
318
|
+
"Policy outputs keys like 'encoded_latents', 'loc', 'scale' that weren't present during random collection."
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# Track frames for FPS calculation over logging interval
|
|
322
|
+
frames_at_log_start = replay_buffer.write_count
|
|
323
|
+
|
|
324
|
+
# Main training loop - iterate over optimization steps
|
|
325
|
+
for optim_step in range(total_optim_steps):
|
|
326
|
+
# Update progress bar every step
|
|
327
|
+
pbar.update(1)
|
|
328
|
+
|
|
329
|
+
# Debug logging every 100 steps
|
|
330
|
+
if optim_step % 100 == 0:
|
|
331
|
+
cuda_mem_allocated = torch.cuda.memory_allocated(device) / (1024**3)
|
|
332
|
+
cuda_mem_reserved = torch.cuda.memory_reserved(device) / (1024**3)
|
|
333
|
+
torchrl_logger.debug(
|
|
334
|
+
f"optim_step={optim_step}: "
|
|
335
|
+
f"buffer_count={replay_buffer.write_count}, "
|
|
336
|
+
f"cuda_allocated={cuda_mem_allocated:.2f}GB, "
|
|
337
|
+
f"cuda_reserved={cuda_mem_reserved:.2f}GB"
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# sample from replay buffer
|
|
341
|
+
with timeit("train/sample"), record_function("## train/sample ##"):
|
|
342
|
+
sampled_tensordict = replay_buffer.sample().reshape(-1, batch_length)
|
|
343
|
+
if profiling_enabled:
|
|
344
|
+
torch.cuda.synchronize()
|
|
345
|
+
|
|
346
|
+
# update world model
|
|
347
|
+
with timeit("train/world_model-forward"), record_function(
|
|
348
|
+
"## world_model/forward ##"
|
|
349
|
+
):
|
|
350
|
+
# Mark step begin for CUDAGraph to prevent tensor overwrite issues
|
|
351
|
+
torch.compiler.cudagraph_mark_step_begin()
|
|
352
|
+
with torch.autocast(
|
|
353
|
+
device_type=device.type,
|
|
354
|
+
dtype=autocast_dtype,
|
|
355
|
+
) if autocast_dtype else contextlib.nullcontext():
|
|
356
|
+
assert (
|
|
357
|
+
sampled_tensordict.device.type == "cuda"
|
|
358
|
+
), "sampled_tensordict should be on CUDA"
|
|
359
|
+
model_loss_td, sampled_tensordict = world_model_loss(sampled_tensordict)
|
|
360
|
+
loss_world_model = (
|
|
361
|
+
model_loss_td["loss_model_kl"]
|
|
362
|
+
+ model_loss_td["loss_model_reco"]
|
|
363
|
+
+ model_loss_td["loss_model_reward"]
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
with timeit("train/world_model-backward"), record_function(
|
|
367
|
+
"## world_model/backward ##"
|
|
368
|
+
):
|
|
369
|
+
world_model_opt.zero_grad()
|
|
370
|
+
if autocast_dtype:
|
|
371
|
+
scaler1.scale(loss_world_model).backward()
|
|
372
|
+
scaler1.unscale_(world_model_opt)
|
|
373
|
+
else:
|
|
374
|
+
loss_world_model.backward()
|
|
375
|
+
torchrl_logger.debug("world_model_loss backward OK")
|
|
376
|
+
world_model_grad = clip_grad_norm_(world_model.parameters(), grad_clip)
|
|
377
|
+
if autocast_dtype:
|
|
378
|
+
scaler1.step(world_model_opt)
|
|
379
|
+
scaler1.update()
|
|
380
|
+
else:
|
|
381
|
+
world_model_opt.step()
|
|
382
|
+
|
|
383
|
+
# update actor network
|
|
384
|
+
with timeit("train/actor-forward"), record_function("## actor/forward ##"):
|
|
385
|
+
# Mark step begin for CUDAGraph to prevent tensor overwrite issues
|
|
386
|
+
torch.compiler.cudagraph_mark_step_begin()
|
|
387
|
+
with torch.autocast(
|
|
388
|
+
device_type=device.type, dtype=autocast_dtype
|
|
389
|
+
) if autocast_dtype else contextlib.nullcontext():
|
|
390
|
+
actor_loss_td, sampled_tensordict = actor_loss(
|
|
391
|
+
sampled_tensordict.reshape(-1)
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
with timeit("train/actor-backward"), record_function("## actor/backward ##"):
|
|
395
|
+
actor_opt.zero_grad()
|
|
396
|
+
if autocast_dtype:
|
|
397
|
+
scaler2.scale(actor_loss_td["loss_actor"]).backward()
|
|
398
|
+
scaler2.unscale_(actor_opt)
|
|
399
|
+
else:
|
|
400
|
+
actor_loss_td["loss_actor"].backward()
|
|
401
|
+
torchrl_logger.debug("actor_loss backward OK")
|
|
402
|
+
actor_model_grad = clip_grad_norm_(actor_model.parameters(), grad_clip)
|
|
403
|
+
if autocast_dtype:
|
|
404
|
+
scaler2.step(actor_opt)
|
|
405
|
+
scaler2.update()
|
|
406
|
+
else:
|
|
407
|
+
actor_opt.step()
|
|
408
|
+
|
|
409
|
+
# update value network
|
|
410
|
+
with timeit("train/value-forward"), record_function("## value/forward ##"):
|
|
411
|
+
# Mark step begin for CUDAGraph to prevent tensor overwrite issues
|
|
412
|
+
torch.compiler.cudagraph_mark_step_begin()
|
|
413
|
+
with torch.autocast(
|
|
414
|
+
device_type=device.type, dtype=autocast_dtype
|
|
415
|
+
) if autocast_dtype else contextlib.nullcontext():
|
|
416
|
+
value_loss_td, sampled_tensordict = value_loss(sampled_tensordict)
|
|
417
|
+
|
|
418
|
+
with timeit("train/value-backward"), record_function("## value/backward ##"):
|
|
419
|
+
value_opt.zero_grad()
|
|
420
|
+
if autocast_dtype:
|
|
421
|
+
scaler3.scale(value_loss_td["loss_value"]).backward()
|
|
422
|
+
scaler3.unscale_(value_opt)
|
|
423
|
+
else:
|
|
424
|
+
value_loss_td["loss_value"].backward()
|
|
425
|
+
torchrl_logger.debug("value_loss backward OK")
|
|
426
|
+
critic_model_grad = clip_grad_norm_(value_model.parameters(), grad_clip)
|
|
427
|
+
if autocast_dtype:
|
|
428
|
+
scaler3.step(value_opt)
|
|
429
|
+
scaler3.update()
|
|
430
|
+
else:
|
|
431
|
+
value_opt.step()
|
|
432
|
+
|
|
433
|
+
# Step profiler (returns True if profiling complete)
|
|
434
|
+
if profiler.step():
|
|
435
|
+
break
|
|
436
|
+
|
|
437
|
+
# Check if profiling is complete and we should exit
|
|
438
|
+
if profiler.should_exit():
|
|
439
|
+
torchrl_logger.info("Profiling complete. Exiting training loop.")
|
|
440
|
+
break
|
|
441
|
+
|
|
442
|
+
# Log metrics periodically (every log_every optim steps)
|
|
443
|
+
if (optim_step + 1) % log_every == 0:
|
|
444
|
+
# Track collected frames from buffer write count
|
|
445
|
+
collected_frames = replay_buffer.write_count
|
|
446
|
+
frames_collected_this_interval = collected_frames - frames_at_log_start
|
|
447
|
+
|
|
448
|
+
# Compute throughput metrics
|
|
449
|
+
t_log_end = time.time()
|
|
450
|
+
log_interval_time = t_log_end - t_log_start
|
|
451
|
+
|
|
452
|
+
# SPS: Samples (batch elements) processed per second
|
|
453
|
+
total_samples = log_every * batch_size
|
|
454
|
+
sps = total_samples / log_interval_time if log_interval_time > 0 else 0
|
|
455
|
+
|
|
456
|
+
# UPS: Updates (gradient steps) per second
|
|
457
|
+
# 3 updates per optim step (world_model, actor, value)
|
|
458
|
+
total_updates = log_every * 3
|
|
459
|
+
ups = total_updates / log_interval_time if log_interval_time > 0 else 0
|
|
460
|
+
|
|
461
|
+
# FPS: Frames collected per second (measured from buffer over logging interval)
|
|
462
|
+
fps = (
|
|
463
|
+
frames_collected_this_interval / log_interval_time
|
|
464
|
+
if log_interval_time > 0
|
|
465
|
+
else 0
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
# OPS: Optim steps per second
|
|
469
|
+
ops = log_every / log_interval_time if log_interval_time > 0 else 0
|
|
470
|
+
|
|
471
|
+
# OPF: Optim steps per frame (ratio of training to collection)
|
|
472
|
+
opf = (optim_step + 1) / collected_frames if collected_frames > 0 else 0
|
|
473
|
+
|
|
474
|
+
# Update progress bar with throughput metrics
|
|
475
|
+
pbar.set_postfix(
|
|
476
|
+
fps=f"{fps:.1f}",
|
|
477
|
+
ops=f"{ops:.1f}",
|
|
478
|
+
opf=f"{opf:.2f}",
|
|
479
|
+
frames=collected_frames,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
# Get reward stats from sampled data (since we don't iterate over collector directly)
|
|
483
|
+
sampled_reward = sampled_tensordict.get(("next", "reward"))
|
|
484
|
+
reward_mean = sampled_reward.mean().item()
|
|
485
|
+
reward_std = sampled_reward.std().item()
|
|
486
|
+
|
|
487
|
+
metrics = {
|
|
488
|
+
"loss_model_kl": model_loss_td["loss_model_kl"].item(),
|
|
489
|
+
"loss_model_reco": model_loss_td["loss_model_reco"].item(),
|
|
490
|
+
"loss_model_reward": model_loss_td["loss_model_reward"].item(),
|
|
491
|
+
"loss_actor": actor_loss_td["loss_actor"].item(),
|
|
492
|
+
"loss_value": value_loss_td["loss_value"].item(),
|
|
493
|
+
"world_model_grad": world_model_grad,
|
|
494
|
+
"actor_model_grad": actor_model_grad,
|
|
495
|
+
"critic_model_grad": critic_model_grad,
|
|
496
|
+
# Reward stats from sampled batch
|
|
497
|
+
"train/reward_mean": reward_mean,
|
|
498
|
+
"train/reward_std": reward_std,
|
|
499
|
+
# Throughput metrics
|
|
500
|
+
"throughput/fps": fps, # Frames per second (collection)
|
|
501
|
+
"throughput/ops": ops, # Optim steps per second
|
|
502
|
+
"throughput/opf": opf, # Optim steps per frame
|
|
503
|
+
"throughput/sps": sps, # Samples per second (training)
|
|
504
|
+
"throughput/ups": ups, # Updates per second (gradient steps)
|
|
505
|
+
"throughput/log_interval_time": log_interval_time,
|
|
506
|
+
# Collection tracking (not a target, just for monitoring)
|
|
507
|
+
"collected_frames": collected_frames,
|
|
508
|
+
# Policy version tracking
|
|
509
|
+
"policy_version": policy_version.version,
|
|
510
|
+
# Detailed timing from timeit (some metrics may be empty when compiling)
|
|
511
|
+
**timeit.todict(prefix="time"),
|
|
512
|
+
}
|
|
513
|
+
|
|
514
|
+
if logger is not None:
|
|
515
|
+
log_metrics(logger, metrics, collected_frames)
|
|
516
|
+
|
|
517
|
+
# Reset timer and frame counter for next logging interval
|
|
518
|
+
t_log_start = time.time()
|
|
519
|
+
frames_at_log_start = collected_frames
|
|
520
|
+
|
|
521
|
+
# Update policy weights in collector (for async collection)
|
|
522
|
+
with timeit("train/weight_update") as weight_update_timer:
|
|
523
|
+
torchrl_logger.debug(
|
|
524
|
+
f"optim_step={optim_step}: Starting weight update..."
|
|
525
|
+
)
|
|
526
|
+
policy[1].step(frames_collected_this_interval)
|
|
527
|
+
collector.update_policy_weights_()
|
|
528
|
+
# Increment policy version after weight update
|
|
529
|
+
collector.increment_version()
|
|
530
|
+
torchrl_logger.debug(
|
|
531
|
+
f"optim_step={optim_step}: Weight update completed in "
|
|
532
|
+
f"{weight_update_timer.elapsed():.3f}s, "
|
|
533
|
+
f"policy_version={policy_version.version}"
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
# Evaluation (every eval_every optimization steps)
|
|
537
|
+
if (optim_step + 1) % eval_every == 0:
|
|
538
|
+
# Real env
|
|
539
|
+
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
|
|
540
|
+
eval_rollout = test_env.rollout(
|
|
541
|
+
eval_rollout_steps,
|
|
542
|
+
policy,
|
|
543
|
+
auto_cast_to_device=True,
|
|
544
|
+
break_when_any_done=True,
|
|
545
|
+
)
|
|
546
|
+
test_env.apply(
|
|
547
|
+
functools.partial(dump_video, step=replay_buffer.write_count)
|
|
548
|
+
)
|
|
549
|
+
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
|
|
550
|
+
eval_metrics = {"eval/reward": eval_reward}
|
|
551
|
+
if logger is not None:
|
|
552
|
+
log_metrics(logger, eval_metrics, replay_buffer.write_count)
|
|
553
|
+
# Simulated env
|
|
554
|
+
if model_based_env_eval is not None:
|
|
555
|
+
with set_exploration_type(
|
|
556
|
+
ExplorationType.DETERMINISTIC
|
|
557
|
+
), torch.no_grad():
|
|
558
|
+
eval_rollout = model_based_env_eval.rollout(
|
|
559
|
+
eval_rollout_steps,
|
|
560
|
+
policy,
|
|
561
|
+
auto_cast_to_device=True,
|
|
562
|
+
break_when_any_done=True,
|
|
563
|
+
auto_reset=False,
|
|
564
|
+
tensordict=eval_rollout[..., 0]
|
|
565
|
+
.exclude("next", "action")
|
|
566
|
+
.to(device),
|
|
567
|
+
)
|
|
568
|
+
model_based_env_eval.apply(
|
|
569
|
+
functools.partial(dump_video, step=replay_buffer.write_count)
|
|
570
|
+
)
|
|
571
|
+
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
|
|
572
|
+
eval_metrics = {"eval/simulated_reward": eval_reward}
|
|
573
|
+
if logger is not None:
|
|
574
|
+
log_metrics(logger, eval_metrics, replay_buffer.write_count)
|
|
575
|
+
|
|
576
|
+
if not test_env.is_closed:
|
|
577
|
+
test_env.close()
|
|
578
|
+
# Shutdown async collector (use async_shutdown since we used start())
|
|
579
|
+
collector.async_shutdown()
|
|
580
|
+
|
|
581
|
+
del test_env
|
|
582
|
+
del collector
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
if __name__ == "__main__":
|
|
586
|
+
main()
|