torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from torchrl.objectives import ClipPPOLoss, KLPENPPOLoss, PPOLoss, SACLoss
|
|
12
|
+
from torchrl.objectives.sac import DiscreteSACLoss
|
|
13
|
+
from torchrl.trainers.algorithms.configs.common import ConfigBase
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class LossConfig(ConfigBase):
|
|
18
|
+
"""A class to configure a loss.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
loss_type: The type of loss to use.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
_partial_: bool = False
|
|
25
|
+
|
|
26
|
+
def __post_init__(self) -> None:
|
|
27
|
+
"""Post-initialization hook for loss configurations."""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class SACLossConfig(LossConfig):
|
|
32
|
+
"""A class to configure a SAC loss."""
|
|
33
|
+
|
|
34
|
+
actor_network: Any = None
|
|
35
|
+
qvalue_network: Any = None
|
|
36
|
+
value_network: Any = None
|
|
37
|
+
discrete: bool = False
|
|
38
|
+
num_qvalue_nets: int = 2
|
|
39
|
+
loss_function: str = "smooth_l1"
|
|
40
|
+
alpha_init: float = 1.0
|
|
41
|
+
min_alpha: float | None = None
|
|
42
|
+
max_alpha: float | None = None
|
|
43
|
+
action_spec: Any = None
|
|
44
|
+
fixed_alpha: bool = False
|
|
45
|
+
target_entropy: str | float = "auto"
|
|
46
|
+
delay_actor: bool = False
|
|
47
|
+
delay_qvalue: bool = True
|
|
48
|
+
delay_value: bool = True
|
|
49
|
+
gamma: float | None = None
|
|
50
|
+
priority_key: str | None = None
|
|
51
|
+
separate_losses: bool = False
|
|
52
|
+
reduction: str | None = None
|
|
53
|
+
skip_done_states: bool = False
|
|
54
|
+
deactivate_vmap: bool = False
|
|
55
|
+
_target_: str = "torchrl.trainers.algorithms.configs.objectives._make_sac_loss"
|
|
56
|
+
|
|
57
|
+
def __post_init__(self) -> None:
|
|
58
|
+
"""Post-initialization hook for SAC loss configurations."""
|
|
59
|
+
super().__post_init__()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _make_sac_loss(*args, **kwargs) -> SACLoss:
|
|
63
|
+
discrete_loss_type = kwargs.pop("discrete", False)
|
|
64
|
+
|
|
65
|
+
# Instantiate networks if they are config objects
|
|
66
|
+
actor_network = kwargs.get("actor_network")
|
|
67
|
+
qvalue_network = kwargs.get("qvalue_network")
|
|
68
|
+
value_network = kwargs.get("value_network")
|
|
69
|
+
|
|
70
|
+
if actor_network is not None and hasattr(actor_network, "_target_"):
|
|
71
|
+
kwargs["actor_network"] = actor_network()
|
|
72
|
+
if qvalue_network is not None and hasattr(qvalue_network, "_target_"):
|
|
73
|
+
kwargs["qvalue_network"] = qvalue_network()
|
|
74
|
+
if value_network is not None and hasattr(value_network, "_target_"):
|
|
75
|
+
kwargs["value_network"] = value_network()
|
|
76
|
+
|
|
77
|
+
if discrete_loss_type:
|
|
78
|
+
return DiscreteSACLoss(*args, **kwargs)
|
|
79
|
+
else:
|
|
80
|
+
return SACLoss(*args, **kwargs)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class PPOLossConfig(LossConfig):
|
|
85
|
+
"""A class to configure a PPO loss."""
|
|
86
|
+
|
|
87
|
+
actor_network: Any = None
|
|
88
|
+
critic_network: Any = None
|
|
89
|
+
loss_type: str = "clip"
|
|
90
|
+
entropy_bonus: bool = True
|
|
91
|
+
samples_mc_entropy: int = 1
|
|
92
|
+
entropy_coeff: float | None = None
|
|
93
|
+
log_explained_variance: bool = True
|
|
94
|
+
critic_coeff: float = 0.25
|
|
95
|
+
loss_critic_type: str = "smooth_l1"
|
|
96
|
+
normalize_advantage: bool = True
|
|
97
|
+
normalize_advantage_exclude_dims: tuple = ()
|
|
98
|
+
gamma: float | None = None
|
|
99
|
+
separate_losses: bool = False
|
|
100
|
+
advantage_key: str | None = None
|
|
101
|
+
value_target_key: str | None = None
|
|
102
|
+
value_key: str | None = None
|
|
103
|
+
functional: bool = True
|
|
104
|
+
actor: Any = None
|
|
105
|
+
critic: Any = None
|
|
106
|
+
reduction: str | None = None
|
|
107
|
+
clip_value: float | None = None
|
|
108
|
+
device: Any = None
|
|
109
|
+
_target_: str = "torchrl.trainers.algorithms.configs.objectives._make_ppo_loss"
|
|
110
|
+
|
|
111
|
+
def __post_init__(self) -> None:
|
|
112
|
+
"""Post-initialization hook for PPO loss configurations."""
|
|
113
|
+
super().__post_init__()
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _make_ppo_loss(*args, **kwargs) -> PPOLoss:
|
|
117
|
+
loss_type = kwargs.pop("loss_type", "clip")
|
|
118
|
+
if loss_type == "clip":
|
|
119
|
+
return ClipPPOLoss(*args, **kwargs)
|
|
120
|
+
elif loss_type == "kl":
|
|
121
|
+
return KLPENPPOLoss(*args, **kwargs)
|
|
122
|
+
elif loss_type == "ppo":
|
|
123
|
+
return PPOLoss(*args, **kwargs)
|
|
124
|
+
else:
|
|
125
|
+
raise ValueError(f"Invalid loss type: {loss_type}")
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@dataclass
|
|
129
|
+
class TargetNetUpdaterConfig:
|
|
130
|
+
"""An abstract class to configure target net updaters."""
|
|
131
|
+
|
|
132
|
+
loss_module: Any
|
|
133
|
+
_partial_: bool = True
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@dataclass
|
|
137
|
+
class SoftUpdateConfig(TargetNetUpdaterConfig):
|
|
138
|
+
"""A class for soft update instantiation."""
|
|
139
|
+
|
|
140
|
+
_target_: str = "torchrl.objectives.utils.SoftUpdate"
|
|
141
|
+
eps: float | None = None # noqa # type-ignore
|
|
142
|
+
tau: float | None = 0.001 # noqa # type-ignore
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@dataclass
|
|
146
|
+
class HardUpdateConfig(TargetNetUpdaterConfig):
|
|
147
|
+
"""A class for hard update instantiation."""
|
|
148
|
+
|
|
149
|
+
_target_: str = "torchrl.objectives.utils.HardUpdate."
|
|
150
|
+
value_network_update_interval: int = 1000
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@dataclass
|
|
154
|
+
class GAEConfig(LossConfig):
|
|
155
|
+
"""A class to configure a GAELoss."""
|
|
156
|
+
|
|
157
|
+
gamma: float | None = None
|
|
158
|
+
lmbda: float | None = None
|
|
159
|
+
value_network: Any = None
|
|
160
|
+
average_gae: bool = True
|
|
161
|
+
differentiable: bool = False
|
|
162
|
+
vectorized: bool | None = None
|
|
163
|
+
skip_existing: bool | None = None
|
|
164
|
+
advantage_key: str | None = None
|
|
165
|
+
value_target_key: str | None = None
|
|
166
|
+
value_key: str | None = None
|
|
167
|
+
shifted: bool = False
|
|
168
|
+
device: Any = None
|
|
169
|
+
time_dim: int | None = None
|
|
170
|
+
auto_reset_env: bool = False
|
|
171
|
+
deactivate_vmap: bool = False
|
|
172
|
+
_target_: str = "torchrl.objectives.value.GAE"
|
|
173
|
+
_partial_: bool = False
|
|
174
|
+
|
|
175
|
+
def __post_init__(self) -> None:
|
|
176
|
+
"""Post-initialization hook for GAELoss configurations."""
|
|
177
|
+
super().__post_init__()
|
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from tensordict.nn import TensorDictModuleBase
|
|
13
|
+
|
|
14
|
+
from torchrl.collectors import BaseCollector
|
|
15
|
+
from torchrl.objectives.common import LossModule
|
|
16
|
+
from torchrl.objectives.utils import TargetNetUpdater
|
|
17
|
+
from torchrl.objectives.value.advantages import GAE
|
|
18
|
+
from torchrl.trainers.algorithms.configs.common import ConfigBase
|
|
19
|
+
from torchrl.trainers.algorithms.ppo import PPOTrainer
|
|
20
|
+
from torchrl.trainers.algorithms.sac import SACTrainer
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class TrainerConfig(ConfigBase):
|
|
25
|
+
"""Base configuration class for trainers."""
|
|
26
|
+
|
|
27
|
+
def __post_init__(self) -> None:
|
|
28
|
+
"""Post-initialization hook for trainer configurations."""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class SACTrainerConfig(TrainerConfig):
|
|
33
|
+
"""Configuration class for SAC (Soft Actor Critic) trainer.
|
|
34
|
+
|
|
35
|
+
This class defines the configuration parameters for creating a SAC trainer,
|
|
36
|
+
including both required and optional fields with sensible defaults.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
collector: Any
|
|
40
|
+
total_frames: int
|
|
41
|
+
optim_steps_per_batch: int | None
|
|
42
|
+
loss_module: Any
|
|
43
|
+
optimizer: Any
|
|
44
|
+
logger: Any
|
|
45
|
+
save_trainer_file: Any
|
|
46
|
+
replay_buffer: Any
|
|
47
|
+
frame_skip: int = 1
|
|
48
|
+
clip_grad_norm: bool = True
|
|
49
|
+
clip_norm: float | None = None
|
|
50
|
+
progress_bar: bool = True
|
|
51
|
+
seed: int | None = None
|
|
52
|
+
save_trainer_interval: int = 10000
|
|
53
|
+
log_interval: int = 10000
|
|
54
|
+
create_env_fn: Any = None
|
|
55
|
+
actor_network: Any = None
|
|
56
|
+
critic_network: Any = None
|
|
57
|
+
target_net_updater: Any = None
|
|
58
|
+
async_collection: bool = False
|
|
59
|
+
log_timings: bool = False
|
|
60
|
+
|
|
61
|
+
_target_: str = "torchrl.trainers.algorithms.configs.trainers._make_sac_trainer"
|
|
62
|
+
|
|
63
|
+
def __post_init__(self) -> None:
|
|
64
|
+
"""Post-initialization hook for SAC trainer configuration."""
|
|
65
|
+
super().__post_init__()
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _make_sac_trainer(*args, **kwargs) -> SACTrainer:
|
|
69
|
+
from torchrl.trainers.trainers import Logger
|
|
70
|
+
|
|
71
|
+
collector = kwargs.pop("collector")
|
|
72
|
+
total_frames = kwargs.pop("total_frames")
|
|
73
|
+
if total_frames is None:
|
|
74
|
+
total_frames = collector.total_frames
|
|
75
|
+
frame_skip = kwargs.pop("frame_skip", 1)
|
|
76
|
+
optim_steps_per_batch = kwargs.pop("optim_steps_per_batch", 1)
|
|
77
|
+
loss_module = kwargs.pop("loss_module")
|
|
78
|
+
optimizer = kwargs.pop("optimizer")
|
|
79
|
+
logger = kwargs.pop("logger")
|
|
80
|
+
clip_grad_norm = kwargs.pop("clip_grad_norm", True)
|
|
81
|
+
clip_norm = kwargs.pop("clip_norm")
|
|
82
|
+
progress_bar = kwargs.pop("progress_bar", True)
|
|
83
|
+
replay_buffer = kwargs.pop("replay_buffer")
|
|
84
|
+
save_trainer_interval = kwargs.pop("save_trainer_interval", 10000)
|
|
85
|
+
log_interval = kwargs.pop("log_interval", 10000)
|
|
86
|
+
save_trainer_file = kwargs.pop("save_trainer_file")
|
|
87
|
+
seed = kwargs.pop("seed")
|
|
88
|
+
actor_network = kwargs.pop("actor_network")
|
|
89
|
+
critic_network = kwargs.pop("critic_network")
|
|
90
|
+
kwargs.pop("create_env_fn")
|
|
91
|
+
target_net_updater = kwargs.pop("target_net_updater")
|
|
92
|
+
async_collection = kwargs.pop("async_collection", False)
|
|
93
|
+
log_timings = kwargs.pop("log_timings", False)
|
|
94
|
+
|
|
95
|
+
# Instantiate networks first
|
|
96
|
+
if actor_network is not None:
|
|
97
|
+
actor_network = actor_network()
|
|
98
|
+
if critic_network is not None:
|
|
99
|
+
critic_network = critic_network()
|
|
100
|
+
|
|
101
|
+
if not isinstance(collector, BaseCollector):
|
|
102
|
+
# then it's a partial config
|
|
103
|
+
if not async_collection:
|
|
104
|
+
collector = collector()
|
|
105
|
+
elif replay_buffer is not None:
|
|
106
|
+
collector = collector(replay_buffer=replay_buffer)
|
|
107
|
+
elif getattr(collector, "replay_buffer", None) is None:
|
|
108
|
+
if async_collection and (
|
|
109
|
+
collector.replay_buffer is None or replay_buffer is None
|
|
110
|
+
):
|
|
111
|
+
raise ValueError(
|
|
112
|
+
"replay_buffer must be provided when async_collection is True"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
if not isinstance(loss_module, LossModule):
|
|
116
|
+
# then it's a partial config
|
|
117
|
+
loss_module = loss_module(
|
|
118
|
+
actor_network=actor_network, critic_network=critic_network
|
|
119
|
+
)
|
|
120
|
+
if not isinstance(target_net_updater, TargetNetUpdater):
|
|
121
|
+
# target_net_updater must be a partial taking the loss as input
|
|
122
|
+
target_net_updater = target_net_updater(loss_module)
|
|
123
|
+
if not isinstance(optimizer, torch.optim.Optimizer):
|
|
124
|
+
# then it's a partial config
|
|
125
|
+
optimizer = optimizer(params=loss_module.parameters())
|
|
126
|
+
|
|
127
|
+
# Quick instance checks
|
|
128
|
+
if not isinstance(collector, BaseCollector):
|
|
129
|
+
raise ValueError(f"collector must be a BaseCollector, got {type(collector)}")
|
|
130
|
+
if not isinstance(loss_module, LossModule):
|
|
131
|
+
raise ValueError(f"loss_module must be a LossModule, got {type(loss_module)}")
|
|
132
|
+
if not isinstance(optimizer, torch.optim.Optimizer):
|
|
133
|
+
raise ValueError(
|
|
134
|
+
f"optimizer must be a torch.optim.Optimizer, got {type(optimizer)}"
|
|
135
|
+
)
|
|
136
|
+
if not isinstance(logger, Logger) and logger is not None:
|
|
137
|
+
raise ValueError(f"logger must be a Logger, got {type(logger)}")
|
|
138
|
+
|
|
139
|
+
return SACTrainer(
|
|
140
|
+
collector=collector,
|
|
141
|
+
total_frames=total_frames,
|
|
142
|
+
frame_skip=frame_skip,
|
|
143
|
+
optim_steps_per_batch=optim_steps_per_batch,
|
|
144
|
+
loss_module=loss_module,
|
|
145
|
+
optimizer=optimizer,
|
|
146
|
+
logger=logger,
|
|
147
|
+
clip_grad_norm=clip_grad_norm,
|
|
148
|
+
clip_norm=clip_norm,
|
|
149
|
+
progress_bar=progress_bar,
|
|
150
|
+
seed=seed,
|
|
151
|
+
save_trainer_interval=save_trainer_interval,
|
|
152
|
+
log_interval=log_interval,
|
|
153
|
+
save_trainer_file=save_trainer_file,
|
|
154
|
+
replay_buffer=replay_buffer,
|
|
155
|
+
target_net_updater=target_net_updater,
|
|
156
|
+
async_collection=async_collection,
|
|
157
|
+
log_timings=log_timings,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
@dataclass
|
|
162
|
+
class PPOTrainerConfig(TrainerConfig):
|
|
163
|
+
"""Configuration class for PPO (Proximal Policy Optimization) trainer.
|
|
164
|
+
|
|
165
|
+
This class defines the configuration parameters for creating a PPO trainer,
|
|
166
|
+
including both required and optional fields with sensible defaults.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
collector: The data collector for gathering training data.
|
|
170
|
+
total_frames: Total number of frames to train for.
|
|
171
|
+
optim_steps_per_batch: Number of optimization steps per batch.
|
|
172
|
+
loss_module: The loss module for computing policy and value losses.
|
|
173
|
+
optimizer: The optimizer for training.
|
|
174
|
+
logger: Logger for tracking training metrics.
|
|
175
|
+
save_trainer_file: File path for saving trainer state.
|
|
176
|
+
replay_buffer: Replay buffer for storing data.
|
|
177
|
+
frame_skip: Frame skip value for the environment. Default: 1.
|
|
178
|
+
clip_grad_norm: Whether to clip gradient norms. Default: True.
|
|
179
|
+
clip_norm: Maximum gradient norm value.
|
|
180
|
+
progress_bar: Whether to show a progress bar. Default: True.
|
|
181
|
+
seed: Random seed for reproducibility.
|
|
182
|
+
save_trainer_interval: Interval for saving trainer state. Default: 10000.
|
|
183
|
+
log_interval: Interval for logging metrics. Default: 10000.
|
|
184
|
+
create_env_fn: Environment creation function.
|
|
185
|
+
actor_network: Actor network configuration.
|
|
186
|
+
critic_network: Critic network configuration.
|
|
187
|
+
num_epochs: Number of epochs per batch. Default: 4.
|
|
188
|
+
async_collection: Whether to use async collection. Default: False.
|
|
189
|
+
add_gae: Whether to add GAE computation. Default: True.
|
|
190
|
+
gae: Custom GAE module configuration.
|
|
191
|
+
weight_update_map: Mapping from collector destination paths to trainer source paths.
|
|
192
|
+
Required if collector has weight_sync_schemes configured.
|
|
193
|
+
Example: ``{"policy": "loss_module.actor_network", "replay_buffer.transforms[0]": "loss_module.critic_network"}``.
|
|
194
|
+
log_timings: Whether to automatically log timing information for all hooks.
|
|
195
|
+
If True, timing metrics will be logged to the logger (e.g., wandb, tensorboard)
|
|
196
|
+
with prefix "time/" (e.g., "time/hook/UpdateWeights"). Default: False.
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
collector: Any
|
|
200
|
+
total_frames: int
|
|
201
|
+
optim_steps_per_batch: int | None
|
|
202
|
+
loss_module: Any
|
|
203
|
+
optimizer: Any
|
|
204
|
+
logger: Any
|
|
205
|
+
save_trainer_file: Any
|
|
206
|
+
replay_buffer: Any
|
|
207
|
+
frame_skip: int = 1
|
|
208
|
+
clip_grad_norm: bool = True
|
|
209
|
+
clip_norm: float | None = None
|
|
210
|
+
progress_bar: bool = True
|
|
211
|
+
seed: int | None = None
|
|
212
|
+
save_trainer_interval: int = 10000
|
|
213
|
+
log_interval: int = 10000
|
|
214
|
+
create_env_fn: Any = None
|
|
215
|
+
actor_network: Any = None
|
|
216
|
+
critic_network: Any = None
|
|
217
|
+
num_epochs: int = 4
|
|
218
|
+
async_collection: bool = False
|
|
219
|
+
add_gae: bool = True
|
|
220
|
+
gae: Any = None
|
|
221
|
+
weight_update_map: dict[str, str] | None = None
|
|
222
|
+
log_timings: bool = False
|
|
223
|
+
|
|
224
|
+
_target_: str = "torchrl.trainers.algorithms.configs.trainers._make_ppo_trainer"
|
|
225
|
+
|
|
226
|
+
def __post_init__(self) -> None:
|
|
227
|
+
"""Post-initialization hook for PPO trainer configuration."""
|
|
228
|
+
super().__post_init__()
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def _make_ppo_trainer(*args, **kwargs) -> PPOTrainer:
|
|
232
|
+
from torchrl.trainers.trainers import Logger
|
|
233
|
+
|
|
234
|
+
collector = kwargs.pop("collector")
|
|
235
|
+
total_frames = kwargs.pop("total_frames")
|
|
236
|
+
if total_frames is None:
|
|
237
|
+
total_frames = collector.total_frames
|
|
238
|
+
frame_skip = kwargs.pop("frame_skip", 1)
|
|
239
|
+
optim_steps_per_batch = kwargs.pop("optim_steps_per_batch", 1)
|
|
240
|
+
loss_module = kwargs.pop("loss_module")
|
|
241
|
+
optimizer = kwargs.pop("optimizer")
|
|
242
|
+
logger = kwargs.pop("logger")
|
|
243
|
+
clip_grad_norm = kwargs.pop("clip_grad_norm", True)
|
|
244
|
+
clip_norm = kwargs.pop("clip_norm")
|
|
245
|
+
progress_bar = kwargs.pop("progress_bar", True)
|
|
246
|
+
replay_buffer = kwargs.pop("replay_buffer")
|
|
247
|
+
save_trainer_interval = kwargs.pop("save_trainer_interval", 10000)
|
|
248
|
+
log_interval = kwargs.pop("log_interval", 10000)
|
|
249
|
+
save_trainer_file = kwargs.pop("save_trainer_file")
|
|
250
|
+
seed = kwargs.pop("seed")
|
|
251
|
+
actor_network = kwargs.pop("actor_network")
|
|
252
|
+
critic_network = kwargs.pop("critic_network")
|
|
253
|
+
add_gae = kwargs.pop("add_gae", True)
|
|
254
|
+
gae = kwargs.pop("gae")
|
|
255
|
+
create_env_fn = kwargs.pop("create_env_fn")
|
|
256
|
+
weight_update_map = kwargs.pop("weight_update_map", None)
|
|
257
|
+
log_timings = kwargs.pop("log_timings", False)
|
|
258
|
+
|
|
259
|
+
if create_env_fn is not None:
|
|
260
|
+
# could be referenced somewhere else, no need to raise an error
|
|
261
|
+
pass
|
|
262
|
+
num_epochs = kwargs.pop("num_epochs", 4)
|
|
263
|
+
async_collection = kwargs.pop("async_collection", False)
|
|
264
|
+
|
|
265
|
+
# Instantiate networks first
|
|
266
|
+
if actor_network is not None:
|
|
267
|
+
actor_network = actor_network()
|
|
268
|
+
if critic_network is not None:
|
|
269
|
+
critic_network = critic_network()
|
|
270
|
+
else:
|
|
271
|
+
critic_network = loss_module.critic_network
|
|
272
|
+
|
|
273
|
+
# Ensure GAE in replay buffer uses the same value network instance as loss module
|
|
274
|
+
# This fixes the issue where Hydra instantiates separate instances of value_model
|
|
275
|
+
if (
|
|
276
|
+
replay_buffer is not None
|
|
277
|
+
and hasattr(replay_buffer, "_transform")
|
|
278
|
+
and len(replay_buffer._transform) > 1
|
|
279
|
+
and hasattr(replay_buffer._transform[1], "module")
|
|
280
|
+
and hasattr(replay_buffer._transform[1].module, "value_network")
|
|
281
|
+
):
|
|
282
|
+
replay_buffer._transform[1].module.value_network = critic_network
|
|
283
|
+
|
|
284
|
+
if not isinstance(collector, BaseCollector):
|
|
285
|
+
# then it's a partial config
|
|
286
|
+
if not async_collection:
|
|
287
|
+
collector = collector()
|
|
288
|
+
else:
|
|
289
|
+
collector = collector(replay_buffer=replay_buffer)
|
|
290
|
+
elif async_collection and getattr(collector, "replay_buffer", None) is None:
|
|
291
|
+
raise RuntimeError(
|
|
292
|
+
"replay_buffer must be provided when async_collection is True"
|
|
293
|
+
)
|
|
294
|
+
if not isinstance(loss_module, LossModule):
|
|
295
|
+
# then it's a partial config
|
|
296
|
+
loss_module = loss_module(
|
|
297
|
+
actor_network=actor_network, critic_network=critic_network
|
|
298
|
+
)
|
|
299
|
+
if not isinstance(optimizer, torch.optim.Optimizer):
|
|
300
|
+
# then it's a partial config
|
|
301
|
+
optimizer = optimizer(params=loss_module.parameters())
|
|
302
|
+
|
|
303
|
+
# Quick instance checks
|
|
304
|
+
if not isinstance(collector, BaseCollector):
|
|
305
|
+
raise ValueError(f"collector must be a BaseCollector, got {type(collector)}")
|
|
306
|
+
if not isinstance(loss_module, LossModule):
|
|
307
|
+
raise ValueError(f"loss_module must be a LossModule, got {type(loss_module)}")
|
|
308
|
+
if not isinstance(optimizer, torch.optim.Optimizer):
|
|
309
|
+
raise ValueError(
|
|
310
|
+
f"optimizer must be a torch.optim.Optimizer, got {type(optimizer)}"
|
|
311
|
+
)
|
|
312
|
+
if not isinstance(logger, Logger) and logger is not None:
|
|
313
|
+
raise ValueError(f"logger must be a Logger, got {type(logger)}")
|
|
314
|
+
# instantiate gae if it is a partial config
|
|
315
|
+
if not isinstance(gae, (GAE, TensorDictModuleBase)) and gae is not None:
|
|
316
|
+
gae = gae()
|
|
317
|
+
|
|
318
|
+
return PPOTrainer(
|
|
319
|
+
collector=collector,
|
|
320
|
+
total_frames=total_frames,
|
|
321
|
+
frame_skip=frame_skip,
|
|
322
|
+
optim_steps_per_batch=optim_steps_per_batch,
|
|
323
|
+
loss_module=loss_module,
|
|
324
|
+
optimizer=optimizer,
|
|
325
|
+
logger=logger,
|
|
326
|
+
clip_grad_norm=clip_grad_norm,
|
|
327
|
+
clip_norm=clip_norm,
|
|
328
|
+
progress_bar=progress_bar,
|
|
329
|
+
seed=seed,
|
|
330
|
+
save_trainer_interval=save_trainer_interval,
|
|
331
|
+
log_interval=log_interval,
|
|
332
|
+
save_trainer_file=save_trainer_file,
|
|
333
|
+
replay_buffer=replay_buffer,
|
|
334
|
+
num_epochs=num_epochs,
|
|
335
|
+
async_collection=async_collection,
|
|
336
|
+
add_gae=add_gae,
|
|
337
|
+
gae=gae,
|
|
338
|
+
weight_update_map=weight_update_map,
|
|
339
|
+
log_timings=log_timings,
|
|
340
|
+
)
|