torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,1308 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torchrl
|
|
3
|
+
Version: 0.11.0
|
|
4
|
+
Summary: A modular, primitive-first, python-first PyTorch library for Reinforcement Learning
|
|
5
|
+
Author-email: torchrl contributors <vmoens@fb.com>
|
|
6
|
+
Maintainer-email: torchrl contributors <vmoens@fb.com>
|
|
7
|
+
Project-URL: Homepage, https://github.com/pytorch/rl
|
|
8
|
+
Project-URL: Documentation, https://pytorch.org/rl
|
|
9
|
+
Project-URL: Repository, https://github.com/pytorch/rl
|
|
10
|
+
Project-URL: Bug Tracker, https://github.com/pytorch/rl/issues
|
|
11
|
+
Project-URL: twitter, https://x.com/torchrl1
|
|
12
|
+
Project-URL: linkedin, https://www.linkedin.com/company/torchrl
|
|
13
|
+
Project-URL: discord, https://discord.gg/cZs26Qq3Dd
|
|
14
|
+
Project-URL: benchmark, https://docs.pytorch.org/rl/dev/bench/
|
|
15
|
+
Keywords: reinforcement-learning,pytorch,rl,machine-learning
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
21
|
+
Classifier: Operating System :: OS Independent
|
|
22
|
+
Classifier: Development Status :: 4 - Beta
|
|
23
|
+
Classifier: Intended Audience :: Developers
|
|
24
|
+
Classifier: Intended Audience :: Science/Research
|
|
25
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
26
|
+
Requires-Python: >=3.10
|
|
27
|
+
Description-Content-Type: text/markdown
|
|
28
|
+
License-File: LICENSE
|
|
29
|
+
Requires-Dist: torch>=2.1.0
|
|
30
|
+
Requires-Dist: pyvers
|
|
31
|
+
Requires-Dist: numpy
|
|
32
|
+
Requires-Dist: packaging
|
|
33
|
+
Requires-Dist: cloudpickle
|
|
34
|
+
Requires-Dist: tensordict<0.12.0,>=0.11.0
|
|
35
|
+
Provides-Extra: atari
|
|
36
|
+
Requires-Dist: gymnasium[atari]; extra == "atari"
|
|
37
|
+
Provides-Extra: dm-control
|
|
38
|
+
Requires-Dist: dm_control; extra == "dm-control"
|
|
39
|
+
Provides-Extra: replay-buffer
|
|
40
|
+
Requires-Dist: torch>=2.7.0; extra == "replay-buffer"
|
|
41
|
+
Provides-Extra: gym-continuous
|
|
42
|
+
Requires-Dist: gymnasium<1.0; extra == "gym-continuous"
|
|
43
|
+
Requires-Dist: mujoco; extra == "gym-continuous"
|
|
44
|
+
Provides-Extra: rendering
|
|
45
|
+
Requires-Dist: moviepy<2.0.0; extra == "rendering"
|
|
46
|
+
Provides-Extra: tests
|
|
47
|
+
Requires-Dist: pytest; extra == "tests"
|
|
48
|
+
Requires-Dist: pyyaml; extra == "tests"
|
|
49
|
+
Requires-Dist: pytest-instafail; extra == "tests"
|
|
50
|
+
Requires-Dist: scipy; extra == "tests"
|
|
51
|
+
Requires-Dist: psutil; extra == "tests"
|
|
52
|
+
Requires-Dist: pytest-mock; extra == "tests"
|
|
53
|
+
Requires-Dist: pytest-cov; extra == "tests"
|
|
54
|
+
Requires-Dist: pytest-asyncio; extra == "tests"
|
|
55
|
+
Requires-Dist: pytest-benchmark; extra == "tests"
|
|
56
|
+
Requires-Dist: pytest-rerunfailures; extra == "tests"
|
|
57
|
+
Requires-Dist: pytest-error-for-skips; extra == "tests"
|
|
58
|
+
Requires-Dist: pytest-timeout; extra == "tests"
|
|
59
|
+
Requires-Dist: pytest-forked; extra == "tests"
|
|
60
|
+
Requires-Dist: pytest-random-order; extra == "tests"
|
|
61
|
+
Requires-Dist: pytest-repeat; extra == "tests"
|
|
62
|
+
Requires-Dist: pytest-isolate; extra == "tests"
|
|
63
|
+
Provides-Extra: utils
|
|
64
|
+
Requires-Dist: tensorboard; extra == "utils"
|
|
65
|
+
Requires-Dist: wandb; extra == "utils"
|
|
66
|
+
Requires-Dist: tqdm; extra == "utils"
|
|
67
|
+
Requires-Dist: hydra-core>=1.1; extra == "utils"
|
|
68
|
+
Requires-Dist: hydra-submitit-launcher; extra == "utils"
|
|
69
|
+
Provides-Extra: checkpointing
|
|
70
|
+
Requires-Dist: torchsnapshot; extra == "checkpointing"
|
|
71
|
+
Provides-Extra: offline-data
|
|
72
|
+
Requires-Dist: huggingface_hub; extra == "offline-data"
|
|
73
|
+
Requires-Dist: minari; extra == "offline-data"
|
|
74
|
+
Requires-Dist: requests; extra == "offline-data"
|
|
75
|
+
Requires-Dist: tqdm; extra == "offline-data"
|
|
76
|
+
Requires-Dist: torchvision; extra == "offline-data"
|
|
77
|
+
Requires-Dist: scikit-learn; extra == "offline-data"
|
|
78
|
+
Requires-Dist: pandas; extra == "offline-data"
|
|
79
|
+
Requires-Dist: h5py; extra == "offline-data"
|
|
80
|
+
Requires-Dist: pillow; extra == "offline-data"
|
|
81
|
+
Provides-Extra: marl
|
|
82
|
+
Requires-Dist: vmas>=1.2.10; extra == "marl"
|
|
83
|
+
Requires-Dist: pettingzoo>=1.24.1; extra == "marl"
|
|
84
|
+
Requires-Dist: dm-meltingpot; python_version >= "3.11" and extra == "marl"
|
|
85
|
+
Provides-Extra: open-spiel
|
|
86
|
+
Requires-Dist: open_spiel>=1.5; extra == "open-spiel"
|
|
87
|
+
Provides-Extra: brax
|
|
88
|
+
Requires-Dist: jax>=0.7.0; python_version >= "3.11" and extra == "brax"
|
|
89
|
+
Requires-Dist: brax; python_version >= "3.11" and extra == "brax"
|
|
90
|
+
Provides-Extra: procgen
|
|
91
|
+
Requires-Dist: procgen; extra == "procgen"
|
|
92
|
+
Provides-Extra: llm
|
|
93
|
+
Requires-Dist: transformers; extra == "llm"
|
|
94
|
+
Requires-Dist: vllm; extra == "llm"
|
|
95
|
+
Requires-Dist: playwright; extra == "llm"
|
|
96
|
+
Requires-Dist: datasets; extra == "llm"
|
|
97
|
+
Requires-Dist: langdetect; extra == "llm"
|
|
98
|
+
Requires-Dist: nltk; extra == "llm"
|
|
99
|
+
Requires-Dist: immutabledict; extra == "llm"
|
|
100
|
+
Requires-Dist: accelerate; extra == "llm"
|
|
101
|
+
Requires-Dist: sentencepiece; extra == "llm"
|
|
102
|
+
Requires-Dist: protobuf; extra == "llm"
|
|
103
|
+
Requires-Dist: einops; extra == "llm"
|
|
104
|
+
Requires-Dist: safetensors; extra == "llm"
|
|
105
|
+
Provides-Extra: grpo
|
|
106
|
+
Requires-Dist: datasets; extra == "grpo"
|
|
107
|
+
Requires-Dist: peft; extra == "grpo"
|
|
108
|
+
Requires-Dist: wandb; extra == "grpo"
|
|
109
|
+
Requires-Dist: vllm; extra == "grpo"
|
|
110
|
+
Requires-Dist: transformers; extra == "grpo"
|
|
111
|
+
Requires-Dist: accelerate; extra == "grpo"
|
|
112
|
+
Requires-Dist: ray; extra == "grpo"
|
|
113
|
+
Requires-Dist: tqdm; extra == "grpo"
|
|
114
|
+
Requires-Dist: flash-attn; extra == "grpo"
|
|
115
|
+
Requires-Dist: bitsandbytes; extra == "grpo"
|
|
116
|
+
Requires-Dist: xformers; extra == "grpo"
|
|
117
|
+
Requires-Dist: nltk; extra == "grpo"
|
|
118
|
+
Requires-Dist: langdetect; extra == "grpo"
|
|
119
|
+
Requires-Dist: immutabledict; extra == "grpo"
|
|
120
|
+
Provides-Extra: dev
|
|
121
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
122
|
+
Requires-Dist: autoflake; extra == "dev"
|
|
123
|
+
Dynamic: license-file
|
|
124
|
+
|
|
125
|
+
[](https://github.com/pytorch/rl/actions/workflows/test-linux.yml)
|
|
126
|
+
[](https://pytorch.org/rl/)
|
|
127
|
+
[](https://pytorch.github.io/rl/dev/bench/)
|
|
128
|
+
[](https://codecov.io/gh/pytorch/rl)
|
|
129
|
+
[](https://twitter.com/torchrl1)
|
|
130
|
+
[](https://www.python.org/downloads/)
|
|
131
|
+
[](https://github.com/pytorch/rl/blob/main/LICENSE)
|
|
132
|
+
<a href="https://pypi.org/project/torchrl"><img src="https://img.shields.io/pypi/v/torchrl" alt="pypi version"></a>
|
|
133
|
+
<a href="https://pypi.org/project/torchrl-nightly"><img src="https://img.shields.io/pypi/v/torchrl-nightly?label=nightly" alt="pypi nightly version"></a>
|
|
134
|
+
[](https://pepy.tech/project/torchrl)
|
|
135
|
+
[)](https://pepy.tech/project/torchrl-nightly)
|
|
136
|
+
[](https://discord.gg/cZs26Qq3Dd)
|
|
137
|
+
|
|
138
|
+
# TorchRL
|
|
139
|
+
|
|
140
|
+
<p align="center">
|
|
141
|
+
<img src="docs/source/_static/img/icon.png" width="200" >
|
|
142
|
+
</p>
|
|
143
|
+
|
|
144
|
+
[**What's New**](#-whats-new) | [**LLM API**](#llm-api---complete-framework-for-language-model-fine-tuning) | [**Getting Started**](#getting-started) | [**Documentation**](#documentation-and-knowledge-base) | [**TensorDict**](#writing-simplified-and-portable-rl-codebase-with-tensordict) |
|
|
145
|
+
[**Features**](#features) | [**Examples, tutorials and demos**](#examples-tutorials-and-demos) | [**Citation**](#citation) | [**Installation**](#installation) |
|
|
146
|
+
[**Asking a question**](#asking-a-question) | [**Contributing**](#contributing)
|
|
147
|
+
|
|
148
|
+
**TorchRL** is an open-source Reinforcement Learning (RL) library for PyTorch.
|
|
149
|
+
|
|
150
|
+
## 🚀 What's New
|
|
151
|
+
|
|
152
|
+
### 🚀 **Command-Line Training Interface** - Train RL Agents Without Writing Code! (Experimental)
|
|
153
|
+
|
|
154
|
+
TorchRL now provides a **powerful command-line interface** that lets you train state-of-the-art RL agents with simple bash commands! No Python scripting required - just run training with customizable parameters:
|
|
155
|
+
|
|
156
|
+
- 🎯 **One-Command Training**: `python sota-implementations/ppo_trainer/train.py`
|
|
157
|
+
- ⚙️ **Full Customization**: Override any parameter via command line: `trainer.total_frames=2000000 optimizer.lr=0.0003`
|
|
158
|
+
- 🌍 **Multi-Environment Support**: Switch between Gym, Brax, DM Control, and more with `env=gym training_env.create_env_fn.base_env.env_name=HalfCheetah-v4`
|
|
159
|
+
- 📊 **Built-in Logging**: TensorBoard, Weights & Biases, CSV logging out of the box
|
|
160
|
+
- 🔧 **Hydra-Powered**: Leverages Hydra's powerful configuration system for maximum flexibility
|
|
161
|
+
- 🏃♂️ **Production Ready**: Same robust training pipeline as our SOTA implementations
|
|
162
|
+
|
|
163
|
+
**Perfect for**: Researchers, practitioners, and anyone who wants to train RL agents without diving into implementation details.
|
|
164
|
+
|
|
165
|
+
⚠️ **Note**: This is an experimental feature. The API may change in future versions. We welcome feedback and contributions to help improve this implementation!
|
|
166
|
+
|
|
167
|
+
📋 **Prerequisites**: The training interface requires Hydra for configuration management. Install with:
|
|
168
|
+
```bash
|
|
169
|
+
pip install "torchrl[utils]"
|
|
170
|
+
# or manually:
|
|
171
|
+
pip install hydra-core omegaconf
|
|
172
|
+
```
|
|
173
|
+
|
|
174
|
+
Check out the [complete CLI documentation](https://github.com/pytorch/rl/tree/main/sota-implementations/ppo_trainer) to get started!
|
|
175
|
+
|
|
176
|
+
### 🚀 **vLLM Revamp** - Major Enhancement to LLM Infrastructure (v0.10)
|
|
177
|
+
|
|
178
|
+
This release introduces a comprehensive revamp of TorchRL's vLLM integration, delivering significant improvements in performance, scalability, and usability for large language model inference and training workflows:
|
|
179
|
+
|
|
180
|
+
- 🔥 **AsyncVLLM Service**: Production-ready distributed vLLM inference with multi-replica scaling and automatic Ray actor management
|
|
181
|
+
- ⚖️ **Multiple Load Balancing Strategies**: Routing strategies including prefix-aware, request-based, and KV-cache load balancing for optimal performance
|
|
182
|
+
- 🏗️ **Unified vLLM Architecture**: New `RLvLLMEngine` interface standardizing all vLLM backends with simplified `vLLMUpdaterV2` for seamless weight updates
|
|
183
|
+
- 🌐 **Distributed Data Loading**: New `RayDataLoadingPrimer` for shared, distributed data loading across multiple environments
|
|
184
|
+
- 📈 **Enhanced Performance**: Native vLLM batching, concurrent request processing, and optimized resource allocation via Ray placement groups
|
|
185
|
+
|
|
186
|
+
```python
|
|
187
|
+
# Simple AsyncVLLM usage - production ready!
|
|
188
|
+
from torchrl.modules.llm import AsyncVLLM, vLLMWrapper
|
|
189
|
+
|
|
190
|
+
# Create distributed vLLM service with load balancing
|
|
191
|
+
service = AsyncVLLM.from_pretrained(
|
|
192
|
+
"Qwen/Qwen2.5-7B",
|
|
193
|
+
num_devices=2, # Tensor parallel across 2 GPUs
|
|
194
|
+
num_replicas=4, # 4 replicas for high throughput
|
|
195
|
+
max_model_len=4096
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Use with TorchRL's LLM wrappers
|
|
199
|
+
wrapper = vLLMWrapper(service, input_mode="history")
|
|
200
|
+
|
|
201
|
+
# Simplified weight updates
|
|
202
|
+
from torchrl.collectors.llm import vLLMUpdaterV2
|
|
203
|
+
updater = vLLMUpdaterV2(service) # Auto-configures from engine
|
|
204
|
+
```
|
|
205
|
+
|
|
206
|
+
This revamp positions TorchRL as the leading platform for scalable LLM inference and training, providing production-ready tools for both research and deployment scenarios.
|
|
207
|
+
|
|
208
|
+
### 🧪 PPOTrainer (Experimental) - High-Level Training Interface
|
|
209
|
+
|
|
210
|
+
TorchRL now includes an **experimental PPOTrainer** that provides a complete, configurable PPO training solution! This prototype feature combines TorchRL's modular components into a cohesive training system with sensible defaults:
|
|
211
|
+
|
|
212
|
+
- 🎯 **Complete Training Pipeline**: Handles environment setup, data collection, loss computation, and optimization automatically
|
|
213
|
+
- ⚙️ **Extensive Configuration**: Comprehensive Hydra-based config system for easy experimentation and hyperparameter tuning
|
|
214
|
+
- 📊 **Built-in Logging**: Automatic tracking of rewards, actions, episode completion rates, and training statistics
|
|
215
|
+
- 🔧 **Modular Design**: Built on existing TorchRL components (collectors, losses, replay buffers) for maximum flexibility
|
|
216
|
+
- 📝 **Minimal Code**: Complete SOTA implementation in [just ~20 lines](sota-implementations/ppo_trainer/train.py)!
|
|
217
|
+
|
|
218
|
+
**Working Example**: See [`sota-implementations/ppo_trainer/`](sota-implementations/ppo_trainer/) for a complete, working PPO implementation that trains on Pendulum-v1 with full Hydra configuration support.
|
|
219
|
+
|
|
220
|
+
**Prerequisites**: Requires Hydra for configuration management: `pip install "torchrl[utils]"`
|
|
221
|
+
|
|
222
|
+
<details>
|
|
223
|
+
<summary>Complete Training Script (sota-implementations/ppo_trainer/train.py)</summary>
|
|
224
|
+
|
|
225
|
+
```python
|
|
226
|
+
import hydra
|
|
227
|
+
from torchrl.trainers.algorithms.configs import *
|
|
228
|
+
|
|
229
|
+
@hydra.main(config_path="config", config_name="config", version_base="1.1")
|
|
230
|
+
def main(cfg):
|
|
231
|
+
trainer = hydra.utils.instantiate(cfg.trainer)
|
|
232
|
+
trainer.train()
|
|
233
|
+
|
|
234
|
+
if __name__ == "__main__":
|
|
235
|
+
main()
|
|
236
|
+
```
|
|
237
|
+
*Complete PPO training in ~20 lines with full configurability.*
|
|
238
|
+
|
|
239
|
+
</details>
|
|
240
|
+
|
|
241
|
+
<details>
|
|
242
|
+
<summary>API Usage Examples</summary>
|
|
243
|
+
|
|
244
|
+
```bash
|
|
245
|
+
# Basic usage - train PPO on Pendulum-v1 with default settings
|
|
246
|
+
python sota-implementations/ppo_trainer/train.py
|
|
247
|
+
|
|
248
|
+
# Custom configuration with command-line overrides
|
|
249
|
+
python sota-implementations/ppo_trainer/train.py \
|
|
250
|
+
trainer.total_frames=2000000 \
|
|
251
|
+
training_env.create_env_fn.base_env.env_name=HalfCheetah-v4 \
|
|
252
|
+
networks.policy_network.num_cells=[256,256] \
|
|
253
|
+
optimizer.lr=0.0003
|
|
254
|
+
|
|
255
|
+
# Use different environment and logger
|
|
256
|
+
python sota-implementations/ppo_trainer/train.py \
|
|
257
|
+
env=gym \
|
|
258
|
+
training_env.create_env_fn.base_env.env_name=Walker2d-v4 \
|
|
259
|
+
logger=tensorboard
|
|
260
|
+
|
|
261
|
+
# See all available options
|
|
262
|
+
python sota-implementations/ppo_trainer/train.py --help
|
|
263
|
+
```
|
|
264
|
+
|
|
265
|
+
</details>
|
|
266
|
+
|
|
267
|
+
**Future Plans**: Additional algorithm trainers (SAC, TD3, DQN) and full integration of all TorchRL components within the configuration system are planned for upcoming releases.
|
|
268
|
+
|
|
269
|
+
## LLM API - Complete Framework for Language Model Fine-tuning
|
|
270
|
+
|
|
271
|
+
TorchRL includes a comprehensive **LLM API** for post-training and fine-tuning of language models! This framework provides everything you need for RLHF, supervised fine-tuning, and tool-augmented training:
|
|
272
|
+
|
|
273
|
+
- 🤖 **Unified LLM Wrappers**: Seamless integration with Hugging Face models and vLLM inference engines
|
|
274
|
+
- 💬 **Conversation Management**: Advanced [`History`](torchrl/data/llm/history.py) class for multi-turn dialogue with automatic chat template detection
|
|
275
|
+
- 🛠️ **Tool Integration**: [Built-in support](torchrl/envs/llm/transforms/) for Python code execution, function calling, and custom tool transforms
|
|
276
|
+
- 🎯 **Specialized Objectives**: [GRPO](torchrl/objectives/llm/grpo.py) (Group Relative Policy Optimization) and [SFT](torchrl/objectives/llm/sft.py) loss functions optimized for language models
|
|
277
|
+
- ⚡ **High-Performance Collectors**: [Async data collection](torchrl/collectors/llm/) with distributed training support
|
|
278
|
+
- 🔄 **Flexible Environments**: Transform-based architecture for reward computation, data loading, and conversation augmentation
|
|
279
|
+
|
|
280
|
+
The LLM API follows TorchRL's modular design principles, allowing you to mix and match components for your specific use case. Check out the [complete documentation](https://pytorch.org/rl/main/reference/llms.html) and [GRPO implementation example](https://github.com/pytorch/rl/tree/main/sota-implementations/grpo) to get started!
|
|
281
|
+
|
|
282
|
+
<details>
|
|
283
|
+
<summary>Quick LLM API Example</summary>
|
|
284
|
+
|
|
285
|
+
```python
|
|
286
|
+
from torchrl.envs.llm import ChatEnv
|
|
287
|
+
from torchrl.modules.llm import TransformersWrapper
|
|
288
|
+
from torchrl.objectives.llm import GRPOLoss
|
|
289
|
+
from torchrl.collectors.llm import LLMCollector
|
|
290
|
+
|
|
291
|
+
# Create environment with Python tool execution
|
|
292
|
+
env = ChatEnv(
|
|
293
|
+
tokenizer=tokenizer,
|
|
294
|
+
system_prompt="You are an assistant that can execute Python code.",
|
|
295
|
+
batch_size=[1]
|
|
296
|
+
).append_transform(PythonInterpreter())
|
|
297
|
+
|
|
298
|
+
# Wrap your language model
|
|
299
|
+
llm = TransformersWrapper(
|
|
300
|
+
model=model,
|
|
301
|
+
tokenizer=tokenizer,
|
|
302
|
+
input_mode="history"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# Set up GRPO training
|
|
306
|
+
loss_fn = GRPOLoss(llm, critic, gamma=0.99)
|
|
307
|
+
collector = LLMCollector(env, llm, frames_per_batch=100)
|
|
308
|
+
|
|
309
|
+
# Training loop
|
|
310
|
+
for data in collector:
|
|
311
|
+
loss = loss_fn(data)
|
|
312
|
+
loss.backward()
|
|
313
|
+
optimizer.step()
|
|
314
|
+
```
|
|
315
|
+
|
|
316
|
+
</details>
|
|
317
|
+
|
|
318
|
+
## Key features
|
|
319
|
+
|
|
320
|
+
- 🐍 **Python-first**: Designed with Python as the primary language for ease of use and flexibility
|
|
321
|
+
- ⏱️ **Efficient**: Optimized for performance to support demanding RL research applications
|
|
322
|
+
- 🧮 **Modular, customizable, extensible**: Highly modular architecture allows for easy swapping, transformation, or creation of new components
|
|
323
|
+
- 📚 **Documented**: Thorough documentation ensures that users can quickly understand and utilize the library
|
|
324
|
+
- ✅ **Tested**: Rigorously tested to ensure reliability and stability
|
|
325
|
+
- ⚙️ **Reusable functionals**: Provides a set of highly reusable functions for cost functions, returns, and data processing
|
|
326
|
+
|
|
327
|
+
### Design Principles
|
|
328
|
+
|
|
329
|
+
- 🔥 **Aligns with PyTorch ecosystem**: Follows the structure and conventions of popular PyTorch libraries
|
|
330
|
+
(e.g., dataset pillar, transforms, models, data utilities)
|
|
331
|
+
- ➖ Minimal dependencies: Only requires Python standard library, NumPy, and PyTorch; optional dependencies for
|
|
332
|
+
common environment libraries (e.g., OpenAI Gym) and datasets (D4RL, OpenX...)
|
|
333
|
+
|
|
334
|
+
Read the [full paper](https://arxiv.org/abs/2306.00577) for a more curated description of the library.
|
|
335
|
+
|
|
336
|
+
## Getting started
|
|
337
|
+
|
|
338
|
+
Check our [Getting Started tutorials](https://pytorch.org/rl/stable/index.html#getting-started) for quickly ramp up with the basic
|
|
339
|
+
features of the library!
|
|
340
|
+
|
|
341
|
+
<p align="center">
|
|
342
|
+
<img src="docs/ppo.png" width="800" >
|
|
343
|
+
</p>
|
|
344
|
+
|
|
345
|
+
## Documentation and knowledge base
|
|
346
|
+
|
|
347
|
+
The TorchRL documentation can be found [here](https://pytorch.org/rl).
|
|
348
|
+
It contains tutorials and the API reference.
|
|
349
|
+
|
|
350
|
+
TorchRL also provides a RL knowledge base to help you debug your code, or simply
|
|
351
|
+
learn the basics of RL. Check it out [here](https://pytorch.org/rl/stable/reference/knowledge_base.html).
|
|
352
|
+
|
|
353
|
+
We have some introductory videos for you to get to know the library better, check them out:
|
|
354
|
+
|
|
355
|
+
- [TalkRL podcast](https://www.talkrl.com/episodes/vincent-moens-on-torchrl)
|
|
356
|
+
- [TorchRL intro at PyTorch day 2022](https://youtu.be/cIKMhZoykEE)
|
|
357
|
+
- [PyTorch 2.0 Q&A: TorchRL](https://www.youtube.com/live/myEfUoYrbts?feature=share)
|
|
358
|
+
|
|
359
|
+
## Spotlight publications
|
|
360
|
+
|
|
361
|
+
TorchRL being domain-agnostic, you can use it across many different fields. Here are a few examples:
|
|
362
|
+
|
|
363
|
+
- [ACEGEN](https://pubs.acs.org/doi/10.1021/acs.jcim.4c00895): Reinforcement Learning of Generative Chemical Agents
|
|
364
|
+
for Drug Discovery
|
|
365
|
+
- [BenchMARL](https://www.jmlr.org/papers/v25/23-1612.html): Benchmarking Multi-Agent Reinforcement Learning
|
|
366
|
+
- [BricksRL](https://arxiv.org/abs/2406.17490): A Platform for Democratizing Robotics and Reinforcement Learning
|
|
367
|
+
Research and Education with LEGO
|
|
368
|
+
- [OmniDrones](https://ieeexplore.ieee.org/abstract/document/10409589): An Efficient and Flexible Platform for Reinforcement Learning in Drone Control
|
|
369
|
+
- [RL4CO](https://arxiv.org/abs/2306.17100): an Extensive Reinforcement Learning for Combinatorial Optimization Benchmark
|
|
370
|
+
- [Robohive](https://proceedings.neurips.cc/paper_files/paper/2023/file/8a84a4341c375b8441b36836bb343d4e-Paper-Datasets_and_Benchmarks.pdf): A unified framework for robot learning
|
|
371
|
+
|
|
372
|
+
## Writing simplified and portable RL codebase with `TensorDict`
|
|
373
|
+
|
|
374
|
+
RL algorithms are very heterogeneous, and it can be hard to recycle a codebase
|
|
375
|
+
across settings (e.g. from online to offline, from state-based to pixel-based
|
|
376
|
+
learning).
|
|
377
|
+
TorchRL solves this problem through [`TensorDict`](https://github.com/pytorch/tensordict/),
|
|
378
|
+
a convenient data structure<sup>(1)</sup> that can be used to streamline one's
|
|
379
|
+
RL codebase.
|
|
380
|
+
With this tool, one can write a *complete PPO training script in less than 100
|
|
381
|
+
lines of code*!
|
|
382
|
+
|
|
383
|
+
<details>
|
|
384
|
+
<summary>Code</summary>
|
|
385
|
+
|
|
386
|
+
```python
|
|
387
|
+
import torch
|
|
388
|
+
from tensordict.nn import TensorDictModule
|
|
389
|
+
from tensordict.nn.distributions import NormalParamExtractor
|
|
390
|
+
from torch import nn
|
|
391
|
+
|
|
392
|
+
from torchrl.collectors import SyncDataCollector
|
|
393
|
+
from torchrl.data.replay_buffers import TensorDictReplayBuffer, \
|
|
394
|
+
LazyTensorStorage, SamplerWithoutReplacement
|
|
395
|
+
from torchrl.envs.libs.gym import GymEnv
|
|
396
|
+
from torchrl.modules import ProbabilisticActor, ValueOperator, TanhNormal
|
|
397
|
+
from torchrl.objectives import ClipPPOLoss
|
|
398
|
+
from torchrl.objectives.value import GAE
|
|
399
|
+
|
|
400
|
+
env = GymEnv("Pendulum-v1")
|
|
401
|
+
model = TensorDictModule(
|
|
402
|
+
nn.Sequential(
|
|
403
|
+
nn.Linear(3, 128), nn.Tanh(),
|
|
404
|
+
nn.Linear(128, 128), nn.Tanh(),
|
|
405
|
+
nn.Linear(128, 128), nn.Tanh(),
|
|
406
|
+
nn.Linear(128, 2),
|
|
407
|
+
NormalParamExtractor()
|
|
408
|
+
),
|
|
409
|
+
in_keys=["observation"],
|
|
410
|
+
out_keys=["loc", "scale"]
|
|
411
|
+
)
|
|
412
|
+
critic = ValueOperator(
|
|
413
|
+
nn.Sequential(
|
|
414
|
+
nn.Linear(3, 128), nn.Tanh(),
|
|
415
|
+
nn.Linear(128, 128), nn.Tanh(),
|
|
416
|
+
nn.Linear(128, 128), nn.Tanh(),
|
|
417
|
+
nn.Linear(128, 1),
|
|
418
|
+
),
|
|
419
|
+
in_keys=["observation"],
|
|
420
|
+
)
|
|
421
|
+
actor = ProbabilisticActor(
|
|
422
|
+
model,
|
|
423
|
+
in_keys=["loc", "scale"],
|
|
424
|
+
distribution_class=TanhNormal,
|
|
425
|
+
distribution_kwargs={"low": -1.0, "high": 1.0},
|
|
426
|
+
return_log_prob=True
|
|
427
|
+
)
|
|
428
|
+
buffer = TensorDictReplayBuffer(
|
|
429
|
+
storage=LazyTensorStorage(1000),
|
|
430
|
+
sampler=SamplerWithoutReplacement(),
|
|
431
|
+
batch_size=50,
|
|
432
|
+
)
|
|
433
|
+
collector = SyncDataCollector(
|
|
434
|
+
env,
|
|
435
|
+
actor,
|
|
436
|
+
frames_per_batch=1000,
|
|
437
|
+
total_frames=1_000_000,
|
|
438
|
+
)
|
|
439
|
+
loss_fn = ClipPPOLoss(actor, critic)
|
|
440
|
+
adv_fn = GAE(value_network=critic, average_gae=True, gamma=0.99, lmbda=0.95)
|
|
441
|
+
optim = torch.optim.Adam(loss_fn.parameters(), lr=2e-4)
|
|
442
|
+
|
|
443
|
+
for data in collector: # collect data
|
|
444
|
+
for epoch in range(10):
|
|
445
|
+
adv_fn(data) # compute advantage
|
|
446
|
+
buffer.extend(data)
|
|
447
|
+
for sample in buffer: # consume data
|
|
448
|
+
loss_vals = loss_fn(sample)
|
|
449
|
+
loss_val = sum(
|
|
450
|
+
value for key, value in loss_vals.items() if
|
|
451
|
+
key.startswith("loss")
|
|
452
|
+
)
|
|
453
|
+
loss_val.backward()
|
|
454
|
+
optim.step()
|
|
455
|
+
optim.zero_grad()
|
|
456
|
+
print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}")
|
|
457
|
+
```
|
|
458
|
+
</details>
|
|
459
|
+
|
|
460
|
+
Here is an example of how the [environment API](https://pytorch.org/rl/stable/reference/envs.html)
|
|
461
|
+
relies on tensordict to carry data from one function to another during a rollout
|
|
462
|
+
execution:
|
|
463
|
+

|
|
464
|
+
|
|
465
|
+
`TensorDict` makes it easy to re-use pieces of code across environments, models and
|
|
466
|
+
algorithms.
|
|
467
|
+
<details>
|
|
468
|
+
<summary>Code</summary>
|
|
469
|
+
|
|
470
|
+
For instance, here's how to code a rollout in TorchRL:
|
|
471
|
+
|
|
472
|
+
```diff
|
|
473
|
+
- obs, done = env.reset()
|
|
474
|
+
+ tensordict = env.reset()
|
|
475
|
+
policy = SafeModule(
|
|
476
|
+
model,
|
|
477
|
+
in_keys=["observation_pixels", "observation_vector"],
|
|
478
|
+
out_keys=["action"],
|
|
479
|
+
)
|
|
480
|
+
out = []
|
|
481
|
+
for i in range(n_steps):
|
|
482
|
+
- action, log_prob = policy(obs)
|
|
483
|
+
- next_obs, reward, done, info = env.step(action)
|
|
484
|
+
- out.append((obs, next_obs, action, log_prob, reward, done))
|
|
485
|
+
- obs = next_obs
|
|
486
|
+
+ tensordict = policy(tensordict)
|
|
487
|
+
+ tensordict = env.step(tensordict)
|
|
488
|
+
+ out.append(tensordict)
|
|
489
|
+
+ tensordict = step_mdp(tensordict) # renames next_observation_* keys to observation_*
|
|
490
|
+
- obs, next_obs, action, log_prob, reward, done = [torch.stack(vals, 0) for vals in zip(*out)]
|
|
491
|
+
+ out = torch.stack(out, 0) # TensorDict supports multiple tensor operations
|
|
492
|
+
```
|
|
493
|
+
</details>
|
|
494
|
+
|
|
495
|
+
Using this, TorchRL abstracts away the input / output signatures of the modules, env,
|
|
496
|
+
collectors, replay buffers and losses of the library, allowing all primitives
|
|
497
|
+
to be easily recycled across settings.
|
|
498
|
+
|
|
499
|
+
<details>
|
|
500
|
+
<summary>Code</summary>
|
|
501
|
+
|
|
502
|
+
Here's another example of an off-policy training loop in TorchRL (assuming
|
|
503
|
+
that a data collector, a replay buffer, a loss and an optimizer have been instantiated):
|
|
504
|
+
|
|
505
|
+
```diff
|
|
506
|
+
- for i, (obs, next_obs, action, hidden_state, reward, done) in enumerate(collector):
|
|
507
|
+
+ for i, tensordict in enumerate(collector):
|
|
508
|
+
- replay_buffer.add((obs, next_obs, action, log_prob, reward, done))
|
|
509
|
+
+ replay_buffer.add(tensordict)
|
|
510
|
+
for j in range(num_optim_steps):
|
|
511
|
+
- obs, next_obs, action, hidden_state, reward, done = replay_buffer.sample(batch_size)
|
|
512
|
+
- loss = loss_fn(obs, next_obs, action, hidden_state, reward, done)
|
|
513
|
+
+ tensordict = replay_buffer.sample(batch_size)
|
|
514
|
+
+ loss = loss_fn(tensordict)
|
|
515
|
+
loss.backward()
|
|
516
|
+
optim.step()
|
|
517
|
+
optim.zero_grad()
|
|
518
|
+
```
|
|
519
|
+
This training loop can be re-used across algorithms as it makes a minimal number of assumptions about the structure of the data.
|
|
520
|
+
</details>
|
|
521
|
+
|
|
522
|
+
TensorDict supports multiple tensor operations on its device and shape
|
|
523
|
+
(the shape of TensorDict, or its batch size, is the common arbitrary N first dimensions of all its contained tensors):
|
|
524
|
+
|
|
525
|
+
<details>
|
|
526
|
+
<summary>Code</summary>
|
|
527
|
+
|
|
528
|
+
```python
|
|
529
|
+
# stack and cat
|
|
530
|
+
tensordict = torch.stack(list_of_tensordicts, 0)
|
|
531
|
+
tensordict = torch.cat(list_of_tensordicts, 0)
|
|
532
|
+
# reshape
|
|
533
|
+
tensordict = tensordict.view(-1)
|
|
534
|
+
tensordict = tensordict.permute(0, 2, 1)
|
|
535
|
+
tensordict = tensordict.unsqueeze(-1)
|
|
536
|
+
tensordict = tensordict.squeeze(-1)
|
|
537
|
+
# indexing
|
|
538
|
+
tensordict = tensordict[:2]
|
|
539
|
+
tensordict[:, 2] = sub_tensordict
|
|
540
|
+
# device and memory location
|
|
541
|
+
tensordict.cuda()
|
|
542
|
+
tensordict.to("cuda:1")
|
|
543
|
+
tensordict.share_memory_()
|
|
544
|
+
```
|
|
545
|
+
</details>
|
|
546
|
+
|
|
547
|
+
TensorDict comes with a dedicated [`tensordict.nn`](https://pytorch.github.io/tensordict/reference/nn.html)
|
|
548
|
+
module that contains everything you might need to write your model with it.
|
|
549
|
+
And it is `functorch` and `torch.compile` compatible!
|
|
550
|
+
|
|
551
|
+
<details>
|
|
552
|
+
<summary>Code</summary>
|
|
553
|
+
|
|
554
|
+
```diff
|
|
555
|
+
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
|
|
556
|
+
+ td_module = SafeModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"])
|
|
557
|
+
src = torch.rand((10, 32, 512))
|
|
558
|
+
tgt = torch.rand((20, 32, 512))
|
|
559
|
+
+ tensordict = TensorDict({"src": src, "tgt": tgt}, batch_size=[20, 32])
|
|
560
|
+
- out = transformer_model(src, tgt)
|
|
561
|
+
+ td_module(tensordict)
|
|
562
|
+
+ out = tensordict["out"]
|
|
563
|
+
```
|
|
564
|
+
|
|
565
|
+
The `TensorDictSequential` class allows to branch sequences of `nn.Module` instances in a highly modular way.
|
|
566
|
+
For instance, here is an implementation of a transformer using the encoder and decoder blocks:
|
|
567
|
+
```python
|
|
568
|
+
encoder_module = TransformerEncoder(...)
|
|
569
|
+
encoder = TensorDictSequential(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"])
|
|
570
|
+
decoder_module = TransformerDecoder(...)
|
|
571
|
+
decoder = TensorDictModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"])
|
|
572
|
+
transformer = TensorDictSequential(encoder, decoder)
|
|
573
|
+
assert transformer.in_keys == ["src", "src_mask", "tgt"]
|
|
574
|
+
assert transformer.out_keys == ["memory", "output"]
|
|
575
|
+
```
|
|
576
|
+
|
|
577
|
+
`TensorDictSequential` allows to isolate subgraphs by querying a set of desired input / output keys:
|
|
578
|
+
```python
|
|
579
|
+
transformer.select_subsequence(out_keys=["memory"]) # returns the encoder
|
|
580
|
+
transformer.select_subsequence(in_keys=["tgt", "memory"]) # returns the decoder
|
|
581
|
+
```
|
|
582
|
+
</details>
|
|
583
|
+
|
|
584
|
+
Check [TensorDict tutorials](https://pytorch.github.io/tensordict/) to
|
|
585
|
+
learn more!
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
## Features
|
|
589
|
+
|
|
590
|
+
- A common [interface for environments](https://github.com/pytorch/rl/blob/main/torchrl/envs)
|
|
591
|
+
which supports common libraries (OpenAI gym, deepmind control lab, etc.)<sup>(1)</sup> and state-less execution
|
|
592
|
+
(e.g. Model-based environments).
|
|
593
|
+
The [batched environments](https://github.com/pytorch/rl/blob/main/torchrl/envs/batched_envs.py) containers allow parallel execution<sup>(2)</sup>.
|
|
594
|
+
A common PyTorch-first class of [tensor-specification class](https://github.com/pytorch/rl/blob/main/torchrl/data/tensor_specs.py) is also provided.
|
|
595
|
+
TorchRL's environments API is simple but stringent and specific. Check the
|
|
596
|
+
[documentation](https://pytorch.org/rl/stable/reference/envs.html)
|
|
597
|
+
and [tutorial](https://pytorch.org/rl/stable/tutorials/pendulum.html) to learn more!
|
|
598
|
+
<details>
|
|
599
|
+
<summary>Code</summary>
|
|
600
|
+
|
|
601
|
+
```python
|
|
602
|
+
env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
|
|
603
|
+
env_parallel = ParallelEnv(4, env_make) # creates 4 envs in parallel
|
|
604
|
+
tensordict = env_parallel.rollout(max_steps=20, policy=None) # random rollout (no policy given)
|
|
605
|
+
assert tensordict.shape == [4, 20] # 4 envs, 20 steps rollout
|
|
606
|
+
env_parallel.action_spec.is_in(tensordict["action"]) # spec check returns True
|
|
607
|
+
```
|
|
608
|
+
</details>
|
|
609
|
+
|
|
610
|
+
- multiprocess and distributed [data collectors](https://github.com/pytorch/rl/blob/main/torchrl/collectors/collectors.py)<sup>(2)</sup>
|
|
611
|
+
that work synchronously or asynchronously.
|
|
612
|
+
Through the use of TensorDict, TorchRL's training loops are made very similar
|
|
613
|
+
to regular training loops in supervised
|
|
614
|
+
learning (although the "dataloader" -- read data collector -- is modified on-the-fly):
|
|
615
|
+
<details>
|
|
616
|
+
<summary>Code</summary>
|
|
617
|
+
|
|
618
|
+
```python
|
|
619
|
+
env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
|
|
620
|
+
collector = MultiaSyncDataCollector(
|
|
621
|
+
[env_make, env_make],
|
|
622
|
+
policy=policy,
|
|
623
|
+
devices=["cuda:0", "cuda:0"],
|
|
624
|
+
total_frames=10000,
|
|
625
|
+
frames_per_batch=50,
|
|
626
|
+
...
|
|
627
|
+
)
|
|
628
|
+
for i, tensordict_data in enumerate(collector):
|
|
629
|
+
loss = loss_module(tensordict_data)
|
|
630
|
+
loss.backward()
|
|
631
|
+
optim.step()
|
|
632
|
+
optim.zero_grad()
|
|
633
|
+
collector.update_policy_weights_()
|
|
634
|
+
```
|
|
635
|
+
</details>
|
|
636
|
+
|
|
637
|
+
Check our [distributed collector examples](https://github.com/pytorch/rl/blob/main/examples/distributed/collectors) to
|
|
638
|
+
learn more about ultra-fast data collection with TorchRL.
|
|
639
|
+
|
|
640
|
+
- efficient<sup>(2)</sup> and generic<sup>(1)</sup> [replay buffers](https://github.com/pytorch/rl/blob/main/torchrl/data/replay_buffers/replay_buffers.py) with modularized storage:
|
|
641
|
+
<details>
|
|
642
|
+
<summary>Code</summary>
|
|
643
|
+
|
|
644
|
+
```python
|
|
645
|
+
storage = LazyMemmapStorage( # memory-mapped (physical) storage
|
|
646
|
+
cfg.buffer_size,
|
|
647
|
+
scratch_dir="/tmp/"
|
|
648
|
+
)
|
|
649
|
+
buffer = TensorDictPrioritizedReplayBuffer(
|
|
650
|
+
alpha=0.7,
|
|
651
|
+
beta=0.5,
|
|
652
|
+
collate_fn=lambda x: x,
|
|
653
|
+
pin_memory=device != torch.device("cpu"),
|
|
654
|
+
prefetch=10, # multi-threaded sampling
|
|
655
|
+
storage=storage
|
|
656
|
+
)
|
|
657
|
+
```
|
|
658
|
+
</details>
|
|
659
|
+
|
|
660
|
+
Replay buffers are also offered as wrappers around common datasets for *offline RL*:
|
|
661
|
+
<details>
|
|
662
|
+
<summary>Code</summary>
|
|
663
|
+
|
|
664
|
+
```python
|
|
665
|
+
from torchrl.data.replay_buffers import SamplerWithoutReplacement
|
|
666
|
+
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
|
|
667
|
+
data = D4RLExperienceReplay(
|
|
668
|
+
"maze2d-open-v0",
|
|
669
|
+
split_trajs=True,
|
|
670
|
+
batch_size=128,
|
|
671
|
+
sampler=SamplerWithoutReplacement(drop_last=True),
|
|
672
|
+
)
|
|
673
|
+
for sample in data: # or alternatively sample = data.sample()
|
|
674
|
+
fun(sample)
|
|
675
|
+
```
|
|
676
|
+
</details>
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
- cross-library [environment transforms](https://github.com/pytorch/rl/blob/main/torchrl/envs/transforms/transforms.py)<sup>(1)</sup>,
|
|
680
|
+
executed on device and in a vectorized fashion<sup>(2)</sup>,
|
|
681
|
+
which process and prepare the data coming out of the environments to be used by the agent:
|
|
682
|
+
<details>
|
|
683
|
+
<summary>Code</summary>
|
|
684
|
+
|
|
685
|
+
```python
|
|
686
|
+
env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
|
|
687
|
+
env_base = ParallelEnv(4, env_make, device="cuda:0") # creates 4 envs in parallel
|
|
688
|
+
env = TransformedEnv(
|
|
689
|
+
env_base,
|
|
690
|
+
Compose(
|
|
691
|
+
ToTensorImage(),
|
|
692
|
+
ObservationNorm(loc=0.5, scale=1.0)), # executes the transforms once and on device
|
|
693
|
+
)
|
|
694
|
+
tensordict = env.reset()
|
|
695
|
+
assert tensordict.device == torch.device("cuda:0")
|
|
696
|
+
```
|
|
697
|
+
Other transforms include: reward scaling (`RewardScaling`), shape operations (concatenation of tensors, unsqueezing etc.), concatenation of
|
|
698
|
+
successive operations (`CatFrames`), resizing (`Resize`) and many more.
|
|
699
|
+
|
|
700
|
+
Unlike other libraries, the transforms are stacked as a list (and not wrapped in each other), which makes it
|
|
701
|
+
easy to add and remove them at will:
|
|
702
|
+
```python
|
|
703
|
+
env.insert_transform(0, NoopResetEnv()) # inserts the NoopResetEnv transform at the index 0
|
|
704
|
+
```
|
|
705
|
+
Nevertheless, transforms can access and execute operations on the parent environment:
|
|
706
|
+
```python
|
|
707
|
+
transform = env.transform[1] # gathers the second transform of the list
|
|
708
|
+
parent_env = transform.parent # returns the base environment of the second transform, i.e. the base env + the first transform
|
|
709
|
+
```
|
|
710
|
+
</details>
|
|
711
|
+
|
|
712
|
+
- various tools for distributed learning (e.g. [memory mapped tensors](https://github.com/pytorch/tensordict/blob/main/tensordict/memmap.py))<sup>(2)</sup>;
|
|
713
|
+
- various [architectures](https://github.com/pytorch/rl/blob/main/torchrl/modules/models/) and models (e.g. [actor-critic](https://github.com/pytorch/rl/blob/main/torchrl/modules/tensordict_module/actors.py))<sup>(1)</sup>:
|
|
714
|
+
<details>
|
|
715
|
+
<summary>Code</summary>
|
|
716
|
+
|
|
717
|
+
```python
|
|
718
|
+
# create an nn.Module
|
|
719
|
+
common_module = ConvNet(
|
|
720
|
+
bias_last_layer=True,
|
|
721
|
+
depth=None,
|
|
722
|
+
num_cells=[32, 64, 64],
|
|
723
|
+
kernel_sizes=[8, 4, 3],
|
|
724
|
+
strides=[4, 2, 1],
|
|
725
|
+
)
|
|
726
|
+
# Wrap it in a SafeModule, indicating what key to read in and where to
|
|
727
|
+
# write out the output
|
|
728
|
+
common_module = SafeModule(
|
|
729
|
+
common_module,
|
|
730
|
+
in_keys=["pixels"],
|
|
731
|
+
out_keys=["hidden"],
|
|
732
|
+
)
|
|
733
|
+
# Wrap the policy module in NormalParamsWrapper, such that the output
|
|
734
|
+
# tensor is split in loc and scale, and scale is mapped onto a positive space
|
|
735
|
+
policy_module = SafeModule(
|
|
736
|
+
NormalParamsWrapper(
|
|
737
|
+
MLP(num_cells=[64, 64], out_features=32, activation=nn.ELU)
|
|
738
|
+
),
|
|
739
|
+
in_keys=["hidden"],
|
|
740
|
+
out_keys=["loc", "scale"],
|
|
741
|
+
)
|
|
742
|
+
# Use a SafeProbabilisticTensorDictSequential to combine the SafeModule with a
|
|
743
|
+
# SafeProbabilisticModule, indicating how to build the
|
|
744
|
+
# torch.distribution.Distribution object and what to do with it
|
|
745
|
+
policy_module = SafeProbabilisticTensorDictSequential( # stochastic policy
|
|
746
|
+
policy_module,
|
|
747
|
+
SafeProbabilisticModule(
|
|
748
|
+
in_keys=["loc", "scale"],
|
|
749
|
+
out_keys="action",
|
|
750
|
+
distribution_class=TanhNormal,
|
|
751
|
+
),
|
|
752
|
+
)
|
|
753
|
+
value_module = MLP(
|
|
754
|
+
num_cells=[64, 64],
|
|
755
|
+
out_features=1,
|
|
756
|
+
activation=nn.ELU,
|
|
757
|
+
)
|
|
758
|
+
# Wrap the policy and value funciton in a common module
|
|
759
|
+
actor_value = ActorValueOperator(common_module, policy_module, value_module)
|
|
760
|
+
# standalone policy from this
|
|
761
|
+
standalone_policy = actor_value.get_policy_operator()
|
|
762
|
+
```
|
|
763
|
+
</details>
|
|
764
|
+
|
|
765
|
+
- exploration [wrappers](https://github.com/pytorch/rl/blob/main/torchrl/modules/tensordict_module/exploration.py) and
|
|
766
|
+
[modules](https://github.com/pytorch/rl/blob/main/torchrl/modules/models/exploration.py) to easily swap between exploration and exploitation<sup>(1)</sup>:
|
|
767
|
+
<details>
|
|
768
|
+
<summary>Code</summary>
|
|
769
|
+
|
|
770
|
+
```python
|
|
771
|
+
policy_explore = EGreedyWrapper(policy)
|
|
772
|
+
with set_exploration_type(ExplorationType.RANDOM):
|
|
773
|
+
tensordict = policy_explore(tensordict) # will use eps-greedy
|
|
774
|
+
with set_exploration_type(ExplorationType.DETERMINISTIC):
|
|
775
|
+
tensordict = policy_explore(tensordict) # will not use eps-greedy
|
|
776
|
+
```
|
|
777
|
+
</details>
|
|
778
|
+
|
|
779
|
+
- A series of efficient [loss modules](https://github.com/pytorch/rl/tree/main/torchrl/objectives)
|
|
780
|
+
and highly vectorized
|
|
781
|
+
[functional return and advantage](https://github.com/pytorch/rl/blob/main/torchrl/objectives/value/functional.py)
|
|
782
|
+
computation.
|
|
783
|
+
|
|
784
|
+
<details>
|
|
785
|
+
<summary>Code</summary>
|
|
786
|
+
|
|
787
|
+
### Loss modules
|
|
788
|
+
```python
|
|
789
|
+
from torchrl.objectives import DQNLoss
|
|
790
|
+
loss_module = DQNLoss(value_network=value_network, gamma=0.99)
|
|
791
|
+
tensordict = replay_buffer.sample(batch_size)
|
|
792
|
+
loss = loss_module(tensordict)
|
|
793
|
+
```
|
|
794
|
+
|
|
795
|
+
### Advantage computation
|
|
796
|
+
```python
|
|
797
|
+
from torchrl.objectives.value.functional import vec_td_lambda_return_estimate
|
|
798
|
+
advantage = vec_td_lambda_return_estimate(gamma, lmbda, next_state_value, reward, done, terminated)
|
|
799
|
+
```
|
|
800
|
+
|
|
801
|
+
</details>
|
|
802
|
+
|
|
803
|
+
- a generic [trainer class](https://github.com/pytorch/rl/blob/main/torchrl/trainers/trainers.py)<sup>(1)</sup> that
|
|
804
|
+
executes the aforementioned training loop. Through a hooking mechanism,
|
|
805
|
+
it also supports any logging or data transformation operation at any given
|
|
806
|
+
time.
|
|
807
|
+
|
|
808
|
+
- various [recipes](https://github.com/pytorch/rl/blob/main/torchrl/trainers/helpers/models.py) to build models that
|
|
809
|
+
correspond to the environment being deployed.
|
|
810
|
+
|
|
811
|
+
- **LLM API**: Complete framework for language model fine-tuning with unified wrappers for Hugging Face and vLLM backends,
|
|
812
|
+
conversation management with automatic chat template detection, tool integration (Python execution, function calling),
|
|
813
|
+
specialized objectives (GRPO, SFT), and high-performance async collectors. Perfect for RLHF, supervised fine-tuning,
|
|
814
|
+
and tool-augmented training scenarios.
|
|
815
|
+
<details>
|
|
816
|
+
<summary>Code</summary>
|
|
817
|
+
|
|
818
|
+
```python
|
|
819
|
+
from torchrl.envs.llm import ChatEnv
|
|
820
|
+
from torchrl.modules.llm import TransformersWrapper
|
|
821
|
+
from torchrl.envs.llm.transforms import PythonInterpreter
|
|
822
|
+
|
|
823
|
+
# Create environment with tool execution
|
|
824
|
+
env = ChatEnv(
|
|
825
|
+
tokenizer=tokenizer,
|
|
826
|
+
system_prompt="You can execute Python code.",
|
|
827
|
+
batch_size=[1]
|
|
828
|
+
).append_transform(PythonInterpreter())
|
|
829
|
+
|
|
830
|
+
# Wrap language model for training
|
|
831
|
+
llm = TransformersWrapper(
|
|
832
|
+
model=model,
|
|
833
|
+
tokenizer=tokenizer,
|
|
834
|
+
input_mode="history"
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
# Multi-turn conversation with tool use
|
|
838
|
+
obs = env.reset(TensorDict({"query": "Calculate 2+2"}, batch_size=[1]))
|
|
839
|
+
llm_output = llm(obs) # Generates response
|
|
840
|
+
obs = env.step(llm_output) # Environment processes response
|
|
841
|
+
```
|
|
842
|
+
</details>
|
|
843
|
+
|
|
844
|
+
If you feel a feature is missing from the library, please submit an issue!
|
|
845
|
+
If you would like to contribute to new features, check our [call for contributions](https://github.com/pytorch/rl/issues/509) and our [contribution](https://github.com/pytorch/rl/blob/main/CONTRIBUTING.md) page.
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
## Examples, tutorials and demos
|
|
849
|
+
|
|
850
|
+
A series of [State-of-the-Art implementations](https://github.com/pytorch/rl/blob/main/sota-implementations/) are provided with an illustrative purpose:
|
|
851
|
+
|
|
852
|
+
<table>
|
|
853
|
+
<tr>
|
|
854
|
+
<td><strong>Algorithm</strong>
|
|
855
|
+
</td>
|
|
856
|
+
<td><strong>Compile Support**</strong>
|
|
857
|
+
</td>
|
|
858
|
+
<td><strong>Tensordict-free API</strong>
|
|
859
|
+
</td>
|
|
860
|
+
<td><strong>Modular Losses</strong>
|
|
861
|
+
</td>
|
|
862
|
+
<td><strong>Continuous and Discrete</strong>
|
|
863
|
+
</td>
|
|
864
|
+
</tr>
|
|
865
|
+
<tr>
|
|
866
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/dqn">DQN</a>
|
|
867
|
+
</td>
|
|
868
|
+
<td> 1.9x
|
|
869
|
+
</td>
|
|
870
|
+
<td> +
|
|
871
|
+
</td>
|
|
872
|
+
<td> NA
|
|
873
|
+
</td>
|
|
874
|
+
<td> + (through <a href="https://pytorch.org/rl/stable/reference/generated/torchrl.envs.transforms.ActionDiscretizer.html?highlight=actiondiscretizer">ActionDiscretizer</a> transform)
|
|
875
|
+
</td>
|
|
876
|
+
</tr>
|
|
877
|
+
<tr>
|
|
878
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/ddpg/ddpg.py">DDPG</a>
|
|
879
|
+
</td>
|
|
880
|
+
<td> 1.87x
|
|
881
|
+
</td>
|
|
882
|
+
<td> +
|
|
883
|
+
</td>
|
|
884
|
+
<td> +
|
|
885
|
+
</td>
|
|
886
|
+
<td> - (continuous only)
|
|
887
|
+
</td>
|
|
888
|
+
</tr>
|
|
889
|
+
<tr>
|
|
890
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/iql/">IQL</a>
|
|
891
|
+
</td>
|
|
892
|
+
<td> 3.22x
|
|
893
|
+
</td>
|
|
894
|
+
<td> +
|
|
895
|
+
</td>
|
|
896
|
+
<td> +
|
|
897
|
+
</td>
|
|
898
|
+
<td> +
|
|
899
|
+
</td>
|
|
900
|
+
</tr>
|
|
901
|
+
<tr>
|
|
902
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/cql/cql_offline.py">CQL</a>
|
|
903
|
+
</td>
|
|
904
|
+
<td> 2.68x
|
|
905
|
+
</td>
|
|
906
|
+
<td> +
|
|
907
|
+
</td>
|
|
908
|
+
<td> +
|
|
909
|
+
</td>
|
|
910
|
+
<td> +
|
|
911
|
+
</td>
|
|
912
|
+
</tr>
|
|
913
|
+
<tr>
|
|
914
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/td3/td3.py">TD3</a>
|
|
915
|
+
</td>
|
|
916
|
+
<td> 2.27x
|
|
917
|
+
</td>
|
|
918
|
+
<td> +
|
|
919
|
+
</td>
|
|
920
|
+
<td> +
|
|
921
|
+
</td>
|
|
922
|
+
<td> - (continuous only)
|
|
923
|
+
</td>
|
|
924
|
+
</tr>
|
|
925
|
+
<tr>
|
|
926
|
+
<td>
|
|
927
|
+
<a href="https://github.com/pytorch/rl/blob/main/sota-implementations/td3_bc/td3_bc.py">TD3+BC</a>
|
|
928
|
+
</td>
|
|
929
|
+
<td> untested
|
|
930
|
+
</td>
|
|
931
|
+
<td> +
|
|
932
|
+
</td>
|
|
933
|
+
<td> +
|
|
934
|
+
</td>
|
|
935
|
+
<td> - (continuous only)
|
|
936
|
+
</td>
|
|
937
|
+
</tr>
|
|
938
|
+
<tr>
|
|
939
|
+
<td>
|
|
940
|
+
<a href="https://github.com/pytorch/rl/blob/main/examples/a2c/">A2C</a>
|
|
941
|
+
</td>
|
|
942
|
+
<td> 2.67x
|
|
943
|
+
</td>
|
|
944
|
+
<td> +
|
|
945
|
+
</td>
|
|
946
|
+
<td> -
|
|
947
|
+
</td>
|
|
948
|
+
<td> +
|
|
949
|
+
</td>
|
|
950
|
+
</tr>
|
|
951
|
+
<tr>
|
|
952
|
+
<td>
|
|
953
|
+
<a href="https://github.com/pytorch/rl/blob/main/sota-implementations/ppo/">PPO</a>
|
|
954
|
+
</td>
|
|
955
|
+
<td> 2.42x
|
|
956
|
+
</td>
|
|
957
|
+
<td> +
|
|
958
|
+
</td>
|
|
959
|
+
<td> -
|
|
960
|
+
</td>
|
|
961
|
+
<td> +
|
|
962
|
+
</td>
|
|
963
|
+
</tr>
|
|
964
|
+
<tr>
|
|
965
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/sac/sac.py">SAC</a>
|
|
966
|
+
</td>
|
|
967
|
+
<td> 2.62x
|
|
968
|
+
</td>
|
|
969
|
+
<td> +
|
|
970
|
+
</td>
|
|
971
|
+
<td> -
|
|
972
|
+
</td>
|
|
973
|
+
<td> +
|
|
974
|
+
</td>
|
|
975
|
+
</tr>
|
|
976
|
+
<tr>
|
|
977
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/redq/redq.py">REDQ</a>
|
|
978
|
+
</td>
|
|
979
|
+
<td> 2.28x
|
|
980
|
+
</td>
|
|
981
|
+
<td> +
|
|
982
|
+
</td>
|
|
983
|
+
<td> -
|
|
984
|
+
</td>
|
|
985
|
+
<td> - (continuous only)
|
|
986
|
+
</td>
|
|
987
|
+
</tr>
|
|
988
|
+
<tr>
|
|
989
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/dreamer/dreamer.py">Dreamer v1</a>
|
|
990
|
+
</td>
|
|
991
|
+
<td> untested
|
|
992
|
+
</td>
|
|
993
|
+
<td> +
|
|
994
|
+
</td>
|
|
995
|
+
<td> + (<a href="https://pytorch.org/rl/stable/reference/objectives.html#dreamer">different classes</a>)
|
|
996
|
+
</td>
|
|
997
|
+
<td> - (continuous only)
|
|
998
|
+
</td>
|
|
999
|
+
</tr>
|
|
1000
|
+
<tr>
|
|
1001
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/decision_transformer">Decision Transformers</a>
|
|
1002
|
+
</td>
|
|
1003
|
+
<td> untested
|
|
1004
|
+
</td>
|
|
1005
|
+
<td> +
|
|
1006
|
+
</td>
|
|
1007
|
+
<td> NA
|
|
1008
|
+
</td>
|
|
1009
|
+
<td> - (continuous only)
|
|
1010
|
+
</td>
|
|
1011
|
+
</tr>
|
|
1012
|
+
<tr>
|
|
1013
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/crossq">CrossQ</a>
|
|
1014
|
+
</td>
|
|
1015
|
+
<td> untested
|
|
1016
|
+
</td>
|
|
1017
|
+
<td> +
|
|
1018
|
+
</td>
|
|
1019
|
+
<td> +
|
|
1020
|
+
</td>
|
|
1021
|
+
<td> - (continuous only)
|
|
1022
|
+
</td>
|
|
1023
|
+
</tr>
|
|
1024
|
+
<tr>
|
|
1025
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/gail">Gail</a>
|
|
1026
|
+
</td>
|
|
1027
|
+
<td> untested
|
|
1028
|
+
</td>
|
|
1029
|
+
<td> +
|
|
1030
|
+
</td>
|
|
1031
|
+
<td> NA
|
|
1032
|
+
</td>
|
|
1033
|
+
<td> +
|
|
1034
|
+
</td>
|
|
1035
|
+
</tr>
|
|
1036
|
+
<tr>
|
|
1037
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/impala">Impala</a>
|
|
1038
|
+
</td>
|
|
1039
|
+
<td> untested
|
|
1040
|
+
</td>
|
|
1041
|
+
<td> +
|
|
1042
|
+
</td>
|
|
1043
|
+
<td> -
|
|
1044
|
+
</td>
|
|
1045
|
+
<td> +
|
|
1046
|
+
</td>
|
|
1047
|
+
</tr>
|
|
1048
|
+
<tr>
|
|
1049
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/multiagent/iql.py">IQL (MARL)</a>
|
|
1050
|
+
</td>
|
|
1051
|
+
<td> untested
|
|
1052
|
+
</td>
|
|
1053
|
+
<td> +
|
|
1054
|
+
</td>
|
|
1055
|
+
<td> +
|
|
1056
|
+
</td>
|
|
1057
|
+
<td> +
|
|
1058
|
+
</td>
|
|
1059
|
+
</tr>
|
|
1060
|
+
<tr>
|
|
1061
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/multiagent/maddpg_iddpg.py">DDPG (MARL)</a>
|
|
1062
|
+
</td>
|
|
1063
|
+
<td> untested
|
|
1064
|
+
</td>
|
|
1065
|
+
<td> +
|
|
1066
|
+
</td>
|
|
1067
|
+
<td> +
|
|
1068
|
+
</td>
|
|
1069
|
+
<td> - (continuous only)
|
|
1070
|
+
</td>
|
|
1071
|
+
</tr>
|
|
1072
|
+
<tr>
|
|
1073
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/multiagent/mappo_ippo.py">PPO (MARL)</a>
|
|
1074
|
+
</td>
|
|
1075
|
+
<td> untested
|
|
1076
|
+
</td>
|
|
1077
|
+
<td> +
|
|
1078
|
+
</td>
|
|
1079
|
+
<td> -
|
|
1080
|
+
</td>
|
|
1081
|
+
<td> +
|
|
1082
|
+
</td>
|
|
1083
|
+
</tr>
|
|
1084
|
+
<tr>
|
|
1085
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/multiagent/qmix_vdn.py">QMIX-VDN (MARL)</a>
|
|
1086
|
+
</td>
|
|
1087
|
+
<td> untested
|
|
1088
|
+
</td>
|
|
1089
|
+
<td> +
|
|
1090
|
+
</td>
|
|
1091
|
+
<td> NA
|
|
1092
|
+
</td>
|
|
1093
|
+
<td> +
|
|
1094
|
+
</td>
|
|
1095
|
+
</tr>
|
|
1096
|
+
<tr>
|
|
1097
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/multiagent/sac.py">SAC (MARL)</a>
|
|
1098
|
+
</td>
|
|
1099
|
+
<td> untested
|
|
1100
|
+
</td>
|
|
1101
|
+
<td> +
|
|
1102
|
+
</td>
|
|
1103
|
+
<td> -
|
|
1104
|
+
</td>
|
|
1105
|
+
<td> +
|
|
1106
|
+
</td>
|
|
1107
|
+
</tr>
|
|
1108
|
+
<tr>
|
|
1109
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/examples/rlhf">RLHF</a>
|
|
1110
|
+
</td>
|
|
1111
|
+
<td> NA
|
|
1112
|
+
</td>
|
|
1113
|
+
<td> +
|
|
1114
|
+
</td>
|
|
1115
|
+
<td> NA
|
|
1116
|
+
</td>
|
|
1117
|
+
<td> NA
|
|
1118
|
+
</td>
|
|
1119
|
+
</tr>
|
|
1120
|
+
<tr>
|
|
1121
|
+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/grpo">LLM API (GRPO)</a>
|
|
1122
|
+
</td>
|
|
1123
|
+
<td> NA
|
|
1124
|
+
</td>
|
|
1125
|
+
<td> +
|
|
1126
|
+
</td>
|
|
1127
|
+
<td> +
|
|
1128
|
+
</td>
|
|
1129
|
+
<td> NA
|
|
1130
|
+
</td>
|
|
1131
|
+
</tr>
|
|
1132
|
+
</table>
|
|
1133
|
+
|
|
1134
|
+
** The number indicates expected speed-up compared to eager mode when executed on CPU. Numbers may vary depending on
|
|
1135
|
+
architecture and device.
|
|
1136
|
+
|
|
1137
|
+
and many more to come!
|
|
1138
|
+
|
|
1139
|
+
[Code examples](examples/) displaying toy code snippets and training scripts are also available
|
|
1140
|
+
- [LLM API & GRPO](sota-implementations/grpo) - Complete language model fine-tuning pipeline
|
|
1141
|
+
- [RLHF](examples/rlhf)
|
|
1142
|
+
- [Memory-mapped replay buffers](examples/torchrl_features)
|
|
1143
|
+
|
|
1144
|
+
|
|
1145
|
+
Check the [examples](https://github.com/pytorch/rl/blob/main/sota-implementations/) directory for more details
|
|
1146
|
+
about handling the various configuration settings.
|
|
1147
|
+
|
|
1148
|
+
We also provide [tutorials and demos](https://pytorch.org/rl/stable#tutorials) that give a sense of
|
|
1149
|
+
what the library can do.
|
|
1150
|
+
|
|
1151
|
+
## Citation
|
|
1152
|
+
|
|
1153
|
+
If you're using TorchRL, please refer to this BibTeX entry to cite this work:
|
|
1154
|
+
```
|
|
1155
|
+
@misc{bou2023torchrl,
|
|
1156
|
+
title={TorchRL: A data-driven decision-making library for PyTorch},
|
|
1157
|
+
author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
|
|
1158
|
+
year={2023},
|
|
1159
|
+
eprint={2306.00577},
|
|
1160
|
+
archivePrefix={arXiv},
|
|
1161
|
+
primaryClass={cs.LG}
|
|
1162
|
+
}
|
|
1163
|
+
```
|
|
1164
|
+
|
|
1165
|
+
## Installation
|
|
1166
|
+
|
|
1167
|
+
### Create a new virtual environment:
|
|
1168
|
+
```bash
|
|
1169
|
+
python -m venv torchrl
|
|
1170
|
+
source torchrl/bin/activate # On Windows use: venv\Scripts\activate
|
|
1171
|
+
```
|
|
1172
|
+
|
|
1173
|
+
Or create a conda environment where the packages will be installed.
|
|
1174
|
+
|
|
1175
|
+
```
|
|
1176
|
+
conda create --name torchrl python=3.10
|
|
1177
|
+
conda activate torchrl
|
|
1178
|
+
```
|
|
1179
|
+
|
|
1180
|
+
### Install dependencies:
|
|
1181
|
+
|
|
1182
|
+
#### PyTorch
|
|
1183
|
+
|
|
1184
|
+
Depending on the use of torchrl that you want to make, you may want to
|
|
1185
|
+
install the latest (nightly) PyTorch release or the latest stable version of PyTorch.
|
|
1186
|
+
See [here](https://pytorch.org/get-started/locally/) for a detailed list of commands,
|
|
1187
|
+
including `pip3` or other special installation instructions.
|
|
1188
|
+
|
|
1189
|
+
TorchRL offers a few pre-defined dependencies such as `"torchrl[tests]"`, `"torchrl[atari]"`, `"torchrl[utils]"` etc.
|
|
1190
|
+
|
|
1191
|
+
For the experimental training interface and configuration system, install:
|
|
1192
|
+
```bash
|
|
1193
|
+
pip3 install "torchrl[utils]" # Includes hydra-core and other utilities
|
|
1194
|
+
```
|
|
1195
|
+
|
|
1196
|
+
#### Torchrl
|
|
1197
|
+
|
|
1198
|
+
You can install the **latest stable release** by using
|
|
1199
|
+
```bash
|
|
1200
|
+
pip3 install torchrl
|
|
1201
|
+
```
|
|
1202
|
+
This should work on linux (including AArch64 machines), Windows 10 and OsX (Metal chips only).
|
|
1203
|
+
On certain Windows machines (Windows 11), one should build the library locally.
|
|
1204
|
+
This can be done in two ways:
|
|
1205
|
+
|
|
1206
|
+
```bash
|
|
1207
|
+
# Install and build locally v0.8.1 of the library without cloning
|
|
1208
|
+
pip3 install git+https://github.com/pytorch/rl@v0.8.1
|
|
1209
|
+
# Clone the library and build it locally
|
|
1210
|
+
git clone https://github.com/pytorch/tensordict
|
|
1211
|
+
git clone https://github.com/pytorch/rl
|
|
1212
|
+
pip install -e tensordict
|
|
1213
|
+
pip install -e rl
|
|
1214
|
+
```
|
|
1215
|
+
|
|
1216
|
+
If you use `uv` (instead of `pip`) and you have already installed a specific PyTorch build (e.g. nightly),
|
|
1217
|
+
make sure `uv` doesn't re-resolve dependencies (which can downgrade PyTorch). Use `--no-deps` for the local installs:
|
|
1218
|
+
|
|
1219
|
+
```bash
|
|
1220
|
+
uv pip install --no-deps -e tensordict
|
|
1221
|
+
uv pip install --no-deps -e rl
|
|
1222
|
+
```
|
|
1223
|
+
|
|
1224
|
+
Note that tensordict local build requires `cmake` to be installed via [homebrew](https://brew.sh/) (MacOS) or another package manager
|
|
1225
|
+
such as `apt`, `apt-get`, `conda` or `yum` but NOT `pip`, as well as `pip install "pybind11[global]"`.
|
|
1226
|
+
|
|
1227
|
+
One can also build the wheels to distribute to co-workers using
|
|
1228
|
+
```bash
|
|
1229
|
+
pip install build
|
|
1230
|
+
python -m build --wheel
|
|
1231
|
+
```
|
|
1232
|
+
Your wheels will be stored there `./dist/torchrl<name>.whl` and installable via
|
|
1233
|
+
```bash
|
|
1234
|
+
pip install torchrl<name>.whl
|
|
1235
|
+
```
|
|
1236
|
+
|
|
1237
|
+
The **nightly build** can be installed via
|
|
1238
|
+
```bash
|
|
1239
|
+
pip3 install tensordict-nightly torchrl-nightly
|
|
1240
|
+
```
|
|
1241
|
+
which we currently only ship for Linux machines.
|
|
1242
|
+
Importantly, the nightly builds require the nightly builds of PyTorch too.
|
|
1243
|
+
Also, a local build of torchrl with the nightly build of tensordict may fail - install both nightlies or both local builds but do not mix them.
|
|
1244
|
+
|
|
1245
|
+
|
|
1246
|
+
**Disclaimer**: As of today, TorchRL requires Python 3.10+ and is roughly compatible with any pytorch version >= 2.1. Installing it will not
|
|
1247
|
+
directly require a newer version of pytorch to be installed. Indirectly though, tensordict still requires the latest
|
|
1248
|
+
PyTorch to be installed and we are working hard to loosen that requirement.
|
|
1249
|
+
The C++ binaries of TorchRL (mainly for prioritized replay buffers) will only work with PyTorch 2.7.0 and above.
|
|
1250
|
+
Some features (e.g., working with nested jagged tensors) may also
|
|
1251
|
+
be limited with older versions of pytorch. It is recommended to use the latest TorchRL with the latest PyTorch version
|
|
1252
|
+
unless there is a strong reason not to do so.
|
|
1253
|
+
|
|
1254
|
+
**Optional dependencies**
|
|
1255
|
+
|
|
1256
|
+
The following libraries can be installed depending on the usage one wants to
|
|
1257
|
+
make of torchrl:
|
|
1258
|
+
```
|
|
1259
|
+
# diverse
|
|
1260
|
+
pip3 install tqdm tensorboard "hydra-core>=1.1" hydra-submitit-launcher
|
|
1261
|
+
|
|
1262
|
+
# rendering
|
|
1263
|
+
pip3 install "moviepy<2.0.0"
|
|
1264
|
+
|
|
1265
|
+
# deepmind control suite
|
|
1266
|
+
pip3 install dm_control
|
|
1267
|
+
|
|
1268
|
+
# gym, atari games
|
|
1269
|
+
pip3 install "gym[atari]" "gym[accept-rom-license]" pygame
|
|
1270
|
+
|
|
1271
|
+
# tests
|
|
1272
|
+
pip3 install pytest pyyaml pytest-instafail
|
|
1273
|
+
|
|
1274
|
+
# tensorboard
|
|
1275
|
+
pip3 install tensorboard
|
|
1276
|
+
|
|
1277
|
+
# wandb
|
|
1278
|
+
pip3 install wandb
|
|
1279
|
+
```
|
|
1280
|
+
|
|
1281
|
+
Versioning issues can cause error message of the type ```undefined symbol```
|
|
1282
|
+
and such. For these, refer to the [versioning issues document](https://github.com/pytorch/rl/blob/main/knowledge_base/VERSIONING_ISSUES.md)
|
|
1283
|
+
for a complete explanation and proposed workarounds.
|
|
1284
|
+
|
|
1285
|
+
## Asking a question
|
|
1286
|
+
|
|
1287
|
+
If you spot a bug in the library, please raise an issue in this repo.
|
|
1288
|
+
|
|
1289
|
+
If you have a more generic question regarding RL in PyTorch, post it on
|
|
1290
|
+
the [PyTorch forum](https://discuss.pytorch.org/c/reinforcement-learning/6).
|
|
1291
|
+
|
|
1292
|
+
## Contributing
|
|
1293
|
+
|
|
1294
|
+
Internal collaborations to torchrl are welcome! Feel free to fork, submit issues and PRs.
|
|
1295
|
+
You can checkout the detailed contribution guide [here](https://github.com/pytorch/rl/blob/main/CONTRIBUTING.md).
|
|
1296
|
+
As mentioned above, a list of open contributions can be found in [here](https://github.com/pytorch/rl/issues/509).
|
|
1297
|
+
|
|
1298
|
+
Contributors are recommended to install [pre-commit hooks](https://pre-commit.com/) (using `pre-commit install`). pre-commit will check for linting related issues when the code is committed locally. You can disable th check by appending `-n` to your commit command: `git commit -m <commit message> -n`
|
|
1299
|
+
|
|
1300
|
+
|
|
1301
|
+
## Disclaimer
|
|
1302
|
+
|
|
1303
|
+
This library is released as a PyTorch beta feature.
|
|
1304
|
+
BC-breaking changes are likely to happen but they will be introduced with a deprecation
|
|
1305
|
+
warranty after a few release cycles.
|
|
1306
|
+
|
|
1307
|
+
# License
|
|
1308
|
+
TorchRL is licensed under the MIT License. See [LICENSE](https://github.com/pytorch/rl/blob/main/LICENSE) for details.
|