torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from tensordict.nn import TensorDictModule, TensorDictSequential
|
|
11
|
+
|
|
12
|
+
from torch import nn, optim
|
|
13
|
+
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
|
|
14
|
+
from torchrl.data.replay_buffers import SamplerWithoutReplacement
|
|
15
|
+
from torchrl.envs import (
|
|
16
|
+
CatTensors,
|
|
17
|
+
Compose,
|
|
18
|
+
DMControlEnv,
|
|
19
|
+
DoubleToFloat,
|
|
20
|
+
EnvCreator,
|
|
21
|
+
InitTracker,
|
|
22
|
+
ParallelEnv,
|
|
23
|
+
RewardSum,
|
|
24
|
+
StepCounter,
|
|
25
|
+
TransformedEnv,
|
|
26
|
+
)
|
|
27
|
+
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
|
|
28
|
+
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
29
|
+
from torchrl.modules import AdditiveGaussianModule, MLP, TanhModule, ValueOperator
|
|
30
|
+
|
|
31
|
+
from torchrl.objectives import SoftUpdate
|
|
32
|
+
from torchrl.objectives.td3_bc import TD3BCLoss
|
|
33
|
+
from torchrl.record import VideoRecorder
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# ====================================================================
|
|
37
|
+
# Environment utils
|
|
38
|
+
# -----------------
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def env_maker(cfg, device="cpu", from_pixels=False):
|
|
42
|
+
lib = cfg.env.library
|
|
43
|
+
if lib in ("gym", "gymnasium"):
|
|
44
|
+
with set_gym_backend(lib):
|
|
45
|
+
return GymEnv(
|
|
46
|
+
cfg.env.name,
|
|
47
|
+
device=device,
|
|
48
|
+
from_pixels=from_pixels,
|
|
49
|
+
pixels_only=False,
|
|
50
|
+
)
|
|
51
|
+
elif lib == "dm_control":
|
|
52
|
+
env = DMControlEnv(
|
|
53
|
+
cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False
|
|
54
|
+
)
|
|
55
|
+
return TransformedEnv(
|
|
56
|
+
env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
|
|
57
|
+
)
|
|
58
|
+
else:
|
|
59
|
+
raise NotImplementedError(f"Unknown lib {lib}.")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def apply_env_transforms(env, max_episode_steps):
|
|
63
|
+
transformed_env = TransformedEnv(
|
|
64
|
+
env,
|
|
65
|
+
Compose(
|
|
66
|
+
StepCounter(max_steps=max_episode_steps),
|
|
67
|
+
InitTracker(),
|
|
68
|
+
DoubleToFloat(),
|
|
69
|
+
RewardSum(),
|
|
70
|
+
),
|
|
71
|
+
)
|
|
72
|
+
return transformed_env
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def make_environment(cfg, logger=None):
|
|
76
|
+
"""Make environments for training and evaluation."""
|
|
77
|
+
partial = functools.partial(env_maker, cfg=cfg)
|
|
78
|
+
parallel_env = ParallelEnv(
|
|
79
|
+
cfg.logger.eval_envs,
|
|
80
|
+
EnvCreator(partial),
|
|
81
|
+
serial_for_single=True,
|
|
82
|
+
)
|
|
83
|
+
parallel_env.set_seed(cfg.env.seed)
|
|
84
|
+
|
|
85
|
+
train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps)
|
|
86
|
+
return train_env
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# ====================================================================
|
|
90
|
+
# Replay buffer
|
|
91
|
+
# ---------------------------
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def make_offline_replay_buffer(rb_cfg, device):
|
|
95
|
+
data = D4RLExperienceReplay(
|
|
96
|
+
dataset_id=rb_cfg.dataset,
|
|
97
|
+
split_trajs=False,
|
|
98
|
+
batch_size=rb_cfg.batch_size,
|
|
99
|
+
# drop_last for compile
|
|
100
|
+
sampler=SamplerWithoutReplacement(drop_last=True),
|
|
101
|
+
prefetch=4,
|
|
102
|
+
direct_download=True,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
data.append_transform(DoubleToFloat())
|
|
106
|
+
data.append_transform(lambda td: td.to(device))
|
|
107
|
+
|
|
108
|
+
return data
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
# ====================================================================
|
|
112
|
+
# Model
|
|
113
|
+
# -----
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def make_td3_agent(cfg, train_env, device):
|
|
117
|
+
"""Make TD3 agent."""
|
|
118
|
+
# Define Actor Network
|
|
119
|
+
in_keys = ["observation"]
|
|
120
|
+
action_spec = train_env.action_spec_unbatched.to(device)
|
|
121
|
+
|
|
122
|
+
actor_net = MLP(
|
|
123
|
+
num_cells=cfg.network.hidden_sizes,
|
|
124
|
+
out_features=action_spec.shape[-1],
|
|
125
|
+
activation_class=get_activation(cfg),
|
|
126
|
+
device=device,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
in_keys_actor = in_keys
|
|
130
|
+
actor_module = TensorDictModule(
|
|
131
|
+
actor_net,
|
|
132
|
+
in_keys=in_keys_actor,
|
|
133
|
+
out_keys=["param"],
|
|
134
|
+
)
|
|
135
|
+
actor = TensorDictSequential(
|
|
136
|
+
actor_module,
|
|
137
|
+
TanhModule(
|
|
138
|
+
in_keys=["param"],
|
|
139
|
+
out_keys=["action"],
|
|
140
|
+
spec=action_spec,
|
|
141
|
+
),
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Define Critic Network
|
|
145
|
+
qvalue_net = MLP(
|
|
146
|
+
num_cells=cfg.network.hidden_sizes,
|
|
147
|
+
out_features=1,
|
|
148
|
+
activation_class=get_activation(cfg),
|
|
149
|
+
device=device,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
qvalue = ValueOperator(
|
|
153
|
+
in_keys=["action"] + in_keys,
|
|
154
|
+
module=qvalue_net,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
model = nn.ModuleList([actor, qvalue])
|
|
158
|
+
|
|
159
|
+
# init nets
|
|
160
|
+
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
|
|
161
|
+
td = train_env.fake_tensordict()
|
|
162
|
+
td = td.to(device)
|
|
163
|
+
for net in model:
|
|
164
|
+
net(td)
|
|
165
|
+
del td
|
|
166
|
+
|
|
167
|
+
# Exploration wrappers:
|
|
168
|
+
actor_model_explore = TensorDictSequential(
|
|
169
|
+
model[0],
|
|
170
|
+
AdditiveGaussianModule(
|
|
171
|
+
sigma_init=1,
|
|
172
|
+
sigma_end=1,
|
|
173
|
+
mean=0,
|
|
174
|
+
std=0.1,
|
|
175
|
+
spec=action_spec,
|
|
176
|
+
device=device,
|
|
177
|
+
),
|
|
178
|
+
)
|
|
179
|
+
return model, actor_model_explore
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
# ====================================================================
|
|
183
|
+
# TD3 Loss
|
|
184
|
+
# ---------
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def make_loss_module(cfg, model):
|
|
188
|
+
"""Make loss module and target network updater."""
|
|
189
|
+
# Create TD3 loss
|
|
190
|
+
loss_module = TD3BCLoss(
|
|
191
|
+
actor_network=model[0],
|
|
192
|
+
qvalue_network=model[1],
|
|
193
|
+
num_qvalue_nets=2,
|
|
194
|
+
loss_function=cfg.loss_function,
|
|
195
|
+
delay_actor=True,
|
|
196
|
+
delay_qvalue=True,
|
|
197
|
+
action_spec=model[0][1].spec,
|
|
198
|
+
policy_noise=cfg.policy_noise,
|
|
199
|
+
noise_clip=cfg.noise_clip,
|
|
200
|
+
alpha=cfg.alpha,
|
|
201
|
+
)
|
|
202
|
+
loss_module.make_value_estimator(gamma=cfg.gamma)
|
|
203
|
+
|
|
204
|
+
# Define Target Network Updater
|
|
205
|
+
target_net_updater = SoftUpdate(loss_module, eps=cfg.target_update_polyak)
|
|
206
|
+
return loss_module, target_net_updater
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def make_optimizer(cfg, loss_module):
|
|
210
|
+
critic_params = list(loss_module.qvalue_network_params.values(True, True))
|
|
211
|
+
actor_params = list(loss_module.actor_network_params.values(True, True))
|
|
212
|
+
|
|
213
|
+
optimizer_actor = optim.Adam(
|
|
214
|
+
actor_params,
|
|
215
|
+
lr=cfg.lr,
|
|
216
|
+
weight_decay=cfg.weight_decay,
|
|
217
|
+
eps=cfg.adam_eps,
|
|
218
|
+
)
|
|
219
|
+
optimizer_critic = optim.Adam(
|
|
220
|
+
critic_params,
|
|
221
|
+
lr=cfg.lr,
|
|
222
|
+
weight_decay=cfg.weight_decay,
|
|
223
|
+
eps=cfg.adam_eps,
|
|
224
|
+
)
|
|
225
|
+
return optimizer_actor, optimizer_critic
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
# ====================================================================
|
|
229
|
+
# General utils
|
|
230
|
+
# ---------
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def log_metrics(logger, metrics, step):
|
|
234
|
+
for metric_name, metric_value in metrics.items():
|
|
235
|
+
logger.log_scalar(metric_name, metric_value, step)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def get_activation(cfg):
|
|
239
|
+
if cfg.network.activation == "relu":
|
|
240
|
+
return nn.ReLU
|
|
241
|
+
elif cfg.network.activation == "tanh":
|
|
242
|
+
return nn.Tanh
|
|
243
|
+
elif cfg.network.activation == "leaky_relu":
|
|
244
|
+
return nn.LeakyReLU
|
|
245
|
+
else:
|
|
246
|
+
raise NotImplementedError
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def dump_video(module):
|
|
250
|
+
if isinstance(module, VideoRecorder):
|
|
251
|
+
module.dump()
|
torchrl/__init__.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
import warnings
|
|
6
|
+
import weakref
|
|
7
|
+
from warnings import warn
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
# Silence noisy dependency warning triggered at import time on older torch stacks.
|
|
12
|
+
# (Emitted by tensordict when registering pytree nodes.)
|
|
13
|
+
warnings.filterwarnings(
|
|
14
|
+
"ignore",
|
|
15
|
+
category=UserWarning,
|
|
16
|
+
message=r"torch\.utils\._pytree\._register_pytree_node is deprecated\.",
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from tensordict import set_lazy_legacy # noqa: E402
|
|
20
|
+
|
|
21
|
+
from torch import multiprocessing as mp # noqa: E402
|
|
22
|
+
from torch.distributions.transforms import ( # noqa: E402
|
|
23
|
+
_InverseTransform,
|
|
24
|
+
ComposeTransform,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
torch._C._log_api_usage_once("torchrl")
|
|
28
|
+
|
|
29
|
+
set_lazy_legacy(False).set()
|
|
30
|
+
|
|
31
|
+
from ._extension import _init_extension # noqa: E402
|
|
32
|
+
|
|
33
|
+
__version__ = None # type: ignore
|
|
34
|
+
try:
|
|
35
|
+
try:
|
|
36
|
+
from importlib.metadata import version as _dist_version
|
|
37
|
+
except ImportError: # pragma: no cover
|
|
38
|
+
from importlib_metadata import version as _dist_version # type: ignore
|
|
39
|
+
|
|
40
|
+
__version__ = _dist_version("torchrl")
|
|
41
|
+
except Exception:
|
|
42
|
+
try:
|
|
43
|
+
from ._version import __version__
|
|
44
|
+
except Exception:
|
|
45
|
+
try:
|
|
46
|
+
from .version import __version__
|
|
47
|
+
except Exception:
|
|
48
|
+
__version__ = None # type: ignore
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
from torch.compiler import is_dynamo_compiling
|
|
52
|
+
except ImportError:
|
|
53
|
+
from torch._dynamo import is_compiling as is_dynamo_compiling
|
|
54
|
+
|
|
55
|
+
_init_extension()
|
|
56
|
+
|
|
57
|
+
from torchrl._utils import ( # noqa: E402
|
|
58
|
+
_get_default_mp_start_method,
|
|
59
|
+
auto_unwrap_transformed_env,
|
|
60
|
+
compile_with_warmup,
|
|
61
|
+
get_ray_default_runtime_env,
|
|
62
|
+
implement_for,
|
|
63
|
+
logger,
|
|
64
|
+
merge_ray_runtime_env,
|
|
65
|
+
set_auto_unwrap_transformed_env,
|
|
66
|
+
set_profiling_enabled,
|
|
67
|
+
timeit,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
logger = logger
|
|
71
|
+
|
|
72
|
+
# TorchRL's multiprocessing default.
|
|
73
|
+
_preferred_start_method = _get_default_mp_start_method()
|
|
74
|
+
if _preferred_start_method == "spawn":
|
|
75
|
+
try:
|
|
76
|
+
mp.set_start_method("spawn")
|
|
77
|
+
except RuntimeError as err:
|
|
78
|
+
if str(err).startswith("context has already been set"):
|
|
79
|
+
mp_start_method = mp.get_start_method()
|
|
80
|
+
if mp_start_method != "spawn":
|
|
81
|
+
warn(
|
|
82
|
+
f"failed to set start method to spawn, "
|
|
83
|
+
f"and current start method for mp is {mp_start_method}."
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Filter warnings in subprocesses: True by default given the multiple optional
|
|
87
|
+
# deps of the library. This can be turned on via `torchrl.filter_warnings_subprocess = False`.
|
|
88
|
+
filter_warnings_subprocess = True
|
|
89
|
+
|
|
90
|
+
_THREAD_POOL_INIT = torch.get_num_threads()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# monkey-patch dist transforms until https://github.com/pytorch/pytorch/pull/135001/ finds a home
|
|
94
|
+
@property
|
|
95
|
+
def _inv(self):
|
|
96
|
+
"""Patched version of Transform.inv.
|
|
97
|
+
|
|
98
|
+
Returns the inverse :class:`Transform` of this transform.
|
|
99
|
+
|
|
100
|
+
This should satisfy ``t.inv.inv is t``.
|
|
101
|
+
"""
|
|
102
|
+
inv = None
|
|
103
|
+
if self._inv is not None:
|
|
104
|
+
inv = self._inv()
|
|
105
|
+
if inv is None:
|
|
106
|
+
inv = _InverseTransform(self)
|
|
107
|
+
if not is_dynamo_compiling():
|
|
108
|
+
self._inv = weakref.ref(inv)
|
|
109
|
+
return inv
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
torch.distributions.transforms.Transform.inv = _inv
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def _inv(self):
|
|
117
|
+
inv = None
|
|
118
|
+
if self._inv is not None:
|
|
119
|
+
inv = self._inv()
|
|
120
|
+
if inv is None:
|
|
121
|
+
inv = ComposeTransform([p.inv for p in reversed(self.parts)])
|
|
122
|
+
if not is_dynamo_compiling():
|
|
123
|
+
self._inv = weakref.ref(inv)
|
|
124
|
+
inv._inv = weakref.ref(self)
|
|
125
|
+
else:
|
|
126
|
+
# We need inv.inv to be equal to self, but weakref can cause a graph break
|
|
127
|
+
inv._inv = lambda out=self: out
|
|
128
|
+
|
|
129
|
+
return inv
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
ComposeTransform.inv = _inv
|
|
133
|
+
|
|
134
|
+
__all__ = [
|
|
135
|
+
"auto_unwrap_transformed_env",
|
|
136
|
+
"compile_with_warmup",
|
|
137
|
+
"get_ray_default_runtime_env",
|
|
138
|
+
"implement_for",
|
|
139
|
+
"merge_ray_runtime_env",
|
|
140
|
+
"set_auto_unwrap_transformed_env",
|
|
141
|
+
"timeit",
|
|
142
|
+
"logger",
|
|
143
|
+
"logger",
|
|
144
|
+
]
|
torchrl/_extension.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib.util
|
|
8
|
+
import warnings
|
|
9
|
+
|
|
10
|
+
from packaging.version import parse
|
|
11
|
+
|
|
12
|
+
__version__ = None # type: ignore
|
|
13
|
+
try:
|
|
14
|
+
try:
|
|
15
|
+
from importlib.metadata import version as _dist_version
|
|
16
|
+
except ImportError: # pragma: no cover
|
|
17
|
+
from importlib_metadata import version as _dist_version # type: ignore
|
|
18
|
+
|
|
19
|
+
__version__ = _dist_version("torchrl")
|
|
20
|
+
except Exception:
|
|
21
|
+
__version__ = None # type: ignore
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
from .version import pytorch_version
|
|
25
|
+
except ImportError:
|
|
26
|
+
pytorch_version = "unknown"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def is_module_available(*modules: str) -> bool:
|
|
30
|
+
"""Returns if a top-level module with :attr:`name` exists *without** importing it.
|
|
31
|
+
|
|
32
|
+
This is generally safer than try-catch block around a
|
|
33
|
+
`import X`. It avoids third party libraries breaking assumptions of some of
|
|
34
|
+
our tests, e.g., setting multiprocessing start method when imported
|
|
35
|
+
(see librosa/#747, torchvision/#544).
|
|
36
|
+
"""
|
|
37
|
+
return all(importlib.util.find_spec(m) is not None for m in modules)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _init_extension():
|
|
41
|
+
if not is_module_available("torchrl._torchrl"):
|
|
42
|
+
warnings.warn("torchrl C++ extension is not available.")
|
|
43
|
+
return
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _is_nightly(version):
|
|
47
|
+
if version is None:
|
|
48
|
+
return True
|
|
49
|
+
parsed_version = parse(version)
|
|
50
|
+
return parsed_version.local is not None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
if _is_nightly(__version__):
|
|
54
|
+
EXTENSION_WARNING = (
|
|
55
|
+
"Failed to import torchrl C++ binaries. Some modules (eg, prioritized replay buffers) may not work with your installation. "
|
|
56
|
+
"You seem to be using the nightly version of TorchRL. If this is a local install, there might be an issue with "
|
|
57
|
+
"the local installation. Here are some tips to debug this:\n"
|
|
58
|
+
" - make sure ninja and cmake were installed\n"
|
|
59
|
+
" - make sure you ran `python setup.py clean && python setup.py develop` and that no error was raised\n"
|
|
60
|
+
" - make sure the version of PyTorch you are using matches the one that was present in your virtual env during "
|
|
61
|
+
f"setup. This package was built with PyTorch {pytorch_version}. You can deactivate this warning by setting the environment variable `RL_WARNINGS=0`."
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
else:
|
|
65
|
+
EXTENSION_WARNING = (
|
|
66
|
+
"Failed to import torchrl C++ binaries. Some modules (eg, prioritized replay buffers) may not work with your installation. "
|
|
67
|
+
"This is likely due to a discrepancy between your package version and the PyTorch version. "
|
|
68
|
+
"TorchRL does not tightly pin PyTorch versions to give users freedom, but the trade-off is that C++ extensions like "
|
|
69
|
+
"prioritized replay buffers can only be used with the PyTorch version they were built against. "
|
|
70
|
+
f"This package was built with PyTorch {pytorch_version}. "
|
|
71
|
+
"Workarounds include: (1) upgrading/downgrading PyTorch or TorchRL to compatible versions, "
|
|
72
|
+
"or (2) making a local install using `pip install git+https://github.com/pytorch/rl.git@<version>`. "
|
|
73
|
+
"You can deactivate this warning by setting the environment variable `RL_WARNINGS=0`."
|
|
74
|
+
)
|
|
Binary file
|