torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import pathlib
|
|
9
|
+
import warnings
|
|
10
|
+
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
|
|
13
|
+
from functools import partial
|
|
14
|
+
|
|
15
|
+
from tensordict import TensorDict, TensorDictBase
|
|
16
|
+
from torch import optim
|
|
17
|
+
|
|
18
|
+
from torchrl.collectors import BaseCollector
|
|
19
|
+
|
|
20
|
+
from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
|
|
21
|
+
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
|
|
22
|
+
from torchrl.objectives.common import LossModule
|
|
23
|
+
from torchrl.objectives.value.advantages import GAE
|
|
24
|
+
from torchrl.record.loggers import Logger
|
|
25
|
+
from torchrl.trainers.trainers import (
|
|
26
|
+
LogScalar,
|
|
27
|
+
ReplayBufferTrainer,
|
|
28
|
+
Trainer,
|
|
29
|
+
UpdateWeights,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
_has_tqdm = True
|
|
36
|
+
except ImportError:
|
|
37
|
+
_has_tqdm = False
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
_has_ts = True
|
|
43
|
+
except ImportError:
|
|
44
|
+
_has_ts = False
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class PPOTrainer(Trainer):
|
|
48
|
+
"""PPO (Proximal Policy Optimization) trainer implementation.
|
|
49
|
+
|
|
50
|
+
.. warning::
|
|
51
|
+
This is an experimental/prototype feature. The API may change in future versions.
|
|
52
|
+
Please report any issues or feedback to help improve this implementation.
|
|
53
|
+
|
|
54
|
+
This trainer implements the PPO algorithm for training reinforcement learning agents.
|
|
55
|
+
It extends the base Trainer class with PPO-specific functionality including
|
|
56
|
+
policy optimization, value function learning, and entropy regularization.
|
|
57
|
+
|
|
58
|
+
PPO typically uses multiple epochs of optimization on the same batch of data.
|
|
59
|
+
This trainer defaults to 4 epochs, which is a common choice for PPO implementations.
|
|
60
|
+
|
|
61
|
+
The trainer includes comprehensive logging capabilities for monitoring training progress:
|
|
62
|
+
- Training rewards (mean, std, max, total)
|
|
63
|
+
- Action statistics (norms)
|
|
64
|
+
- Episode completion rates
|
|
65
|
+
- Observation statistics (optional)
|
|
66
|
+
|
|
67
|
+
Logging can be configured via constructor parameters to enable/disable specific metrics.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
collector (BaseCollector): The data collector for gathering training data.
|
|
71
|
+
total_frames (int): Total number of frames to train for.
|
|
72
|
+
frame_skip (int): Frame skip value for the environment.
|
|
73
|
+
optim_steps_per_batch (int): Number of optimization steps per batch.
|
|
74
|
+
loss_module (LossModule): The loss module for computing policy and value losses.
|
|
75
|
+
optimizer (optim.Optimizer, optional): The optimizer for training.
|
|
76
|
+
logger (Logger, optional): Logger for tracking training metrics.
|
|
77
|
+
clip_grad_norm (bool, optional): Whether to clip gradient norms. Default: True.
|
|
78
|
+
clip_norm (float, optional): Maximum gradient norm value.
|
|
79
|
+
progress_bar (bool, optional): Whether to show a progress bar. Default: True.
|
|
80
|
+
seed (int, optional): Random seed for reproducibility.
|
|
81
|
+
save_trainer_interval (int, optional): Interval for saving trainer state. Default: 10000.
|
|
82
|
+
log_interval (int, optional): Interval for logging metrics. Default: 10000.
|
|
83
|
+
save_trainer_file (str | pathlib.Path, optional): File path for saving trainer state.
|
|
84
|
+
num_epochs (int, optional): Number of epochs per batch. Default: 4.
|
|
85
|
+
replay_buffer (ReplayBuffer, optional): Replay buffer for storing data.
|
|
86
|
+
batch_size (int, optional): Batch size for optimization.
|
|
87
|
+
gamma (float, optional): Discount factor for GAE. Default: 0.9.
|
|
88
|
+
lmbda (float, optional): Lambda parameter for GAE. Default: 0.99.
|
|
89
|
+
enable_logging (bool, optional): Whether to enable logging. Default: True.
|
|
90
|
+
log_rewards (bool, optional): Whether to log rewards. Default: True.
|
|
91
|
+
log_actions (bool, optional): Whether to log actions. Default: True.
|
|
92
|
+
log_observations (bool, optional): Whether to log observations. Default: False.
|
|
93
|
+
async_collection (bool, optional): Whether to use async collection. Default: False.
|
|
94
|
+
add_gae (bool, optional): Whether to add GAE computation. Default: True.
|
|
95
|
+
gae (Callable, optional): Custom GAE module. If None and add_gae is True, a default GAE will be created.
|
|
96
|
+
weight_update_map (dict[str, str], optional): Mapping from collector destination paths (keys in
|
|
97
|
+
collector's weight_sync_schemes) to trainer source paths. Required if collector has
|
|
98
|
+
weight_sync_schemes configured. Example: {"policy": "loss_module.actor_network",
|
|
99
|
+
"replay_buffer.transforms[0]": "loss_module.critic_network"}
|
|
100
|
+
log_timings (bool, optional): If True, automatically register a LogTiming hook to log
|
|
101
|
+
timing information for all hooks to the logger (e.g., wandb, tensorboard).
|
|
102
|
+
Timing metrics will be logged with prefix "time/" (e.g., "time/hook/UpdateWeights").
|
|
103
|
+
Default is False.
|
|
104
|
+
|
|
105
|
+
Examples:
|
|
106
|
+
>>> # Basic usage with manual configuration
|
|
107
|
+
>>> from torchrl.trainers.algorithms.ppo import PPOTrainer
|
|
108
|
+
>>> from torchrl.trainers.algorithms.configs import PPOTrainerConfig
|
|
109
|
+
>>> from hydra import instantiate
|
|
110
|
+
>>> config = PPOTrainerConfig(...) # Configure with required parameters
|
|
111
|
+
>>> trainer = instantiate(config)
|
|
112
|
+
>>> trainer.train()
|
|
113
|
+
|
|
114
|
+
.. note::
|
|
115
|
+
This trainer requires a configurable environment setup. See the
|
|
116
|
+
:class:`~torchrl.trainers.algorithms.configs` module for configuration options.
|
|
117
|
+
|
|
118
|
+
.. warning::
|
|
119
|
+
This is an experimental feature. The API may change in future versions.
|
|
120
|
+
We welcome feedback and contributions to help improve this implementation!
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
def __init__(
|
|
124
|
+
self,
|
|
125
|
+
*,
|
|
126
|
+
collector: BaseCollector,
|
|
127
|
+
total_frames: int,
|
|
128
|
+
frame_skip: int,
|
|
129
|
+
optim_steps_per_batch: int,
|
|
130
|
+
loss_module: LossModule | Callable[[TensorDictBase], TensorDictBase],
|
|
131
|
+
optimizer: optim.Optimizer | None = None,
|
|
132
|
+
logger: Logger | None = None,
|
|
133
|
+
clip_grad_norm: bool = True,
|
|
134
|
+
clip_norm: float | None = None,
|
|
135
|
+
progress_bar: bool = True,
|
|
136
|
+
seed: int | None = None,
|
|
137
|
+
save_trainer_interval: int = 10000,
|
|
138
|
+
log_interval: int = 10000,
|
|
139
|
+
save_trainer_file: str | pathlib.Path | None = None,
|
|
140
|
+
num_epochs: int = 4,
|
|
141
|
+
replay_buffer: ReplayBuffer | None = None,
|
|
142
|
+
batch_size: int | None = None,
|
|
143
|
+
gamma: float = 0.9,
|
|
144
|
+
lmbda: float = 0.99,
|
|
145
|
+
enable_logging: bool = True,
|
|
146
|
+
log_rewards: bool = True,
|
|
147
|
+
log_actions: bool = True,
|
|
148
|
+
log_observations: bool = False,
|
|
149
|
+
async_collection: bool = False,
|
|
150
|
+
add_gae: bool = True,
|
|
151
|
+
gae: Callable[[TensorDictBase], TensorDictBase] | None = None,
|
|
152
|
+
weight_update_map: dict[str, str] | None = None,
|
|
153
|
+
log_timings: bool = False,
|
|
154
|
+
) -> None:
|
|
155
|
+
warnings.warn(
|
|
156
|
+
"PPOTrainer is an experimental/prototype feature. The API may change in future versions. "
|
|
157
|
+
"Please report any issues or feedback to help improve this implementation.",
|
|
158
|
+
UserWarning,
|
|
159
|
+
stacklevel=2,
|
|
160
|
+
)
|
|
161
|
+
super().__init__(
|
|
162
|
+
collector=collector,
|
|
163
|
+
total_frames=total_frames,
|
|
164
|
+
frame_skip=frame_skip,
|
|
165
|
+
optim_steps_per_batch=optim_steps_per_batch,
|
|
166
|
+
loss_module=loss_module,
|
|
167
|
+
optimizer=optimizer,
|
|
168
|
+
logger=logger,
|
|
169
|
+
clip_grad_norm=clip_grad_norm,
|
|
170
|
+
clip_norm=clip_norm,
|
|
171
|
+
progress_bar=progress_bar,
|
|
172
|
+
seed=seed,
|
|
173
|
+
save_trainer_interval=save_trainer_interval,
|
|
174
|
+
log_interval=log_interval,
|
|
175
|
+
save_trainer_file=save_trainer_file,
|
|
176
|
+
num_epochs=num_epochs,
|
|
177
|
+
async_collection=async_collection,
|
|
178
|
+
log_timings=log_timings,
|
|
179
|
+
)
|
|
180
|
+
self.replay_buffer = replay_buffer
|
|
181
|
+
self.async_collection = async_collection
|
|
182
|
+
|
|
183
|
+
if add_gae and gae is None:
|
|
184
|
+
gae = GAE(
|
|
185
|
+
gamma=gamma,
|
|
186
|
+
lmbda=lmbda,
|
|
187
|
+
value_network=self.loss_module.critic_network,
|
|
188
|
+
average_gae=True,
|
|
189
|
+
)
|
|
190
|
+
self.register_op("pre_epoch", gae)
|
|
191
|
+
elif not add_gae and gae is not None:
|
|
192
|
+
raise ValueError("gae must not be provided if add_gae is False")
|
|
193
|
+
|
|
194
|
+
if (
|
|
195
|
+
not self.async_collection
|
|
196
|
+
and replay_buffer is not None
|
|
197
|
+
and not isinstance(replay_buffer.sampler, SamplerWithoutReplacement)
|
|
198
|
+
):
|
|
199
|
+
warnings.warn(
|
|
200
|
+
"Sampler is not a SamplerWithoutReplacement, which is required for PPO."
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
if replay_buffer is not None:
|
|
204
|
+
rb_trainer = ReplayBufferTrainer(
|
|
205
|
+
replay_buffer,
|
|
206
|
+
batch_size=None,
|
|
207
|
+
flatten_tensordicts=True,
|
|
208
|
+
memmap=False,
|
|
209
|
+
device=getattr(replay_buffer.storage, "device", "cpu"),
|
|
210
|
+
iterate=True,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
if not self.async_collection:
|
|
214
|
+
# rb has been extended by the collector
|
|
215
|
+
self.register_op("pre_epoch", rb_trainer.extend)
|
|
216
|
+
self.register_op("process_optim_batch", rb_trainer.sample)
|
|
217
|
+
self.register_op("post_loss", rb_trainer.update_priority)
|
|
218
|
+
|
|
219
|
+
# Set up weight updates
|
|
220
|
+
# Validate weight_update_map if collector has weight_sync_schemes
|
|
221
|
+
if (
|
|
222
|
+
hasattr(self.collector, "_weight_sync_schemes")
|
|
223
|
+
and self.collector._weight_sync_schemes
|
|
224
|
+
):
|
|
225
|
+
if weight_update_map is None:
|
|
226
|
+
raise ValueError(
|
|
227
|
+
"Collector has weight_sync_schemes configured, but weight_update_map was not provided. "
|
|
228
|
+
f"Please provide a mapping for all destinations: {list(self.collector._weight_sync_schemes.keys())}"
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# Validate that all scheme destinations are covered in the map
|
|
232
|
+
scheme_destinations = set(self.collector._weight_sync_schemes.keys())
|
|
233
|
+
map_destinations = set(weight_update_map.keys())
|
|
234
|
+
|
|
235
|
+
if scheme_destinations != map_destinations:
|
|
236
|
+
missing = scheme_destinations - map_destinations
|
|
237
|
+
extra = map_destinations - scheme_destinations
|
|
238
|
+
error_msg = "weight_update_map does not match collector's weight_sync_schemes.\n"
|
|
239
|
+
if missing:
|
|
240
|
+
error_msg += f" Missing destinations: {missing}\n"
|
|
241
|
+
if extra:
|
|
242
|
+
error_msg += f" Extra destinations: {extra}\n"
|
|
243
|
+
raise ValueError(error_msg)
|
|
244
|
+
|
|
245
|
+
# Use the weight_update_map approach
|
|
246
|
+
update_weights = UpdateWeights(
|
|
247
|
+
self.collector,
|
|
248
|
+
1,
|
|
249
|
+
weight_update_map=weight_update_map,
|
|
250
|
+
trainer=self,
|
|
251
|
+
)
|
|
252
|
+
else:
|
|
253
|
+
# Fall back to legacy approach for backward compatibility
|
|
254
|
+
if weight_update_map is not None:
|
|
255
|
+
warnings.warn(
|
|
256
|
+
"weight_update_map was provided but collector has no weight_sync_schemes. "
|
|
257
|
+
"Ignoring weight_update_map and using legacy policy_weights_getter.",
|
|
258
|
+
UserWarning,
|
|
259
|
+
stacklevel=2,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
policy_weights_getter = partial(
|
|
263
|
+
TensorDict.from_module, self.loss_module.actor_network
|
|
264
|
+
)
|
|
265
|
+
update_weights = UpdateWeights(
|
|
266
|
+
self.collector, 1, policy_weights_getter=policy_weights_getter
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
self.register_op("post_steps", update_weights)
|
|
270
|
+
|
|
271
|
+
# Store logging configuration
|
|
272
|
+
self.enable_logging = enable_logging
|
|
273
|
+
self.log_rewards = log_rewards
|
|
274
|
+
self.log_actions = log_actions
|
|
275
|
+
self.log_observations = log_observations
|
|
276
|
+
|
|
277
|
+
# Set up comprehensive logging for PPO training
|
|
278
|
+
if self.enable_logging:
|
|
279
|
+
self._setup_ppo_logging()
|
|
280
|
+
|
|
281
|
+
def _setup_ppo_logging(self):
|
|
282
|
+
"""Set up logging hooks for PPO-specific metrics.
|
|
283
|
+
|
|
284
|
+
This method configures logging for common PPO metrics including:
|
|
285
|
+
- Training rewards (mean and std)
|
|
286
|
+
- Action statistics (norms, entropy)
|
|
287
|
+
- Episode completion rates
|
|
288
|
+
- Value function statistics
|
|
289
|
+
- Advantage statistics
|
|
290
|
+
"""
|
|
291
|
+
# Always log done states as percentage (episode completion rate)
|
|
292
|
+
log_done_percentage = LogScalar(
|
|
293
|
+
key=("next", "done"),
|
|
294
|
+
logname="done_percentage",
|
|
295
|
+
log_pbar=True,
|
|
296
|
+
include_std=False, # No std for binary values
|
|
297
|
+
reduction="mean",
|
|
298
|
+
)
|
|
299
|
+
if not self.async_collection:
|
|
300
|
+
self.register_op("pre_steps_log", log_done_percentage)
|
|
301
|
+
else:
|
|
302
|
+
self.register_op("post_optim_log", log_done_percentage)
|
|
303
|
+
|
|
304
|
+
# Log rewards if enabled
|
|
305
|
+
if self.log_rewards:
|
|
306
|
+
# 1. Log training rewards (most important metric for PPO)
|
|
307
|
+
log_rewards = LogScalar(
|
|
308
|
+
key=("next", "reward"),
|
|
309
|
+
logname="r_training",
|
|
310
|
+
log_pbar=True, # Show in progress bar
|
|
311
|
+
include_std=True,
|
|
312
|
+
reduction="mean",
|
|
313
|
+
)
|
|
314
|
+
if not self.async_collection:
|
|
315
|
+
self.register_op("pre_steps_log", log_rewards)
|
|
316
|
+
else:
|
|
317
|
+
self.register_op("post_optim_log", log_rewards)
|
|
318
|
+
|
|
319
|
+
# 2. Log maximum reward in batch (for monitoring best performance)
|
|
320
|
+
log_max_reward = LogScalar(
|
|
321
|
+
key=("next", "reward"),
|
|
322
|
+
logname="r_max",
|
|
323
|
+
log_pbar=False,
|
|
324
|
+
include_std=False,
|
|
325
|
+
reduction="max",
|
|
326
|
+
)
|
|
327
|
+
if not self.async_collection:
|
|
328
|
+
self.register_op("pre_steps_log", log_max_reward)
|
|
329
|
+
else:
|
|
330
|
+
self.register_op("post_optim_log", log_max_reward)
|
|
331
|
+
|
|
332
|
+
# 3. Log total reward in batch (for monitoring cumulative performance)
|
|
333
|
+
log_total_reward = LogScalar(
|
|
334
|
+
key=("next", "reward"),
|
|
335
|
+
logname="r_total",
|
|
336
|
+
log_pbar=False,
|
|
337
|
+
include_std=False,
|
|
338
|
+
reduction="sum",
|
|
339
|
+
)
|
|
340
|
+
if not self.async_collection:
|
|
341
|
+
self.register_op("pre_steps_log", log_total_reward)
|
|
342
|
+
else:
|
|
343
|
+
self.register_op("post_optim_log", log_total_reward)
|
|
344
|
+
|
|
345
|
+
# Log actions if enabled
|
|
346
|
+
if self.log_actions:
|
|
347
|
+
# 4. Log action norms (useful for monitoring policy behavior)
|
|
348
|
+
log_action_norm = LogScalar(
|
|
349
|
+
key="action",
|
|
350
|
+
logname="action_norm",
|
|
351
|
+
log_pbar=False,
|
|
352
|
+
include_std=True,
|
|
353
|
+
reduction="mean",
|
|
354
|
+
)
|
|
355
|
+
if not self.async_collection:
|
|
356
|
+
self.register_op("pre_steps_log", log_action_norm)
|
|
357
|
+
else:
|
|
358
|
+
self.register_op("post_optim_log", log_action_norm)
|
|
359
|
+
|
|
360
|
+
# Log observations if enabled
|
|
361
|
+
if self.log_observations:
|
|
362
|
+
# 5. Log observation statistics (for monitoring state distributions)
|
|
363
|
+
log_obs_norm = LogScalar(
|
|
364
|
+
key="observation",
|
|
365
|
+
logname="obs_norm",
|
|
366
|
+
log_pbar=False,
|
|
367
|
+
include_std=True,
|
|
368
|
+
reduction="mean",
|
|
369
|
+
)
|
|
370
|
+
if not self.async_collection:
|
|
371
|
+
self.register_op("pre_steps_log", log_obs_norm)
|
|
372
|
+
else:
|
|
373
|
+
self.register_op("post_optim_log", log_obs_norm)
|
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import pathlib
|
|
9
|
+
import warnings
|
|
10
|
+
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
|
|
13
|
+
from functools import partial
|
|
14
|
+
|
|
15
|
+
from tensordict import TensorDict, TensorDictBase
|
|
16
|
+
from torch import optim
|
|
17
|
+
|
|
18
|
+
from torchrl.collectors import BaseCollector
|
|
19
|
+
|
|
20
|
+
from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
|
|
21
|
+
from torchrl.objectives.common import LossModule
|
|
22
|
+
from torchrl.objectives.utils import TargetNetUpdater
|
|
23
|
+
from torchrl.record.loggers import Logger
|
|
24
|
+
from torchrl.trainers.trainers import (
|
|
25
|
+
LogScalar,
|
|
26
|
+
ReplayBufferTrainer,
|
|
27
|
+
TargetNetUpdaterHook,
|
|
28
|
+
Trainer,
|
|
29
|
+
UpdateWeights,
|
|
30
|
+
UTDRHook,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class SACTrainer(Trainer):
|
|
35
|
+
"""A trainer class for Soft Actor-Critic (SAC) algorithm.
|
|
36
|
+
|
|
37
|
+
This trainer implements the SAC algorithm, an off-policy actor-critic method that
|
|
38
|
+
optimizes a stochastic policy in an off-policy way, forming a bridge between
|
|
39
|
+
stochastic policy optimization and DDPG-style approaches. SAC incorporates the
|
|
40
|
+
entropy measure of the policy into the reward to encourage exploration.
|
|
41
|
+
|
|
42
|
+
The trainer handles:
|
|
43
|
+
- Replay buffer management for off-policy learning
|
|
44
|
+
- Target network updates with configurable update frequency
|
|
45
|
+
- Policy weight updates to the data collector
|
|
46
|
+
- Comprehensive logging of training metrics
|
|
47
|
+
- Gradient clipping and optimization steps
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
collector (BaseCollector): The data collector used to gather environment interactions.
|
|
51
|
+
total_frames (int): Total number of frames to collect during training.
|
|
52
|
+
frame_skip (int): Number of frames to skip between policy updates.
|
|
53
|
+
optim_steps_per_batch (int): Number of optimization steps per collected batch.
|
|
54
|
+
loss_module (LossModule | Callable): The SAC loss module or a callable that computes losses.
|
|
55
|
+
optimizer (optim.Optimizer, optional): The optimizer for training. If None, must be configured elsewhere.
|
|
56
|
+
logger (Logger, optional): Logger for recording training metrics. Defaults to None.
|
|
57
|
+
clip_grad_norm (bool, optional): Whether to clip gradient norms. Defaults to True.
|
|
58
|
+
clip_norm (float, optional): Maximum gradient norm for clipping. Defaults to None.
|
|
59
|
+
progress_bar (bool, optional): Whether to show a progress bar during training. Defaults to True.
|
|
60
|
+
seed (int, optional): Random seed for reproducibility. Defaults to None.
|
|
61
|
+
save_trainer_interval (int, optional): Interval for saving trainer state. Defaults to 10000.
|
|
62
|
+
log_interval (int, optional): Interval for logging metrics. Defaults to 10000.
|
|
63
|
+
save_trainer_file (str | pathlib.Path, optional): File path for saving trainer state. Defaults to None.
|
|
64
|
+
replay_buffer (ReplayBuffer, optional): Replay buffer for storing and sampling experiences. Defaults to None.
|
|
65
|
+
batch_size (int, optional): Batch size for sampling from replay buffer. Defaults to None.
|
|
66
|
+
enable_logging (bool, optional): Whether to enable metric logging. Defaults to True.
|
|
67
|
+
log_rewards (bool, optional): Whether to log reward statistics. Defaults to True.
|
|
68
|
+
log_actions (bool, optional): Whether to log action statistics. Defaults to True.
|
|
69
|
+
log_observations (bool, optional): Whether to log observation statistics. Defaults to False.
|
|
70
|
+
target_net_updater (TargetNetUpdater, optional): Target network updater for soft updates. Defaults to None.
|
|
71
|
+
|
|
72
|
+
Example:
|
|
73
|
+
>>> from torchrl.collectors import Collector
|
|
74
|
+
>>> from torchrl.objectives import SACLoss
|
|
75
|
+
>>> from torchrl.data import ReplayBuffer, LazyTensorStorage
|
|
76
|
+
>>> from torch import optim
|
|
77
|
+
>>>
|
|
78
|
+
>>> # Set up collector, loss, and replay buffer
|
|
79
|
+
>>> collector = Collector(env, policy, frames_per_batch=1000)
|
|
80
|
+
>>> loss_module = SACLoss(actor_network, qvalue_network)
|
|
81
|
+
>>> optimizer = optim.Adam(loss_module.parameters(), lr=3e-4)
|
|
82
|
+
>>> replay_buffer = ReplayBuffer(storage=LazyTensorStorage(100000))
|
|
83
|
+
>>>
|
|
84
|
+
>>> # Create and run trainer
|
|
85
|
+
>>> trainer = SACTrainer(
|
|
86
|
+
... collector=collector,
|
|
87
|
+
... total_frames=1000000,
|
|
88
|
+
... frame_skip=1,
|
|
89
|
+
... optim_steps_per_batch=100,
|
|
90
|
+
... loss_module=loss_module,
|
|
91
|
+
... optimizer=optimizer,
|
|
92
|
+
... replay_buffer=replay_buffer,
|
|
93
|
+
... )
|
|
94
|
+
>>> trainer.train()
|
|
95
|
+
|
|
96
|
+
Note:
|
|
97
|
+
This is an experimental/prototype feature. The API may change in future versions.
|
|
98
|
+
SAC is particularly effective for continuous control tasks and environments where
|
|
99
|
+
exploration is crucial due to its entropy regularization.
|
|
100
|
+
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
*,
|
|
106
|
+
collector: BaseCollector,
|
|
107
|
+
total_frames: int,
|
|
108
|
+
frame_skip: int,
|
|
109
|
+
optim_steps_per_batch: int,
|
|
110
|
+
loss_module: LossModule | Callable[[TensorDictBase], TensorDictBase],
|
|
111
|
+
optimizer: optim.Optimizer | None = None,
|
|
112
|
+
logger: Logger | None = None,
|
|
113
|
+
clip_grad_norm: bool = True,
|
|
114
|
+
clip_norm: float | None = None,
|
|
115
|
+
progress_bar: bool = True,
|
|
116
|
+
seed: int | None = None,
|
|
117
|
+
save_trainer_interval: int = 10000,
|
|
118
|
+
log_interval: int = 10000,
|
|
119
|
+
save_trainer_file: str | pathlib.Path | None = None,
|
|
120
|
+
replay_buffer: ReplayBuffer | None = None,
|
|
121
|
+
batch_size: int | None = None,
|
|
122
|
+
enable_logging: bool = True,
|
|
123
|
+
log_rewards: bool = True,
|
|
124
|
+
log_actions: bool = True,
|
|
125
|
+
log_observations: bool = False,
|
|
126
|
+
target_net_updater: TargetNetUpdater | None = None,
|
|
127
|
+
async_collection: bool = False,
|
|
128
|
+
log_timings: bool = False,
|
|
129
|
+
) -> None:
|
|
130
|
+
warnings.warn(
|
|
131
|
+
"SACTrainer is an experimental/prototype feature. The API may change in future versions. "
|
|
132
|
+
"Please report any issues or feedback to help improve this implementation.",
|
|
133
|
+
UserWarning,
|
|
134
|
+
stacklevel=2,
|
|
135
|
+
)
|
|
136
|
+
# try to get the action spec
|
|
137
|
+
self._pass_action_spec_from_collector_to_loss(collector, loss_module)
|
|
138
|
+
|
|
139
|
+
super().__init__(
|
|
140
|
+
collector=collector,
|
|
141
|
+
total_frames=total_frames,
|
|
142
|
+
frame_skip=frame_skip,
|
|
143
|
+
optim_steps_per_batch=optim_steps_per_batch,
|
|
144
|
+
loss_module=loss_module,
|
|
145
|
+
optimizer=optimizer,
|
|
146
|
+
logger=logger,
|
|
147
|
+
clip_grad_norm=clip_grad_norm,
|
|
148
|
+
clip_norm=clip_norm,
|
|
149
|
+
progress_bar=progress_bar,
|
|
150
|
+
seed=seed,
|
|
151
|
+
save_trainer_interval=save_trainer_interval,
|
|
152
|
+
log_interval=log_interval,
|
|
153
|
+
save_trainer_file=save_trainer_file,
|
|
154
|
+
async_collection=async_collection,
|
|
155
|
+
log_timings=log_timings,
|
|
156
|
+
)
|
|
157
|
+
self.replay_buffer = replay_buffer
|
|
158
|
+
self.async_collection = async_collection
|
|
159
|
+
|
|
160
|
+
# Note: SAC can use any sampler type, unlike PPO which requires SamplerWithoutReplacement
|
|
161
|
+
|
|
162
|
+
if replay_buffer is not None:
|
|
163
|
+
rb_trainer = ReplayBufferTrainer(
|
|
164
|
+
replay_buffer,
|
|
165
|
+
batch_size=None,
|
|
166
|
+
flatten_tensordicts=True,
|
|
167
|
+
memmap=False,
|
|
168
|
+
device=getattr(replay_buffer.storage, "device", "cpu"),
|
|
169
|
+
iterate=True,
|
|
170
|
+
)
|
|
171
|
+
if not self.async_collection:
|
|
172
|
+
self.register_op("pre_epoch", rb_trainer.extend)
|
|
173
|
+
self.register_op("process_optim_batch", rb_trainer.sample)
|
|
174
|
+
self.register_op("post_loss", rb_trainer.update_priority)
|
|
175
|
+
self.register_op("post_optim", TargetNetUpdaterHook(target_net_updater))
|
|
176
|
+
|
|
177
|
+
policy_weights_getter = partial(
|
|
178
|
+
TensorDict.from_module, self.loss_module.actor_network
|
|
179
|
+
)
|
|
180
|
+
update_weights = UpdateWeights(
|
|
181
|
+
self.collector, 1, policy_weights_getter=policy_weights_getter
|
|
182
|
+
)
|
|
183
|
+
self.register_op("post_steps", update_weights)
|
|
184
|
+
|
|
185
|
+
# Store logging configuration
|
|
186
|
+
self.enable_logging = enable_logging
|
|
187
|
+
self.log_rewards = log_rewards
|
|
188
|
+
self.log_actions = log_actions
|
|
189
|
+
self.log_observations = log_observations
|
|
190
|
+
|
|
191
|
+
# Set up comprehensive logging for SAC training
|
|
192
|
+
if self.enable_logging:
|
|
193
|
+
self._setup_sac_logging()
|
|
194
|
+
|
|
195
|
+
def _pass_action_spec_from_collector_to_loss(
|
|
196
|
+
self, collector: BaseCollector, loss: LossModule
|
|
197
|
+
):
|
|
198
|
+
"""Pass the action specification from the collector's environment to the loss module.
|
|
199
|
+
|
|
200
|
+
This method extracts the action specification from the collector's environment
|
|
201
|
+
and assigns it to the loss module if the loss module doesn't already have one.
|
|
202
|
+
This is necessary for SAC loss computation which requires knowledge of the
|
|
203
|
+
action space bounds for proper entropy calculation and action clipping.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
collector (BaseCollector): The data collector containing the environment.
|
|
207
|
+
loss (LossModule): The loss module that needs the action specification.
|
|
208
|
+
"""
|
|
209
|
+
if hasattr(loss, "_action_spec") and loss._action_spec is None:
|
|
210
|
+
action_spec = collector.getattr_env("full_action_spec_unbatched").cpu()
|
|
211
|
+
loss._action_spec = action_spec
|
|
212
|
+
|
|
213
|
+
def _setup_sac_logging(self):
|
|
214
|
+
"""Set up logging hooks for SAC-specific metrics.
|
|
215
|
+
|
|
216
|
+
This method configures logging for common SAC metrics including:
|
|
217
|
+
- Training rewards (mean, max, total, and std)
|
|
218
|
+
- Action statistics (action norms)
|
|
219
|
+
- Episode completion rates (done percentage)
|
|
220
|
+
- Observation statistics (when enabled)
|
|
221
|
+
- Q-value and policy loss metrics (handled by loss module)
|
|
222
|
+
"""
|
|
223
|
+
# Always log done states as percentage (episode completion rate)
|
|
224
|
+
log_done_percentage = LogScalar(
|
|
225
|
+
key=("next", "done"),
|
|
226
|
+
logname="done_percentage",
|
|
227
|
+
log_pbar=True,
|
|
228
|
+
include_std=False, # No std for binary values
|
|
229
|
+
reduction="mean",
|
|
230
|
+
)
|
|
231
|
+
if not self.async_collection:
|
|
232
|
+
self.register_op("pre_steps_log", log_done_percentage)
|
|
233
|
+
else:
|
|
234
|
+
self.register_op("post_optim_log", log_done_percentage)
|
|
235
|
+
|
|
236
|
+
# Log rewards if enabled
|
|
237
|
+
if self.log_rewards:
|
|
238
|
+
# 1. Log training rewards (most important metric for SAC)
|
|
239
|
+
log_rewards = LogScalar(
|
|
240
|
+
key=("next", "reward"),
|
|
241
|
+
logname="r_training",
|
|
242
|
+
log_pbar=True, # Show in progress bar
|
|
243
|
+
include_std=True,
|
|
244
|
+
reduction="mean",
|
|
245
|
+
)
|
|
246
|
+
if not self.async_collection:
|
|
247
|
+
self.register_op("pre_steps_log", log_rewards)
|
|
248
|
+
else:
|
|
249
|
+
# In the async case, use the batch passed to the optimizer
|
|
250
|
+
self.register_op("post_optim_log", log_rewards)
|
|
251
|
+
|
|
252
|
+
# 2. Log maximum reward in batch (for monitoring best performance)
|
|
253
|
+
log_max_reward = LogScalar(
|
|
254
|
+
key=("next", "reward"),
|
|
255
|
+
logname="r_max",
|
|
256
|
+
log_pbar=False,
|
|
257
|
+
include_std=False,
|
|
258
|
+
reduction="max",
|
|
259
|
+
)
|
|
260
|
+
if not self.async_collection:
|
|
261
|
+
self.register_op("pre_steps_log", log_max_reward)
|
|
262
|
+
else:
|
|
263
|
+
self.register_op("post_optim_log", log_max_reward)
|
|
264
|
+
|
|
265
|
+
# 3. Log total reward in batch (for monitoring cumulative performance)
|
|
266
|
+
log_total_reward = LogScalar(
|
|
267
|
+
key=("next", "reward_sum"),
|
|
268
|
+
logname="r_total",
|
|
269
|
+
log_pbar=False,
|
|
270
|
+
include_std=False,
|
|
271
|
+
reduction="max",
|
|
272
|
+
)
|
|
273
|
+
if not self.async_collection:
|
|
274
|
+
self.register_op("pre_steps_log", log_total_reward)
|
|
275
|
+
else:
|
|
276
|
+
self.register_op("post_optim_log", log_total_reward)
|
|
277
|
+
|
|
278
|
+
# Log actions if enabled
|
|
279
|
+
if self.log_actions:
|
|
280
|
+
# 4. Log action norms (useful for monitoring policy behavior)
|
|
281
|
+
log_action_norm = LogScalar(
|
|
282
|
+
key="action",
|
|
283
|
+
logname="action_norm",
|
|
284
|
+
log_pbar=False,
|
|
285
|
+
include_std=True,
|
|
286
|
+
reduction="mean",
|
|
287
|
+
)
|
|
288
|
+
if not self.async_collection:
|
|
289
|
+
self.register_op("pre_steps_log", log_action_norm)
|
|
290
|
+
else:
|
|
291
|
+
self.register_op("post_optim_log", log_action_norm)
|
|
292
|
+
|
|
293
|
+
# Log observations if enabled
|
|
294
|
+
if self.log_observations:
|
|
295
|
+
# 5. Log observation statistics (for monitoring state distributions)
|
|
296
|
+
log_obs_norm = LogScalar(
|
|
297
|
+
key="observation",
|
|
298
|
+
logname="obs_norm",
|
|
299
|
+
log_pbar=False,
|
|
300
|
+
include_std=True,
|
|
301
|
+
reduction="mean",
|
|
302
|
+
)
|
|
303
|
+
if not self.async_collection:
|
|
304
|
+
self.register_op("pre_steps_log", log_obs_norm)
|
|
305
|
+
else:
|
|
306
|
+
self.register_op("post_optim_log", log_obs_norm)
|
|
307
|
+
|
|
308
|
+
self.register_op("pre_steps_log", UTDRHook(self))
|