torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torchrl._utils import _make_ordinal_device
|
|
11
|
+
|
|
12
|
+
from torchrl.data.replay_buffers.replay_buffers import (
|
|
13
|
+
ReplayBuffer,
|
|
14
|
+
TensorDictReplayBuffer,
|
|
15
|
+
)
|
|
16
|
+
from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
|
|
17
|
+
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
|
|
18
|
+
from torchrl.data.utils import DEVICE_TYPING
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def make_replay_buffer(
|
|
22
|
+
device: DEVICE_TYPING, cfg: DictConfig # noqa: F821
|
|
23
|
+
) -> ReplayBuffer: # noqa: F821
|
|
24
|
+
"""Builds a replay buffer using the config built from ReplayArgsConfig."""
|
|
25
|
+
device = _make_ordinal_device(torch.device(device))
|
|
26
|
+
if not cfg.prb:
|
|
27
|
+
sampler = RandomSampler()
|
|
28
|
+
else:
|
|
29
|
+
sampler = PrioritizedSampler(
|
|
30
|
+
max_capacity=cfg.buffer_size,
|
|
31
|
+
alpha=0.7,
|
|
32
|
+
beta=0.5,
|
|
33
|
+
)
|
|
34
|
+
buffer = TensorDictReplayBuffer(
|
|
35
|
+
storage=LazyMemmapStorage(
|
|
36
|
+
cfg.buffer_size,
|
|
37
|
+
scratch_dir=cfg.buffer_scratch_dir,
|
|
38
|
+
# device=device, # when using prefetch, this can overload the GPU memory
|
|
39
|
+
),
|
|
40
|
+
sampler=sampler,
|
|
41
|
+
pin_memory=device != torch.device("cpu"),
|
|
42
|
+
prefetch=cfg.buffer_prefetch,
|
|
43
|
+
batch_size=cfg.batch_size,
|
|
44
|
+
)
|
|
45
|
+
return buffer
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class ReplayArgsConfig:
|
|
50
|
+
"""Generic Replay Buffer config struct."""
|
|
51
|
+
|
|
52
|
+
buffer_size: int = 1000000
|
|
53
|
+
# buffer size, in number of frames stored. Default=1e6
|
|
54
|
+
prb: bool = False
|
|
55
|
+
# whether a Prioritized replay buffer should be used instead of a more basic circular one.
|
|
56
|
+
buffer_scratch_dir: str | None = None
|
|
57
|
+
# directory where the buffer data should be stored. If none is passed, they will be placed in /tmp/
|
|
58
|
+
buffer_prefetch: int = 10
|
|
59
|
+
# prefetching queue length for the replay buffer
|
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from warnings import warn
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from tensordict.nn import TensorDictModule, TensorDictModuleWrapper
|
|
12
|
+
from torch import optim
|
|
13
|
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
14
|
+
|
|
15
|
+
from torchrl._utils import logger as torchrl_logger, VERBOSE
|
|
16
|
+
from torchrl.collectors import BaseCollector
|
|
17
|
+
from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
|
|
18
|
+
from torchrl.envs.common import EnvBase
|
|
19
|
+
from torchrl.envs.utils import ExplorationType
|
|
20
|
+
from torchrl.modules import reset_noise
|
|
21
|
+
from torchrl.objectives.common import LossModule
|
|
22
|
+
from torchrl.objectives.utils import TargetNetUpdater
|
|
23
|
+
from torchrl.record.loggers import Logger
|
|
24
|
+
from torchrl.trainers.trainers import (
|
|
25
|
+
BatchSubSampler,
|
|
26
|
+
ClearCudaCache,
|
|
27
|
+
CountFramesLog,
|
|
28
|
+
LogScalar,
|
|
29
|
+
LogValidationReward,
|
|
30
|
+
ReplayBufferTrainer,
|
|
31
|
+
RewardNormalizer,
|
|
32
|
+
SelectKeys,
|
|
33
|
+
Trainer,
|
|
34
|
+
UpdateWeights,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
OPTIMIZERS = {
|
|
38
|
+
"adam": optim.Adam,
|
|
39
|
+
"sgd": optim.SGD,
|
|
40
|
+
"adamax": optim.Adamax,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class TrainerConfig:
|
|
46
|
+
"""Trainer config struct."""
|
|
47
|
+
|
|
48
|
+
optim_steps_per_batch: int = 500
|
|
49
|
+
# Number of optimization steps in between two collection of data. See frames_per_batch below.
|
|
50
|
+
optimizer: str = "adam"
|
|
51
|
+
# Optimizer to be used.
|
|
52
|
+
lr_scheduler: str = "cosine"
|
|
53
|
+
# LR scheduler.
|
|
54
|
+
selected_keys: list | None = None
|
|
55
|
+
# a list of strings that indicate the data that should be kept from the data collector. Since storing and
|
|
56
|
+
# retrieving information from the replay buffer does not come for free, limiting the amount of data
|
|
57
|
+
# passed to it can improve the algorithm performance.
|
|
58
|
+
batch_size: int = 256
|
|
59
|
+
# batch size of the TensorDict retrieved from the replay buffer. Default=256.
|
|
60
|
+
log_interval: int = 10000
|
|
61
|
+
# logging interval, in terms of optimization steps. Default=10000.
|
|
62
|
+
lr: float = 3e-4
|
|
63
|
+
# Learning rate used for the optimizer. Default=3e-4.
|
|
64
|
+
weight_decay: float = 0.0
|
|
65
|
+
# Weight-decay to be used with the optimizer. Default=0.0.
|
|
66
|
+
clip_norm: float = 1000.0
|
|
67
|
+
# value at which the total gradient norm / single derivative should be clipped. Default=1000.0
|
|
68
|
+
clip_grad_norm: bool = False
|
|
69
|
+
# if called, the gradient will be clipped based on its L2 norm. Otherwise, single gradient values will be clipped to the desired threshold.
|
|
70
|
+
normalize_rewards_online: bool = False
|
|
71
|
+
# Computes the running statistics of the rewards and normalizes them before they are passed to the loss module.
|
|
72
|
+
normalize_rewards_online_scale: float = 1.0
|
|
73
|
+
# Final scale of the normalized rewards.
|
|
74
|
+
normalize_rewards_online_decay: float = 0.9999
|
|
75
|
+
# Decay of the reward moving averaging
|
|
76
|
+
sub_traj_len: int = -1
|
|
77
|
+
# length of the trajectories that sub-samples must have in online settings.
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def make_trainer(
|
|
81
|
+
collector: BaseCollector,
|
|
82
|
+
loss_module: LossModule,
|
|
83
|
+
recorder: EnvBase | None = None,
|
|
84
|
+
target_net_updater: TargetNetUpdater | None = None,
|
|
85
|
+
policy_exploration: None | (TensorDictModuleWrapper | TensorDictModule) = None,
|
|
86
|
+
replay_buffer: ReplayBuffer | None = None,
|
|
87
|
+
logger: Logger | None = None,
|
|
88
|
+
cfg: DictConfig = None, # noqa: F821
|
|
89
|
+
) -> Trainer:
|
|
90
|
+
"""Creates a Trainer instance given its constituents.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
collector (BaseCollector): A data collector to be used to collect data.
|
|
94
|
+
loss_module (LossModule): A TorchRL loss module
|
|
95
|
+
recorder (EnvBase, optional): a recorder environment. If None, the trainer will train the policy without
|
|
96
|
+
testing it.
|
|
97
|
+
target_net_updater (TargetNetUpdater, optional): A target network update object.
|
|
98
|
+
policy_exploration (TDModule or TensorDictModuleWrapper, optional): a policy to be used for recording and exploration
|
|
99
|
+
updates (should be synced with the learnt policy).
|
|
100
|
+
replay_buffer (ReplayBuffer, optional): a replay buffer to be used to collect data.
|
|
101
|
+
logger (Logger, optional): a Logger to be used for logging.
|
|
102
|
+
cfg (DictConfig, optional): a DictConfig containing the arguments of the script. If None, the default
|
|
103
|
+
arguments are used.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
A trainer built with the input objects. The optimizer is built by this helper function using the cfg provided.
|
|
107
|
+
|
|
108
|
+
Examples:
|
|
109
|
+
>>> import torch
|
|
110
|
+
>>> import tempfile
|
|
111
|
+
>>> from torchrl.trainers.loggers import TensorboardLogger
|
|
112
|
+
>>> from torchrl.trainers import Trainer
|
|
113
|
+
>>> from torchrl.envs import EnvCreator
|
|
114
|
+
>>> from torchrl.collectors import Collector
|
|
115
|
+
>>> from torchrl.data import TensorDictReplayBuffer
|
|
116
|
+
>>> from torchrl.envs.libs.gym import GymEnv
|
|
117
|
+
>>> from torchrl.modules import TensorDictModuleWrapper, SafeModule, ValueOperator, EGreedyWrapper
|
|
118
|
+
>>> from torchrl.objectives.common import LossModule
|
|
119
|
+
>>> from torchrl.objectives.utils import TargetNetUpdater
|
|
120
|
+
>>> from torchrl.objectives import DDPGLoss
|
|
121
|
+
>>> env_maker = EnvCreator(lambda: GymEnv("Pendulum-v0"))
|
|
122
|
+
>>> env_proof = env_maker()
|
|
123
|
+
>>> obs_spec = env_proof.observation_spec
|
|
124
|
+
>>> action_spec = env_proof.action_spec
|
|
125
|
+
>>> net = torch.nn.Linear(env_proof.observation_spec.shape[-1], action_spec.shape[-1])
|
|
126
|
+
>>> net_value = torch.nn.Linear(env_proof.observation_spec.shape[-1], 1) # for the purpose of testing
|
|
127
|
+
>>> policy = SafeModule(action_spec, net, in_keys=["observation"], out_keys=["action"])
|
|
128
|
+
>>> value = ValueOperator(net_value, in_keys=["observation"], out_keys=["state_action_value"])
|
|
129
|
+
>>> collector = Collector(env_maker, policy, total_frames=100)
|
|
130
|
+
>>> loss_module = DDPGLoss(policy, value, gamma=0.99)
|
|
131
|
+
>>> recorder = env_proof
|
|
132
|
+
>>> target_net_updater = None
|
|
133
|
+
>>> policy_exploration = EGreedyWrapper(policy)
|
|
134
|
+
>>> replay_buffer = TensorDictReplayBuffer()
|
|
135
|
+
>>> dir = tempfile.gettempdir()
|
|
136
|
+
>>> logger = TensorboardLogger(exp_name=dir)
|
|
137
|
+
>>> trainer = make_trainer(collector, loss_module, recorder, target_net_updater, policy_exploration,
|
|
138
|
+
... replay_buffer, logger)
|
|
139
|
+
>>> print(trainer)
|
|
140
|
+
|
|
141
|
+
"""
|
|
142
|
+
if cfg is None:
|
|
143
|
+
warn(
|
|
144
|
+
"Getting default cfg for the trainer. "
|
|
145
|
+
"This should be only used for debugging."
|
|
146
|
+
)
|
|
147
|
+
cfg = TrainerConfig()
|
|
148
|
+
cfg.frame_skip = 1
|
|
149
|
+
cfg.total_frames = 1000
|
|
150
|
+
cfg.record_frames = 10
|
|
151
|
+
cfg.record_interval = 10
|
|
152
|
+
|
|
153
|
+
optimizer_kwargs = {} if cfg.optimizer != "adam" else {"betas": (0.0, 0.9)}
|
|
154
|
+
optimizer = OPTIMIZERS[cfg.optimizer](
|
|
155
|
+
loss_module.parameters(),
|
|
156
|
+
lr=cfg.lr,
|
|
157
|
+
weight_decay=cfg.weight_decay,
|
|
158
|
+
**optimizer_kwargs,
|
|
159
|
+
)
|
|
160
|
+
device = next(loss_module.parameters()).device
|
|
161
|
+
if cfg.lr_scheduler == "cosine":
|
|
162
|
+
optim_scheduler = CosineAnnealingLR(
|
|
163
|
+
optimizer,
|
|
164
|
+
T_max=int(
|
|
165
|
+
cfg.total_frames / cfg.frames_per_batch * cfg.optim_steps_per_batch
|
|
166
|
+
),
|
|
167
|
+
)
|
|
168
|
+
elif cfg.lr_scheduler == "":
|
|
169
|
+
optim_scheduler = None
|
|
170
|
+
else:
|
|
171
|
+
raise NotImplementedError(f"lr scheduler {cfg.lr_scheduler}")
|
|
172
|
+
|
|
173
|
+
if VERBOSE:
|
|
174
|
+
torchrl_logger.info(
|
|
175
|
+
f"collector = {collector}; \n"
|
|
176
|
+
f"loss_module = {loss_module}; \n"
|
|
177
|
+
f"recorder = {recorder}; \n"
|
|
178
|
+
f"target_net_updater = {target_net_updater}; \n"
|
|
179
|
+
f"policy_exploration = {policy_exploration}; \n"
|
|
180
|
+
f"replay_buffer = {replay_buffer}; \n"
|
|
181
|
+
f"logger = {logger}; \n"
|
|
182
|
+
f"cfg = {cfg}; \n"
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
if logger is not None:
|
|
186
|
+
# log hyperparams
|
|
187
|
+
logger.log_hparams(cfg)
|
|
188
|
+
|
|
189
|
+
trainer = Trainer(
|
|
190
|
+
collector=collector,
|
|
191
|
+
frame_skip=cfg.frame_skip,
|
|
192
|
+
total_frames=cfg.total_frames * cfg.frame_skip,
|
|
193
|
+
loss_module=loss_module,
|
|
194
|
+
optimizer=optimizer,
|
|
195
|
+
logger=logger,
|
|
196
|
+
optim_steps_per_batch=cfg.optim_steps_per_batch,
|
|
197
|
+
clip_grad_norm=cfg.clip_grad_norm,
|
|
198
|
+
clip_norm=cfg.clip_norm,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
if torch.cuda.device_count() > 0:
|
|
202
|
+
trainer.register_op("pre_optim_steps", ClearCudaCache(1))
|
|
203
|
+
|
|
204
|
+
if hasattr(cfg, "noisy") and cfg.noisy:
|
|
205
|
+
trainer.register_op("pre_optim_steps", lambda: loss_module.apply(reset_noise))
|
|
206
|
+
|
|
207
|
+
if cfg.selected_keys:
|
|
208
|
+
trainer.register_op("batch_process", SelectKeys(cfg.selected_keys))
|
|
209
|
+
trainer.register_op("batch_process", lambda batch: batch.cpu())
|
|
210
|
+
|
|
211
|
+
if replay_buffer is not None:
|
|
212
|
+
# replay buffer is used 2 or 3 times: to register data, to sample
|
|
213
|
+
# data and to update priorities
|
|
214
|
+
rb_trainer = ReplayBufferTrainer(
|
|
215
|
+
replay_buffer,
|
|
216
|
+
cfg.batch_size,
|
|
217
|
+
flatten_tensordicts=False,
|
|
218
|
+
memmap=False,
|
|
219
|
+
device=device,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
trainer.register_op("batch_process", rb_trainer.extend)
|
|
223
|
+
trainer.register_op("process_optim_batch", rb_trainer.sample)
|
|
224
|
+
trainer.register_op("post_loss", rb_trainer.update_priority)
|
|
225
|
+
else:
|
|
226
|
+
# trainer.register_op("batch_process", mask_batch)
|
|
227
|
+
trainer.register_op(
|
|
228
|
+
"process_optim_batch",
|
|
229
|
+
BatchSubSampler(batch_size=cfg.batch_size, sub_traj_len=cfg.sub_traj_len),
|
|
230
|
+
)
|
|
231
|
+
trainer.register_op("process_optim_batch", lambda batch: batch.to(device))
|
|
232
|
+
|
|
233
|
+
if optim_scheduler is not None:
|
|
234
|
+
trainer.register_op("post_optim", optim_scheduler.step)
|
|
235
|
+
|
|
236
|
+
if target_net_updater is not None:
|
|
237
|
+
trainer.register_op("post_optim", target_net_updater.step)
|
|
238
|
+
|
|
239
|
+
if cfg.normalize_rewards_online:
|
|
240
|
+
# if used the running statistics of the rewards are computed and the
|
|
241
|
+
# rewards used for training will be normalized based on these.
|
|
242
|
+
reward_normalizer = RewardNormalizer(
|
|
243
|
+
scale=cfg.normalize_rewards_online_scale,
|
|
244
|
+
decay=cfg.normalize_rewards_online_decay,
|
|
245
|
+
)
|
|
246
|
+
trainer.register_op("batch_process", reward_normalizer.update_reward_stats)
|
|
247
|
+
trainer.register_op("process_optim_batch", reward_normalizer.normalize_reward)
|
|
248
|
+
|
|
249
|
+
if policy_exploration is not None and hasattr(policy_exploration, "step"):
|
|
250
|
+
trainer.register_op(
|
|
251
|
+
"post_steps", policy_exploration.step, frames=cfg.frames_per_batch
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
trainer.register_op(
|
|
255
|
+
"post_steps_log", lambda *cfg: {"lr": optimizer.param_groups[0]["lr"]}
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
if recorder is not None:
|
|
259
|
+
# create recorder object
|
|
260
|
+
recorder_obj = LogValidationReward(
|
|
261
|
+
record_frames=cfg.record_frames,
|
|
262
|
+
frame_skip=cfg.frame_skip,
|
|
263
|
+
policy_exploration=policy_exploration,
|
|
264
|
+
environment=recorder,
|
|
265
|
+
record_interval=cfg.record_interval,
|
|
266
|
+
log_keys=cfg.recorder_log_keys,
|
|
267
|
+
)
|
|
268
|
+
# register recorder
|
|
269
|
+
trainer.register_op(
|
|
270
|
+
"post_steps_log",
|
|
271
|
+
recorder_obj,
|
|
272
|
+
)
|
|
273
|
+
# call recorder - could be removed
|
|
274
|
+
recorder_obj(None)
|
|
275
|
+
# create explorative recorder - could be optional
|
|
276
|
+
recorder_obj_explore = LogValidationReward(
|
|
277
|
+
record_frames=cfg.record_frames,
|
|
278
|
+
frame_skip=cfg.frame_skip,
|
|
279
|
+
policy_exploration=policy_exploration,
|
|
280
|
+
environment=recorder,
|
|
281
|
+
record_interval=cfg.record_interval,
|
|
282
|
+
exploration_type=ExplorationType.RANDOM,
|
|
283
|
+
suffix="exploration",
|
|
284
|
+
out_keys={("next", "reward"): "r_evaluation_exploration"},
|
|
285
|
+
)
|
|
286
|
+
# register recorder
|
|
287
|
+
trainer.register_op(
|
|
288
|
+
"post_steps_log",
|
|
289
|
+
recorder_obj_explore,
|
|
290
|
+
)
|
|
291
|
+
# call recorder - could be removed
|
|
292
|
+
recorder_obj_explore(None)
|
|
293
|
+
|
|
294
|
+
trainer.register_op(
|
|
295
|
+
"post_steps", UpdateWeights(collector, update_weights_interval=1)
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
trainer.register_op("pre_steps_log", LogScalar())
|
|
299
|
+
trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.frame_skip))
|
|
300
|
+
|
|
301
|
+
return trainer
|