torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1307 @@
1
+ Metadata-Version: 2.1
2
+ Name: torchrl
3
+ Version: 0.11.0
4
+ Summary: A modular, primitive-first, python-first PyTorch library for Reinforcement Learning
5
+ Author-email: torchrl contributors <vmoens@fb.com>
6
+ Maintainer-email: torchrl contributors <vmoens@fb.com>
7
+ Project-URL: Homepage, https://github.com/pytorch/rl
8
+ Project-URL: Documentation, https://pytorch.org/rl
9
+ Project-URL: Repository, https://github.com/pytorch/rl
10
+ Project-URL: Bug Tracker, https://github.com/pytorch/rl/issues
11
+ Project-URL: twitter, https://x.com/torchrl1
12
+ Project-URL: linkedin, https://www.linkedin.com/company/torchrl
13
+ Project-URL: discord, https://discord.gg/cZs26Qq3Dd
14
+ Project-URL: benchmark, https://docs.pytorch.org/rl/dev/bench/
15
+ Keywords: reinforcement-learning,pytorch,rl,machine-learning
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Programming Language :: Python :: 3.13
20
+ Classifier: Programming Language :: Python :: 3.14
21
+ Classifier: Operating System :: OS Independent
22
+ Classifier: Development Status :: 4 - Beta
23
+ Classifier: Intended Audience :: Developers
24
+ Classifier: Intended Audience :: Science/Research
25
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
26
+ Requires-Python: >=3.10
27
+ Description-Content-Type: text/markdown
28
+ License-File: LICENSE
29
+ Requires-Dist: torch (>=2.1.0)
30
+ Requires-Dist: pyvers
31
+ Requires-Dist: numpy
32
+ Requires-Dist: packaging
33
+ Requires-Dist: cloudpickle
34
+ Requires-Dist: tensordict (<0.12.0,>=0.11.0)
35
+ Provides-Extra: atari
36
+ Requires-Dist: gymnasium[atari] ; extra == 'atari'
37
+ Provides-Extra: brax
38
+ Requires-Dist: jax (>=0.7.0) ; (python_version >= "3.11") and extra == 'brax'
39
+ Requires-Dist: brax ; (python_version >= "3.11") and extra == 'brax'
40
+ Provides-Extra: checkpointing
41
+ Requires-Dist: torchsnapshot ; extra == 'checkpointing'
42
+ Provides-Extra: dev
43
+ Requires-Dist: pre-commit ; extra == 'dev'
44
+ Requires-Dist: autoflake ; extra == 'dev'
45
+ Provides-Extra: dm_control
46
+ Requires-Dist: dm-control ; extra == 'dm_control'
47
+ Provides-Extra: grpo
48
+ Requires-Dist: datasets ; extra == 'grpo'
49
+ Requires-Dist: peft ; extra == 'grpo'
50
+ Requires-Dist: wandb ; extra == 'grpo'
51
+ Requires-Dist: vllm ; extra == 'grpo'
52
+ Requires-Dist: transformers ; extra == 'grpo'
53
+ Requires-Dist: accelerate ; extra == 'grpo'
54
+ Requires-Dist: ray ; extra == 'grpo'
55
+ Requires-Dist: tqdm ; extra == 'grpo'
56
+ Requires-Dist: flash-attn ; extra == 'grpo'
57
+ Requires-Dist: bitsandbytes ; extra == 'grpo'
58
+ Requires-Dist: xformers ; extra == 'grpo'
59
+ Requires-Dist: nltk ; extra == 'grpo'
60
+ Requires-Dist: langdetect ; extra == 'grpo'
61
+ Requires-Dist: immutabledict ; extra == 'grpo'
62
+ Provides-Extra: gym_continuous
63
+ Requires-Dist: gymnasium (<1.0) ; extra == 'gym_continuous'
64
+ Requires-Dist: mujoco ; extra == 'gym_continuous'
65
+ Provides-Extra: llm
66
+ Requires-Dist: transformers ; extra == 'llm'
67
+ Requires-Dist: vllm ; extra == 'llm'
68
+ Requires-Dist: playwright ; extra == 'llm'
69
+ Requires-Dist: datasets ; extra == 'llm'
70
+ Requires-Dist: langdetect ; extra == 'llm'
71
+ Requires-Dist: nltk ; extra == 'llm'
72
+ Requires-Dist: immutabledict ; extra == 'llm'
73
+ Requires-Dist: accelerate ; extra == 'llm'
74
+ Requires-Dist: sentencepiece ; extra == 'llm'
75
+ Requires-Dist: protobuf ; extra == 'llm'
76
+ Requires-Dist: einops ; extra == 'llm'
77
+ Requires-Dist: safetensors ; extra == 'llm'
78
+ Provides-Extra: marl
79
+ Requires-Dist: vmas (>=1.2.10) ; extra == 'marl'
80
+ Requires-Dist: pettingzoo (>=1.24.1) ; extra == 'marl'
81
+ Requires-Dist: dm-meltingpot ; (python_version >= "3.11") and extra == 'marl'
82
+ Provides-Extra: offline-data
83
+ Requires-Dist: huggingface-hub ; extra == 'offline-data'
84
+ Requires-Dist: minari ; extra == 'offline-data'
85
+ Requires-Dist: requests ; extra == 'offline-data'
86
+ Requires-Dist: tqdm ; extra == 'offline-data'
87
+ Requires-Dist: torchvision ; extra == 'offline-data'
88
+ Requires-Dist: scikit-learn ; extra == 'offline-data'
89
+ Requires-Dist: pandas ; extra == 'offline-data'
90
+ Requires-Dist: h5py ; extra == 'offline-data'
91
+ Requires-Dist: pillow ; extra == 'offline-data'
92
+ Provides-Extra: open_spiel
93
+ Requires-Dist: open-spiel (>=1.5) ; extra == 'open_spiel'
94
+ Provides-Extra: procgen
95
+ Requires-Dist: procgen ; extra == 'procgen'
96
+ Provides-Extra: rendering
97
+ Requires-Dist: moviepy (<2.0.0) ; extra == 'rendering'
98
+ Provides-Extra: replay_buffer
99
+ Requires-Dist: torch (>=2.7.0) ; extra == 'replay_buffer'
100
+ Provides-Extra: tests
101
+ Requires-Dist: pytest ; extra == 'tests'
102
+ Requires-Dist: pyyaml ; extra == 'tests'
103
+ Requires-Dist: pytest-instafail ; extra == 'tests'
104
+ Requires-Dist: scipy ; extra == 'tests'
105
+ Requires-Dist: psutil ; extra == 'tests'
106
+ Requires-Dist: pytest-mock ; extra == 'tests'
107
+ Requires-Dist: pytest-cov ; extra == 'tests'
108
+ Requires-Dist: pytest-asyncio ; extra == 'tests'
109
+ Requires-Dist: pytest-benchmark ; extra == 'tests'
110
+ Requires-Dist: pytest-rerunfailures ; extra == 'tests'
111
+ Requires-Dist: pytest-error-for-skips ; extra == 'tests'
112
+ Requires-Dist: pytest-timeout ; extra == 'tests'
113
+ Requires-Dist: pytest-forked ; extra == 'tests'
114
+ Requires-Dist: pytest-random-order ; extra == 'tests'
115
+ Requires-Dist: pytest-repeat ; extra == 'tests'
116
+ Requires-Dist: pytest-isolate ; extra == 'tests'
117
+ Provides-Extra: utils
118
+ Requires-Dist: tensorboard ; extra == 'utils'
119
+ Requires-Dist: wandb ; extra == 'utils'
120
+ Requires-Dist: tqdm ; extra == 'utils'
121
+ Requires-Dist: hydra-core (>=1.1) ; extra == 'utils'
122
+ Requires-Dist: hydra-submitit-launcher ; extra == 'utils'
123
+
124
+ [![Unit-tests](https://github.com/pytorch/rl/actions/workflows/test-linux.yml/badge.svg)](https://github.com/pytorch/rl/actions/workflows/test-linux.yml)
125
+ [![Documentation](https://img.shields.io/badge/Documentation-blue.svg)](https://pytorch.org/rl/)
126
+ [![Benchmarks](https://img.shields.io/badge/Benchmarks-blue.svg)](https://pytorch.github.io/rl/dev/bench/)
127
+ [![codecov](https://codecov.io/gh/pytorch/rl/branch/main/graph/badge.svg?token=HcpK1ILV6r)](https://codecov.io/gh/pytorch/rl)
128
+ [![Twitter Follow](https://img.shields.io/twitter/follow/torchrl1?style=social)](https://twitter.com/torchrl1)
129
+ [![Python version](https://img.shields.io/pypi/pyversions/torchrl.svg)](https://www.python.org/downloads/)
130
+ [![GitHub license](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/pytorch/rl/blob/main/LICENSE)
131
+ <a href="https://pypi.org/project/torchrl"><img src="https://img.shields.io/pypi/v/torchrl" alt="pypi version"></a>
132
+ <a href="https://pypi.org/project/torchrl-nightly"><img src="https://img.shields.io/pypi/v/torchrl-nightly?label=nightly" alt="pypi nightly version"></a>
133
+ [![Downloads](https://static.pepy.tech/personalized-badge/torchrl?period=total&units=international_system&left_color=blue&right_color=orange&left_text=Downloads)](https://pepy.tech/project/torchrl)
134
+ [![Downloads](https://static.pepy.tech/personalized-badge/torchrl-nightly?period=total&units=international_system&left_color=blue&right_color=orange&left_text=Downloads%20(nightly))](https://pepy.tech/project/torchrl-nightly)
135
+ [![Discord Shield](https://dcbadge.vercel.app/api/server/cZs26Qq3Dd)](https://discord.gg/cZs26Qq3Dd)
136
+
137
+ # TorchRL
138
+
139
+ <p align="center">
140
+ <img src="docs/source/_static/img/icon.png" width="200" >
141
+ </p>
142
+
143
+ [**What's New**](#-whats-new) | [**LLM API**](#llm-api---complete-framework-for-language-model-fine-tuning) | [**Getting Started**](#getting-started) | [**Documentation**](#documentation-and-knowledge-base) | [**TensorDict**](#writing-simplified-and-portable-rl-codebase-with-tensordict) |
144
+ [**Features**](#features) | [**Examples, tutorials and demos**](#examples-tutorials-and-demos) | [**Citation**](#citation) | [**Installation**](#installation) |
145
+ [**Asking a question**](#asking-a-question) | [**Contributing**](#contributing)
146
+
147
+ **TorchRL** is an open-source Reinforcement Learning (RL) library for PyTorch.
148
+
149
+ ## ���� What's New
150
+
151
+ ### ���� **Command-Line Training Interface** - Train RL Agents Without Writing Code! (Experimental)
152
+
153
+ TorchRL now provides a **powerful command-line interface** that lets you train state-of-the-art RL agents with simple bash commands! No Python scripting required - just run training with customizable parameters:
154
+
155
+ - ���� **One-Command Training**: `python sota-implementations/ppo_trainer/train.py`
156
+ - ������ **Full Customization**: Override any parameter via command line: `trainer.total_frames=2000000 optimizer.lr=0.0003`
157
+ - ���� **Multi-Environment Support**: Switch between Gym, Brax, DM Control, and more with `env=gym training_env.create_env_fn.base_env.env_name=HalfCheetah-v4`
158
+ - ���� **Built-in Logging**: TensorBoard, Weights & Biases, CSV logging out of the box
159
+ - ���� **Hydra-Powered**: Leverages Hydra's powerful configuration system for maximum flexibility
160
+ - ������������� **Production Ready**: Same robust training pipeline as our SOTA implementations
161
+
162
+ **Perfect for**: Researchers, practitioners, and anyone who wants to train RL agents without diving into implementation details.
163
+
164
+ ������ **Note**: This is an experimental feature. The API may change in future versions. We welcome feedback and contributions to help improve this implementation!
165
+
166
+ ���� **Prerequisites**: The training interface requires Hydra for configuration management. Install with:
167
+ ```bash
168
+ pip install "torchrl[utils]"
169
+ # or manually:
170
+ pip install hydra-core omegaconf
171
+ ```
172
+
173
+ Check out the [complete CLI documentation](https://github.com/pytorch/rl/tree/main/sota-implementations/ppo_trainer) to get started!
174
+
175
+ ### ���� **vLLM Revamp** - Major Enhancement to LLM Infrastructure (v0.10)
176
+
177
+ This release introduces a comprehensive revamp of TorchRL's vLLM integration, delivering significant improvements in performance, scalability, and usability for large language model inference and training workflows:
178
+
179
+ - ���� **AsyncVLLM Service**: Production-ready distributed vLLM inference with multi-replica scaling and automatic Ray actor management
180
+ - ������ **Multiple Load Balancing Strategies**: Routing strategies including prefix-aware, request-based, and KV-cache load balancing for optimal performance
181
+ - ������� **Unified vLLM Architecture**: New `RLvLLMEngine` interface standardizing all vLLM backends with simplified `vLLMUpdaterV2` for seamless weight updates
182
+ - ���� **Distributed Data Loading**: New `RayDataLoadingPrimer` for shared, distributed data loading across multiple environments
183
+ - ���� **Enhanced Performance**: Native vLLM batching, concurrent request processing, and optimized resource allocation via Ray placement groups
184
+
185
+ ```python
186
+ # Simple AsyncVLLM usage - production ready!
187
+ from torchrl.modules.llm import AsyncVLLM, vLLMWrapper
188
+
189
+ # Create distributed vLLM service with load balancing
190
+ service = AsyncVLLM.from_pretrained(
191
+ "Qwen/Qwen2.5-7B",
192
+ num_devices=2, # Tensor parallel across 2 GPUs
193
+ num_replicas=4, # 4 replicas for high throughput
194
+ max_model_len=4096
195
+ )
196
+
197
+ # Use with TorchRL's LLM wrappers
198
+ wrapper = vLLMWrapper(service, input_mode="history")
199
+
200
+ # Simplified weight updates
201
+ from torchrl.collectors.llm import vLLMUpdaterV2
202
+ updater = vLLMUpdaterV2(service) # Auto-configures from engine
203
+ ```
204
+
205
+ This revamp positions TorchRL as the leading platform for scalable LLM inference and training, providing production-ready tools for both research and deployment scenarios.
206
+
207
+ ### ���� PPOTrainer (Experimental) - High-Level Training Interface
208
+
209
+ TorchRL now includes an **experimental PPOTrainer** that provides a complete, configurable PPO training solution! This prototype feature combines TorchRL's modular components into a cohesive training system with sensible defaults:
210
+
211
+ - ���� **Complete Training Pipeline**: Handles environment setup, data collection, loss computation, and optimization automatically
212
+ - ������ **Extensive Configuration**: Comprehensive Hydra-based config system for easy experimentation and hyperparameter tuning
213
+ - ���� **Built-in Logging**: Automatic tracking of rewards, actions, episode completion rates, and training statistics
214
+ - ���� **Modular Design**: Built on existing TorchRL components (collectors, losses, replay buffers) for maximum flexibility
215
+ - ���� **Minimal Code**: Complete SOTA implementation in [just ~20 lines](sota-implementations/ppo_trainer/train.py)!
216
+
217
+ **Working Example**: See [`sota-implementations/ppo_trainer/`](sota-implementations/ppo_trainer/) for a complete, working PPO implementation that trains on Pendulum-v1 with full Hydra configuration support.
218
+
219
+ **Prerequisites**: Requires Hydra for configuration management: `pip install "torchrl[utils]"`
220
+
221
+ <details>
222
+ <summary>Complete Training Script (sota-implementations/ppo_trainer/train.py)</summary>
223
+
224
+ ```python
225
+ import hydra
226
+ from torchrl.trainers.algorithms.configs import *
227
+
228
+ @hydra.main(config_path="config", config_name="config", version_base="1.1")
229
+ def main(cfg):
230
+ trainer = hydra.utils.instantiate(cfg.trainer)
231
+ trainer.train()
232
+
233
+ if __name__ == "__main__":
234
+ main()
235
+ ```
236
+ *Complete PPO training in ~20 lines with full configurability.*
237
+
238
+ </details>
239
+
240
+ <details>
241
+ <summary>API Usage Examples</summary>
242
+
243
+ ```bash
244
+ # Basic usage - train PPO on Pendulum-v1 with default settings
245
+ python sota-implementations/ppo_trainer/train.py
246
+
247
+ # Custom configuration with command-line overrides
248
+ python sota-implementations/ppo_trainer/train.py \
249
+ trainer.total_frames=2000000 \
250
+ training_env.create_env_fn.base_env.env_name=HalfCheetah-v4 \
251
+ networks.policy_network.num_cells=[256,256] \
252
+ optimizer.lr=0.0003
253
+
254
+ # Use different environment and logger
255
+ python sota-implementations/ppo_trainer/train.py \
256
+ env=gym \
257
+ training_env.create_env_fn.base_env.env_name=Walker2d-v4 \
258
+ logger=tensorboard
259
+
260
+ # See all available options
261
+ python sota-implementations/ppo_trainer/train.py --help
262
+ ```
263
+
264
+ </details>
265
+
266
+ **Future Plans**: Additional algorithm trainers (SAC, TD3, DQN) and full integration of all TorchRL components within the configuration system are planned for upcoming releases.
267
+
268
+ ## LLM API - Complete Framework for Language Model Fine-tuning
269
+
270
+ TorchRL includes a comprehensive **LLM API** for post-training and fine-tuning of language models! This framework provides everything you need for RLHF, supervised fine-tuning, and tool-augmented training:
271
+
272
+ - ���� **Unified LLM Wrappers**: Seamless integration with Hugging Face models and vLLM inference engines
273
+ - ���� **Conversation Management**: Advanced [`History`](torchrl/data/llm/history.py) class for multi-turn dialogue with automatic chat template detection
274
+ - ������� **Tool Integration**: [Built-in support](torchrl/envs/llm/transforms/) for Python code execution, function calling, and custom tool transforms
275
+ - ���� **Specialized Objectives**: [GRPO](torchrl/objectives/llm/grpo.py) (Group Relative Policy Optimization) and [SFT](torchrl/objectives/llm/sft.py) loss functions optimized for language models
276
+ - ��� **High-Performance Collectors**: [Async data collection](torchrl/collectors/llm/) with distributed training support
277
+ - ���� **Flexible Environments**: Transform-based architecture for reward computation, data loading, and conversation augmentation
278
+
279
+ The LLM API follows TorchRL's modular design principles, allowing you to mix and match components for your specific use case. Check out the [complete documentation](https://pytorch.org/rl/main/reference/llms.html) and [GRPO implementation example](https://github.com/pytorch/rl/tree/main/sota-implementations/grpo) to get started!
280
+
281
+ <details>
282
+ <summary>Quick LLM API Example</summary>
283
+
284
+ ```python
285
+ from torchrl.envs.llm import ChatEnv
286
+ from torchrl.modules.llm import TransformersWrapper
287
+ from torchrl.objectives.llm import GRPOLoss
288
+ from torchrl.collectors.llm import LLMCollector
289
+
290
+ # Create environment with Python tool execution
291
+ env = ChatEnv(
292
+ tokenizer=tokenizer,
293
+ system_prompt="You are an assistant that can execute Python code.",
294
+ batch_size=[1]
295
+ ).append_transform(PythonInterpreter())
296
+
297
+ # Wrap your language model
298
+ llm = TransformersWrapper(
299
+ model=model,
300
+ tokenizer=tokenizer,
301
+ input_mode="history"
302
+ )
303
+
304
+ # Set up GRPO training
305
+ loss_fn = GRPOLoss(llm, critic, gamma=0.99)
306
+ collector = LLMCollector(env, llm, frames_per_batch=100)
307
+
308
+ # Training loop
309
+ for data in collector:
310
+ loss = loss_fn(data)
311
+ loss.backward()
312
+ optimizer.step()
313
+ ```
314
+
315
+ </details>
316
+
317
+ ## Key features
318
+
319
+ - ���� **Python-first**: Designed with Python as the primary language for ease of use and flexibility
320
+ - ������ **Efficient**: Optimized for performance to support demanding RL research applications
321
+ - ���� **Modular, customizable, extensible**: Highly modular architecture allows for easy swapping, transformation, or creation of new components
322
+ - ���� **Documented**: Thorough documentation ensures that users can quickly understand and utilize the library
323
+ - ��� **Tested**: Rigorously tested to ensure reliability and stability
324
+ - ������ **Reusable functionals**: Provides a set of highly reusable functions for cost functions, returns, and data processing
325
+
326
+ ### Design Principles
327
+
328
+ - ���� **Aligns with PyTorch ecosystem**: Follows the structure and conventions of popular PyTorch libraries
329
+ (e.g., dataset pillar, transforms, models, data utilities)
330
+ - ��� Minimal dependencies: Only requires Python standard library, NumPy, and PyTorch; optional dependencies for
331
+ common environment libraries (e.g., OpenAI Gym) and datasets (D4RL, OpenX...)
332
+
333
+ Read the [full paper](https://arxiv.org/abs/2306.00577) for a more curated description of the library.
334
+
335
+ ## Getting started
336
+
337
+ Check our [Getting Started tutorials](https://pytorch.org/rl/stable/index.html#getting-started) for quickly ramp up with the basic
338
+ features of the library!
339
+
340
+ <p align="center">
341
+ <img src="docs/ppo.png" width="800" >
342
+ </p>
343
+
344
+ ## Documentation and knowledge base
345
+
346
+ The TorchRL documentation can be found [here](https://pytorch.org/rl).
347
+ It contains tutorials and the API reference.
348
+
349
+ TorchRL also provides a RL knowledge base to help you debug your code, or simply
350
+ learn the basics of RL. Check it out [here](https://pytorch.org/rl/stable/reference/knowledge_base.html).
351
+
352
+ We have some introductory videos for you to get to know the library better, check them out:
353
+
354
+ - [TalkRL podcast](https://www.talkrl.com/episodes/vincent-moens-on-torchrl)
355
+ - [TorchRL intro at PyTorch day 2022](https://youtu.be/cIKMhZoykEE)
356
+ - [PyTorch 2.0 Q&A: TorchRL](https://www.youtube.com/live/myEfUoYrbts?feature=share)
357
+
358
+ ## Spotlight publications
359
+
360
+ TorchRL being domain-agnostic, you can use it across many different fields. Here are a few examples:
361
+
362
+ - [ACEGEN](https://pubs.acs.org/doi/10.1021/acs.jcim.4c00895): Reinforcement Learning of Generative Chemical Agents
363
+ for Drug Discovery
364
+ - [BenchMARL](https://www.jmlr.org/papers/v25/23-1612.html): Benchmarking Multi-Agent Reinforcement Learning
365
+ - [BricksRL](https://arxiv.org/abs/2406.17490): A Platform for Democratizing Robotics and Reinforcement Learning
366
+ Research and Education with LEGO
367
+ - [OmniDrones](https://ieeexplore.ieee.org/abstract/document/10409589): An Efficient and Flexible Platform for Reinforcement Learning in Drone Control
368
+ - [RL4CO](https://arxiv.org/abs/2306.17100): an Extensive Reinforcement Learning for Combinatorial Optimization Benchmark
369
+ - [Robohive](https://proceedings.neurips.cc/paper_files/paper/2023/file/8a84a4341c375b8441b36836bb343d4e-Paper-Datasets_and_Benchmarks.pdf): A unified framework for robot learning
370
+
371
+ ## Writing simplified and portable RL codebase with `TensorDict`
372
+
373
+ RL algorithms are very heterogeneous, and it can be hard to recycle a codebase
374
+ across settings (e.g. from online to offline, from state-based to pixel-based
375
+ learning).
376
+ TorchRL solves this problem through [`TensorDict`](https://github.com/pytorch/tensordict/),
377
+ a convenient data structure<sup>(1)</sup> that can be used to streamline one's
378
+ RL codebase.
379
+ With this tool, one can write a *complete PPO training script in less than 100
380
+ lines of code*!
381
+
382
+ <details>
383
+ <summary>Code</summary>
384
+
385
+ ```python
386
+ import torch
387
+ from tensordict.nn import TensorDictModule
388
+ from tensordict.nn.distributions import NormalParamExtractor
389
+ from torch import nn
390
+
391
+ from torchrl.collectors import SyncDataCollector
392
+ from torchrl.data.replay_buffers import TensorDictReplayBuffer, \
393
+ LazyTensorStorage, SamplerWithoutReplacement
394
+ from torchrl.envs.libs.gym import GymEnv
395
+ from torchrl.modules import ProbabilisticActor, ValueOperator, TanhNormal
396
+ from torchrl.objectives import ClipPPOLoss
397
+ from torchrl.objectives.value import GAE
398
+
399
+ env = GymEnv("Pendulum-v1")
400
+ model = TensorDictModule(
401
+ nn.Sequential(
402
+ nn.Linear(3, 128), nn.Tanh(),
403
+ nn.Linear(128, 128), nn.Tanh(),
404
+ nn.Linear(128, 128), nn.Tanh(),
405
+ nn.Linear(128, 2),
406
+ NormalParamExtractor()
407
+ ),
408
+ in_keys=["observation"],
409
+ out_keys=["loc", "scale"]
410
+ )
411
+ critic = ValueOperator(
412
+ nn.Sequential(
413
+ nn.Linear(3, 128), nn.Tanh(),
414
+ nn.Linear(128, 128), nn.Tanh(),
415
+ nn.Linear(128, 128), nn.Tanh(),
416
+ nn.Linear(128, 1),
417
+ ),
418
+ in_keys=["observation"],
419
+ )
420
+ actor = ProbabilisticActor(
421
+ model,
422
+ in_keys=["loc", "scale"],
423
+ distribution_class=TanhNormal,
424
+ distribution_kwargs={"low": -1.0, "high": 1.0},
425
+ return_log_prob=True
426
+ )
427
+ buffer = TensorDictReplayBuffer(
428
+ storage=LazyTensorStorage(1000),
429
+ sampler=SamplerWithoutReplacement(),
430
+ batch_size=50,
431
+ )
432
+ collector = SyncDataCollector(
433
+ env,
434
+ actor,
435
+ frames_per_batch=1000,
436
+ total_frames=1_000_000,
437
+ )
438
+ loss_fn = ClipPPOLoss(actor, critic)
439
+ adv_fn = GAE(value_network=critic, average_gae=True, gamma=0.99, lmbda=0.95)
440
+ optim = torch.optim.Adam(loss_fn.parameters(), lr=2e-4)
441
+
442
+ for data in collector: # collect data
443
+ for epoch in range(10):
444
+ adv_fn(data) # compute advantage
445
+ buffer.extend(data)
446
+ for sample in buffer: # consume data
447
+ loss_vals = loss_fn(sample)
448
+ loss_val = sum(
449
+ value for key, value in loss_vals.items() if
450
+ key.startswith("loss")
451
+ )
452
+ loss_val.backward()
453
+ optim.step()
454
+ optim.zero_grad()
455
+ print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}")
456
+ ```
457
+ </details>
458
+
459
+ Here is an example of how the [environment API](https://pytorch.org/rl/stable/reference/envs.html)
460
+ relies on tensordict to carry data from one function to another during a rollout
461
+ execution:
462
+ ![Alt Text](https://github.com/pytorch/rl/blob/main/docs/source/_static/img/rollout.gif)
463
+
464
+ `TensorDict` makes it easy to re-use pieces of code across environments, models and
465
+ algorithms.
466
+ <details>
467
+ <summary>Code</summary>
468
+
469
+ For instance, here's how to code a rollout in TorchRL:
470
+
471
+ ```diff
472
+ - obs, done = env.reset()
473
+ + tensordict = env.reset()
474
+ policy = SafeModule(
475
+ model,
476
+ in_keys=["observation_pixels", "observation_vector"],
477
+ out_keys=["action"],
478
+ )
479
+ out = []
480
+ for i in range(n_steps):
481
+ - action, log_prob = policy(obs)
482
+ - next_obs, reward, done, info = env.step(action)
483
+ - out.append((obs, next_obs, action, log_prob, reward, done))
484
+ - obs = next_obs
485
+ + tensordict = policy(tensordict)
486
+ + tensordict = env.step(tensordict)
487
+ + out.append(tensordict)
488
+ + tensordict = step_mdp(tensordict) # renames next_observation_* keys to observation_*
489
+ - obs, next_obs, action, log_prob, reward, done = [torch.stack(vals, 0) for vals in zip(*out)]
490
+ + out = torch.stack(out, 0) # TensorDict supports multiple tensor operations
491
+ ```
492
+ </details>
493
+
494
+ Using this, TorchRL abstracts away the input / output signatures of the modules, env,
495
+ collectors, replay buffers and losses of the library, allowing all primitives
496
+ to be easily recycled across settings.
497
+
498
+ <details>
499
+ <summary>Code</summary>
500
+
501
+ Here's another example of an off-policy training loop in TorchRL (assuming
502
+ that a data collector, a replay buffer, a loss and an optimizer have been instantiated):
503
+
504
+ ```diff
505
+ - for i, (obs, next_obs, action, hidden_state, reward, done) in enumerate(collector):
506
+ + for i, tensordict in enumerate(collector):
507
+ - replay_buffer.add((obs, next_obs, action, log_prob, reward, done))
508
+ + replay_buffer.add(tensordict)
509
+ for j in range(num_optim_steps):
510
+ - obs, next_obs, action, hidden_state, reward, done = replay_buffer.sample(batch_size)
511
+ - loss = loss_fn(obs, next_obs, action, hidden_state, reward, done)
512
+ + tensordict = replay_buffer.sample(batch_size)
513
+ + loss = loss_fn(tensordict)
514
+ loss.backward()
515
+ optim.step()
516
+ optim.zero_grad()
517
+ ```
518
+ This training loop can be re-used across algorithms as it makes a minimal number of assumptions about the structure of the data.
519
+ </details>
520
+
521
+ TensorDict supports multiple tensor operations on its device and shape
522
+ (the shape of TensorDict, or its batch size, is the common arbitrary N first dimensions of all its contained tensors):
523
+
524
+ <details>
525
+ <summary>Code</summary>
526
+
527
+ ```python
528
+ # stack and cat
529
+ tensordict = torch.stack(list_of_tensordicts, 0)
530
+ tensordict = torch.cat(list_of_tensordicts, 0)
531
+ # reshape
532
+ tensordict = tensordict.view(-1)
533
+ tensordict = tensordict.permute(0, 2, 1)
534
+ tensordict = tensordict.unsqueeze(-1)
535
+ tensordict = tensordict.squeeze(-1)
536
+ # indexing
537
+ tensordict = tensordict[:2]
538
+ tensordict[:, 2] = sub_tensordict
539
+ # device and memory location
540
+ tensordict.cuda()
541
+ tensordict.to("cuda:1")
542
+ tensordict.share_memory_()
543
+ ```
544
+ </details>
545
+
546
+ TensorDict comes with a dedicated [`tensordict.nn`](https://pytorch.github.io/tensordict/reference/nn.html)
547
+ module that contains everything you might need to write your model with it.
548
+ And it is `functorch` and `torch.compile` compatible!
549
+
550
+ <details>
551
+ <summary>Code</summary>
552
+
553
+ ```diff
554
+ transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
555
+ + td_module = SafeModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"])
556
+ src = torch.rand((10, 32, 512))
557
+ tgt = torch.rand((20, 32, 512))
558
+ + tensordict = TensorDict({"src": src, "tgt": tgt}, batch_size=[20, 32])
559
+ - out = transformer_model(src, tgt)
560
+ + td_module(tensordict)
561
+ + out = tensordict["out"]
562
+ ```
563
+
564
+ The `TensorDictSequential` class allows to branch sequences of `nn.Module` instances in a highly modular way.
565
+ For instance, here is an implementation of a transformer using the encoder and decoder blocks:
566
+ ```python
567
+ encoder_module = TransformerEncoder(...)
568
+ encoder = TensorDictSequential(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"])
569
+ decoder_module = TransformerDecoder(...)
570
+ decoder = TensorDictModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"])
571
+ transformer = TensorDictSequential(encoder, decoder)
572
+ assert transformer.in_keys == ["src", "src_mask", "tgt"]
573
+ assert transformer.out_keys == ["memory", "output"]
574
+ ```
575
+
576
+ `TensorDictSequential` allows to isolate subgraphs by querying a set of desired input / output keys:
577
+ ```python
578
+ transformer.select_subsequence(out_keys=["memory"]) # returns the encoder
579
+ transformer.select_subsequence(in_keys=["tgt", "memory"]) # returns the decoder
580
+ ```
581
+ </details>
582
+
583
+ Check [TensorDict tutorials](https://pytorch.github.io/tensordict/) to
584
+ learn more!
585
+
586
+
587
+ ## Features
588
+
589
+ - A common [interface for environments](https://github.com/pytorch/rl/blob/main/torchrl/envs)
590
+ which supports common libraries (OpenAI gym, deepmind control lab, etc.)<sup>(1)</sup> and state-less execution
591
+ (e.g. Model-based environments).
592
+ The [batched environments](https://github.com/pytorch/rl/blob/main/torchrl/envs/batched_envs.py) containers allow parallel execution<sup>(2)</sup>.
593
+ A common PyTorch-first class of [tensor-specification class](https://github.com/pytorch/rl/blob/main/torchrl/data/tensor_specs.py) is also provided.
594
+ TorchRL's environments API is simple but stringent and specific. Check the
595
+ [documentation](https://pytorch.org/rl/stable/reference/envs.html)
596
+ and [tutorial](https://pytorch.org/rl/stable/tutorials/pendulum.html) to learn more!
597
+ <details>
598
+ <summary>Code</summary>
599
+
600
+ ```python
601
+ env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
602
+ env_parallel = ParallelEnv(4, env_make) # creates 4 envs in parallel
603
+ tensordict = env_parallel.rollout(max_steps=20, policy=None) # random rollout (no policy given)
604
+ assert tensordict.shape == [4, 20] # 4 envs, 20 steps rollout
605
+ env_parallel.action_spec.is_in(tensordict["action"]) # spec check returns True
606
+ ```
607
+ </details>
608
+
609
+ - multiprocess and distributed [data collectors](https://github.com/pytorch/rl/blob/main/torchrl/collectors/collectors.py)<sup>(2)</sup>
610
+ that work synchronously or asynchronously.
611
+ Through the use of TensorDict, TorchRL's training loops are made very similar
612
+ to regular training loops in supervised
613
+ learning (although the "dataloader" -- read data collector -- is modified on-the-fly):
614
+ <details>
615
+ <summary>Code</summary>
616
+
617
+ ```python
618
+ env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
619
+ collector = MultiaSyncDataCollector(
620
+ [env_make, env_make],
621
+ policy=policy,
622
+ devices=["cuda:0", "cuda:0"],
623
+ total_frames=10000,
624
+ frames_per_batch=50,
625
+ ...
626
+ )
627
+ for i, tensordict_data in enumerate(collector):
628
+ loss = loss_module(tensordict_data)
629
+ loss.backward()
630
+ optim.step()
631
+ optim.zero_grad()
632
+ collector.update_policy_weights_()
633
+ ```
634
+ </details>
635
+
636
+ Check our [distributed collector examples](https://github.com/pytorch/rl/blob/main/examples/distributed/collectors) to
637
+ learn more about ultra-fast data collection with TorchRL.
638
+
639
+ - efficient<sup>(2)</sup> and generic<sup>(1)</sup> [replay buffers](https://github.com/pytorch/rl/blob/main/torchrl/data/replay_buffers/replay_buffers.py) with modularized storage:
640
+ <details>
641
+ <summary>Code</summary>
642
+
643
+ ```python
644
+ storage = LazyMemmapStorage( # memory-mapped (physical) storage
645
+ cfg.buffer_size,
646
+ scratch_dir="/tmp/"
647
+ )
648
+ buffer = TensorDictPrioritizedReplayBuffer(
649
+ alpha=0.7,
650
+ beta=0.5,
651
+ collate_fn=lambda x: x,
652
+ pin_memory=device != torch.device("cpu"),
653
+ prefetch=10, # multi-threaded sampling
654
+ storage=storage
655
+ )
656
+ ```
657
+ </details>
658
+
659
+ Replay buffers are also offered as wrappers around common datasets for *offline RL*:
660
+ <details>
661
+ <summary>Code</summary>
662
+
663
+ ```python
664
+ from torchrl.data.replay_buffers import SamplerWithoutReplacement
665
+ from torchrl.data.datasets.d4rl import D4RLExperienceReplay
666
+ data = D4RLExperienceReplay(
667
+ "maze2d-open-v0",
668
+ split_trajs=True,
669
+ batch_size=128,
670
+ sampler=SamplerWithoutReplacement(drop_last=True),
671
+ )
672
+ for sample in data: # or alternatively sample = data.sample()
673
+ fun(sample)
674
+ ```
675
+ </details>
676
+
677
+
678
+ - cross-library [environment transforms](https://github.com/pytorch/rl/blob/main/torchrl/envs/transforms/transforms.py)<sup>(1)</sup>,
679
+ executed on device and in a vectorized fashion<sup>(2)</sup>,
680
+ which process and prepare the data coming out of the environments to be used by the agent:
681
+ <details>
682
+ <summary>Code</summary>
683
+
684
+ ```python
685
+ env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
686
+ env_base = ParallelEnv(4, env_make, device="cuda:0") # creates 4 envs in parallel
687
+ env = TransformedEnv(
688
+ env_base,
689
+ Compose(
690
+ ToTensorImage(),
691
+ ObservationNorm(loc=0.5, scale=1.0)), # executes the transforms once and on device
692
+ )
693
+ tensordict = env.reset()
694
+ assert tensordict.device == torch.device("cuda:0")
695
+ ```
696
+ Other transforms include: reward scaling (`RewardScaling`), shape operations (concatenation of tensors, unsqueezing etc.), concatenation of
697
+ successive operations (`CatFrames`), resizing (`Resize`) and many more.
698
+
699
+ Unlike other libraries, the transforms are stacked as a list (and not wrapped in each other), which makes it
700
+ easy to add and remove them at will:
701
+ ```python
702
+ env.insert_transform(0, NoopResetEnv()) # inserts the NoopResetEnv transform at the index 0
703
+ ```
704
+ Nevertheless, transforms can access and execute operations on the parent environment:
705
+ ```python
706
+ transform = env.transform[1] # gathers the second transform of the list
707
+ parent_env = transform.parent # returns the base environment of the second transform, i.e. the base env + the first transform
708
+ ```
709
+ </details>
710
+
711
+ - various tools for distributed learning (e.g. [memory mapped tensors](https://github.com/pytorch/tensordict/blob/main/tensordict/memmap.py))<sup>(2)</sup>;
712
+ - various [architectures](https://github.com/pytorch/rl/blob/main/torchrl/modules/models/) and models (e.g. [actor-critic](https://github.com/pytorch/rl/blob/main/torchrl/modules/tensordict_module/actors.py))<sup>(1)</sup>:
713
+ <details>
714
+ <summary>Code</summary>
715
+
716
+ ```python
717
+ # create an nn.Module
718
+ common_module = ConvNet(
719
+ bias_last_layer=True,
720
+ depth=None,
721
+ num_cells=[32, 64, 64],
722
+ kernel_sizes=[8, 4, 3],
723
+ strides=[4, 2, 1],
724
+ )
725
+ # Wrap it in a SafeModule, indicating what key to read in and where to
726
+ # write out the output
727
+ common_module = SafeModule(
728
+ common_module,
729
+ in_keys=["pixels"],
730
+ out_keys=["hidden"],
731
+ )
732
+ # Wrap the policy module in NormalParamsWrapper, such that the output
733
+ # tensor is split in loc and scale, and scale is mapped onto a positive space
734
+ policy_module = SafeModule(
735
+ NormalParamsWrapper(
736
+ MLP(num_cells=[64, 64], out_features=32, activation=nn.ELU)
737
+ ),
738
+ in_keys=["hidden"],
739
+ out_keys=["loc", "scale"],
740
+ )
741
+ # Use a SafeProbabilisticTensorDictSequential to combine the SafeModule with a
742
+ # SafeProbabilisticModule, indicating how to build the
743
+ # torch.distribution.Distribution object and what to do with it
744
+ policy_module = SafeProbabilisticTensorDictSequential( # stochastic policy
745
+ policy_module,
746
+ SafeProbabilisticModule(
747
+ in_keys=["loc", "scale"],
748
+ out_keys="action",
749
+ distribution_class=TanhNormal,
750
+ ),
751
+ )
752
+ value_module = MLP(
753
+ num_cells=[64, 64],
754
+ out_features=1,
755
+ activation=nn.ELU,
756
+ )
757
+ # Wrap the policy and value funciton in a common module
758
+ actor_value = ActorValueOperator(common_module, policy_module, value_module)
759
+ # standalone policy from this
760
+ standalone_policy = actor_value.get_policy_operator()
761
+ ```
762
+ </details>
763
+
764
+ - exploration [wrappers](https://github.com/pytorch/rl/blob/main/torchrl/modules/tensordict_module/exploration.py) and
765
+ [modules](https://github.com/pytorch/rl/blob/main/torchrl/modules/models/exploration.py) to easily swap between exploration and exploitation<sup>(1)</sup>:
766
+ <details>
767
+ <summary>Code</summary>
768
+
769
+ ```python
770
+ policy_explore = EGreedyWrapper(policy)
771
+ with set_exploration_type(ExplorationType.RANDOM):
772
+ tensordict = policy_explore(tensordict) # will use eps-greedy
773
+ with set_exploration_type(ExplorationType.DETERMINISTIC):
774
+ tensordict = policy_explore(tensordict) # will not use eps-greedy
775
+ ```
776
+ </details>
777
+
778
+ - A series of efficient [loss modules](https://github.com/pytorch/rl/tree/main/torchrl/objectives)
779
+ and highly vectorized
780
+ [functional return and advantage](https://github.com/pytorch/rl/blob/main/torchrl/objectives/value/functional.py)
781
+ computation.
782
+
783
+ <details>
784
+ <summary>Code</summary>
785
+
786
+ ### Loss modules
787
+ ```python
788
+ from torchrl.objectives import DQNLoss
789
+ loss_module = DQNLoss(value_network=value_network, gamma=0.99)
790
+ tensordict = replay_buffer.sample(batch_size)
791
+ loss = loss_module(tensordict)
792
+ ```
793
+
794
+ ### Advantage computation
795
+ ```python
796
+ from torchrl.objectives.value.functional import vec_td_lambda_return_estimate
797
+ advantage = vec_td_lambda_return_estimate(gamma, lmbda, next_state_value, reward, done, terminated)
798
+ ```
799
+
800
+ </details>
801
+
802
+ - a generic [trainer class](https://github.com/pytorch/rl/blob/main/torchrl/trainers/trainers.py)<sup>(1)</sup> that
803
+ executes the aforementioned training loop. Through a hooking mechanism,
804
+ it also supports any logging or data transformation operation at any given
805
+ time.
806
+
807
+ - various [recipes](https://github.com/pytorch/rl/blob/main/torchrl/trainers/helpers/models.py) to build models that
808
+ correspond to the environment being deployed.
809
+
810
+ - **LLM API**: Complete framework for language model fine-tuning with unified wrappers for Hugging Face and vLLM backends,
811
+ conversation management with automatic chat template detection, tool integration (Python execution, function calling),
812
+ specialized objectives (GRPO, SFT), and high-performance async collectors. Perfect for RLHF, supervised fine-tuning,
813
+ and tool-augmented training scenarios.
814
+ <details>
815
+ <summary>Code</summary>
816
+
817
+ ```python
818
+ from torchrl.envs.llm import ChatEnv
819
+ from torchrl.modules.llm import TransformersWrapper
820
+ from torchrl.envs.llm.transforms import PythonInterpreter
821
+
822
+ # Create environment with tool execution
823
+ env = ChatEnv(
824
+ tokenizer=tokenizer,
825
+ system_prompt="You can execute Python code.",
826
+ batch_size=[1]
827
+ ).append_transform(PythonInterpreter())
828
+
829
+ # Wrap language model for training
830
+ llm = TransformersWrapper(
831
+ model=model,
832
+ tokenizer=tokenizer,
833
+ input_mode="history"
834
+ )
835
+
836
+ # Multi-turn conversation with tool use
837
+ obs = env.reset(TensorDict({"query": "Calculate 2+2"}, batch_size=[1]))
838
+ llm_output = llm(obs) # Generates response
839
+ obs = env.step(llm_output) # Environment processes response
840
+ ```
841
+ </details>
842
+
843
+ If you feel a feature is missing from the library, please submit an issue!
844
+ If you would like to contribute to new features, check our [call for contributions](https://github.com/pytorch/rl/issues/509) and our [contribution](https://github.com/pytorch/rl/blob/main/CONTRIBUTING.md) page.
845
+
846
+
847
+ ## Examples, tutorials and demos
848
+
849
+ A series of [State-of-the-Art implementations](https://github.com/pytorch/rl/blob/main/sota-implementations/) are provided with an illustrative purpose:
850
+
851
+ <table>
852
+ <tr>
853
+ <td><strong>Algorithm</strong>
854
+ </td>
855
+ <td><strong>Compile Support**</strong>
856
+ </td>
857
+ <td><strong>Tensordict-free API</strong>
858
+ </td>
859
+ <td><strong>Modular Losses</strong>
860
+ </td>
861
+ <td><strong>Continuous and Discrete</strong>
862
+ </td>
863
+ </tr>
864
+ <tr>
865
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/dqn">DQN</a>
866
+ </td>
867
+ <td> 1.9x
868
+ </td>
869
+ <td> +
870
+ </td>
871
+ <td> NA
872
+ </td>
873
+ <td> + (through <a href="https://pytorch.org/rl/stable/reference/generated/torchrl.envs.transforms.ActionDiscretizer.html?highlight=actiondiscretizer">ActionDiscretizer</a> transform)
874
+ </td>
875
+ </tr>
876
+ <tr>
877
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/ddpg/ddpg.py">DDPG</a>
878
+ </td>
879
+ <td> 1.87x
880
+ </td>
881
+ <td> +
882
+ </td>
883
+ <td> +
884
+ </td>
885
+ <td> - (continuous only)
886
+ </td>
887
+ </tr>
888
+ <tr>
889
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/iql/">IQL</a>
890
+ </td>
891
+ <td> 3.22x
892
+ </td>
893
+ <td> +
894
+ </td>
895
+ <td> +
896
+ </td>
897
+ <td> +
898
+ </td>
899
+ </tr>
900
+ <tr>
901
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/cql/cql_offline.py">CQL</a>
902
+ </td>
903
+ <td> 2.68x
904
+ </td>
905
+ <td> +
906
+ </td>
907
+ <td> +
908
+ </td>
909
+ <td> +
910
+ </td>
911
+ </tr>
912
+ <tr>
913
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/td3/td3.py">TD3</a>
914
+ </td>
915
+ <td> 2.27x
916
+ </td>
917
+ <td> +
918
+ </td>
919
+ <td> +
920
+ </td>
921
+ <td> - (continuous only)
922
+ </td>
923
+ </tr>
924
+ <tr>
925
+ <td>
926
+ <a href="https://github.com/pytorch/rl/blob/main/sota-implementations/td3_bc/td3_bc.py">TD3+BC</a>
927
+ </td>
928
+ <td> untested
929
+ </td>
930
+ <td> +
931
+ </td>
932
+ <td> +
933
+ </td>
934
+ <td> - (continuous only)
935
+ </td>
936
+ </tr>
937
+ <tr>
938
+ <td>
939
+ <a href="https://github.com/pytorch/rl/blob/main/examples/a2c/">A2C</a>
940
+ </td>
941
+ <td> 2.67x
942
+ </td>
943
+ <td> +
944
+ </td>
945
+ <td> -
946
+ </td>
947
+ <td> +
948
+ </td>
949
+ </tr>
950
+ <tr>
951
+ <td>
952
+ <a href="https://github.com/pytorch/rl/blob/main/sota-implementations/ppo/">PPO</a>
953
+ </td>
954
+ <td> 2.42x
955
+ </td>
956
+ <td> +
957
+ </td>
958
+ <td> -
959
+ </td>
960
+ <td> +
961
+ </td>
962
+ </tr>
963
+ <tr>
964
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/sac/sac.py">SAC</a>
965
+ </td>
966
+ <td> 2.62x
967
+ </td>
968
+ <td> +
969
+ </td>
970
+ <td> -
971
+ </td>
972
+ <td> +
973
+ </td>
974
+ </tr>
975
+ <tr>
976
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/redq/redq.py">REDQ</a>
977
+ </td>
978
+ <td> 2.28x
979
+ </td>
980
+ <td> +
981
+ </td>
982
+ <td> -
983
+ </td>
984
+ <td> - (continuous only)
985
+ </td>
986
+ </tr>
987
+ <tr>
988
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/dreamer/dreamer.py">Dreamer v1</a>
989
+ </td>
990
+ <td> untested
991
+ </td>
992
+ <td> +
993
+ </td>
994
+ <td> + (<a href="https://pytorch.org/rl/stable/reference/objectives.html#dreamer">different classes</a>)
995
+ </td>
996
+ <td> - (continuous only)
997
+ </td>
998
+ </tr>
999
+ <tr>
1000
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/decision_transformer">Decision Transformers</a>
1001
+ </td>
1002
+ <td> untested
1003
+ </td>
1004
+ <td> +
1005
+ </td>
1006
+ <td> NA
1007
+ </td>
1008
+ <td> - (continuous only)
1009
+ </td>
1010
+ </tr>
1011
+ <tr>
1012
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/crossq">CrossQ</a>
1013
+ </td>
1014
+ <td> untested
1015
+ </td>
1016
+ <td> +
1017
+ </td>
1018
+ <td> +
1019
+ </td>
1020
+ <td> - (continuous only)
1021
+ </td>
1022
+ </tr>
1023
+ <tr>
1024
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/gail">Gail</a>
1025
+ </td>
1026
+ <td> untested
1027
+ </td>
1028
+ <td> +
1029
+ </td>
1030
+ <td> NA
1031
+ </td>
1032
+ <td> +
1033
+ </td>
1034
+ </tr>
1035
+ <tr>
1036
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/impala">Impala</a>
1037
+ </td>
1038
+ <td> untested
1039
+ </td>
1040
+ <td> +
1041
+ </td>
1042
+ <td> -
1043
+ </td>
1044
+ <td> +
1045
+ </td>
1046
+ </tr>
1047
+ <tr>
1048
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/multiagent/iql.py">IQL (MARL)</a>
1049
+ </td>
1050
+ <td> untested
1051
+ </td>
1052
+ <td> +
1053
+ </td>
1054
+ <td> +
1055
+ </td>
1056
+ <td> +
1057
+ </td>
1058
+ </tr>
1059
+ <tr>
1060
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/multiagent/maddpg_iddpg.py">DDPG (MARL)</a>
1061
+ </td>
1062
+ <td> untested
1063
+ </td>
1064
+ <td> +
1065
+ </td>
1066
+ <td> +
1067
+ </td>
1068
+ <td> - (continuous only)
1069
+ </td>
1070
+ </tr>
1071
+ <tr>
1072
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/multiagent/mappo_ippo.py">PPO (MARL)</a>
1073
+ </td>
1074
+ <td> untested
1075
+ </td>
1076
+ <td> +
1077
+ </td>
1078
+ <td> -
1079
+ </td>
1080
+ <td> +
1081
+ </td>
1082
+ </tr>
1083
+ <tr>
1084
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/multiagent/qmix_vdn.py">QMIX-VDN (MARL)</a>
1085
+ </td>
1086
+ <td> untested
1087
+ </td>
1088
+ <td> +
1089
+ </td>
1090
+ <td> NA
1091
+ </td>
1092
+ <td> +
1093
+ </td>
1094
+ </tr>
1095
+ <tr>
1096
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/multiagent/sac.py">SAC (MARL)</a>
1097
+ </td>
1098
+ <td> untested
1099
+ </td>
1100
+ <td> +
1101
+ </td>
1102
+ <td> -
1103
+ </td>
1104
+ <td> +
1105
+ </td>
1106
+ </tr>
1107
+ <tr>
1108
+ <td><a href="https://github.com/pytorch/rl/blob/main/examples/rlhf">RLHF</a>
1109
+ </td>
1110
+ <td> NA
1111
+ </td>
1112
+ <td> +
1113
+ </td>
1114
+ <td> NA
1115
+ </td>
1116
+ <td> NA
1117
+ </td>
1118
+ </tr>
1119
+ <tr>
1120
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/grpo">LLM API (GRPO)</a>
1121
+ </td>
1122
+ <td> NA
1123
+ </td>
1124
+ <td> +
1125
+ </td>
1126
+ <td> +
1127
+ </td>
1128
+ <td> NA
1129
+ </td>
1130
+ </tr>
1131
+ </table>
1132
+
1133
+ ** The number indicates expected speed-up compared to eager mode when executed on CPU. Numbers may vary depending on
1134
+ architecture and device.
1135
+
1136
+ and many more to come!
1137
+
1138
+ [Code examples](examples/) displaying toy code snippets and training scripts are also available
1139
+ - [LLM API & GRPO](sota-implementations/grpo) - Complete language model fine-tuning pipeline
1140
+ - [RLHF](examples/rlhf)
1141
+ - [Memory-mapped replay buffers](examples/torchrl_features)
1142
+
1143
+
1144
+ Check the [examples](https://github.com/pytorch/rl/blob/main/sota-implementations/) directory for more details
1145
+ about handling the various configuration settings.
1146
+
1147
+ We also provide [tutorials and demos](https://pytorch.org/rl/stable#tutorials) that give a sense of
1148
+ what the library can do.
1149
+
1150
+ ## Citation
1151
+
1152
+ If you're using TorchRL, please refer to this BibTeX entry to cite this work:
1153
+ ```
1154
+ @misc{bou2023torchrl,
1155
+ title={TorchRL: A data-driven decision-making library for PyTorch},
1156
+ author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
1157
+ year={2023},
1158
+ eprint={2306.00577},
1159
+ archivePrefix={arXiv},
1160
+ primaryClass={cs.LG}
1161
+ }
1162
+ ```
1163
+
1164
+ ## Installation
1165
+
1166
+ ### Create a new virtual environment:
1167
+ ```bash
1168
+ python -m venv torchrl
1169
+ source torchrl/bin/activate # On Windows use: venv\Scripts\activate
1170
+ ```
1171
+
1172
+ Or create a conda environment where the packages will be installed.
1173
+
1174
+ ```
1175
+ conda create --name torchrl python=3.10
1176
+ conda activate torchrl
1177
+ ```
1178
+
1179
+ ### Install dependencies:
1180
+
1181
+ #### PyTorch
1182
+
1183
+ Depending on the use of torchrl that you want to make, you may want to
1184
+ install the latest (nightly) PyTorch release or the latest stable version of PyTorch.
1185
+ See [here](https://pytorch.org/get-started/locally/) for a detailed list of commands,
1186
+ including `pip3` or other special installation instructions.
1187
+
1188
+ TorchRL offers a few pre-defined dependencies such as `"torchrl[tests]"`, `"torchrl[atari]"`, `"torchrl[utils]"` etc.
1189
+
1190
+ For the experimental training interface and configuration system, install:
1191
+ ```bash
1192
+ pip3 install "torchrl[utils]" # Includes hydra-core and other utilities
1193
+ ```
1194
+
1195
+ #### Torchrl
1196
+
1197
+ You can install the **latest stable release** by using
1198
+ ```bash
1199
+ pip3 install torchrl
1200
+ ```
1201
+ This should work on linux (including AArch64 machines), Windows 10 and OsX (Metal chips only).
1202
+ On certain Windows machines (Windows 11), one should build the library locally.
1203
+ This can be done in two ways:
1204
+
1205
+ ```bash
1206
+ # Install and build locally v0.8.1 of the library without cloning
1207
+ pip3 install git+https://github.com/pytorch/rl@v0.8.1
1208
+ # Clone the library and build it locally
1209
+ git clone https://github.com/pytorch/tensordict
1210
+ git clone https://github.com/pytorch/rl
1211
+ pip install -e tensordict
1212
+ pip install -e rl
1213
+ ```
1214
+
1215
+ If you use `uv` (instead of `pip`) and you have already installed a specific PyTorch build (e.g. nightly),
1216
+ make sure `uv` doesn't re-resolve dependencies (which can downgrade PyTorch). Use `--no-deps` for the local installs:
1217
+
1218
+ ```bash
1219
+ uv pip install --no-deps -e tensordict
1220
+ uv pip install --no-deps -e rl
1221
+ ```
1222
+
1223
+ Note that tensordict local build requires `cmake` to be installed via [homebrew](https://brew.sh/) (MacOS) or another package manager
1224
+ such as `apt`, `apt-get`, `conda` or `yum` but NOT `pip`, as well as `pip install "pybind11[global]"`.
1225
+
1226
+ One can also build the wheels to distribute to co-workers using
1227
+ ```bash
1228
+ pip install build
1229
+ python -m build --wheel
1230
+ ```
1231
+ Your wheels will be stored there `./dist/torchrl<name>.whl` and installable via
1232
+ ```bash
1233
+ pip install torchrl<name>.whl
1234
+ ```
1235
+
1236
+ The **nightly build** can be installed via
1237
+ ```bash
1238
+ pip3 install tensordict-nightly torchrl-nightly
1239
+ ```
1240
+ which we currently only ship for Linux machines.
1241
+ Importantly, the nightly builds require the nightly builds of PyTorch too.
1242
+ Also, a local build of torchrl with the nightly build of tensordict may fail - install both nightlies or both local builds but do not mix them.
1243
+
1244
+
1245
+ **Disclaimer**: As of today, TorchRL requires Python 3.10+ and is roughly compatible with any pytorch version >= 2.1. Installing it will not
1246
+ directly require a newer version of pytorch to be installed. Indirectly though, tensordict still requires the latest
1247
+ PyTorch to be installed and we are working hard to loosen that requirement.
1248
+ The C++ binaries of TorchRL (mainly for prioritized replay buffers) will only work with PyTorch 2.7.0 and above.
1249
+ Some features (e.g., working with nested jagged tensors) may also
1250
+ be limited with older versions of pytorch. It is recommended to use the latest TorchRL with the latest PyTorch version
1251
+ unless there is a strong reason not to do so.
1252
+
1253
+ **Optional dependencies**
1254
+
1255
+ The following libraries can be installed depending on the usage one wants to
1256
+ make of torchrl:
1257
+ ```
1258
+ # diverse
1259
+ pip3 install tqdm tensorboard "hydra-core>=1.1" hydra-submitit-launcher
1260
+
1261
+ # rendering
1262
+ pip3 install "moviepy<2.0.0"
1263
+
1264
+ # deepmind control suite
1265
+ pip3 install dm_control
1266
+
1267
+ # gym, atari games
1268
+ pip3 install "gym[atari]" "gym[accept-rom-license]" pygame
1269
+
1270
+ # tests
1271
+ pip3 install pytest pyyaml pytest-instafail
1272
+
1273
+ # tensorboard
1274
+ pip3 install tensorboard
1275
+
1276
+ # wandb
1277
+ pip3 install wandb
1278
+ ```
1279
+
1280
+ Versioning issues can cause error message of the type ```undefined symbol```
1281
+ and such. For these, refer to the [versioning issues document](https://github.com/pytorch/rl/blob/main/knowledge_base/VERSIONING_ISSUES.md)
1282
+ for a complete explanation and proposed workarounds.
1283
+
1284
+ ## Asking a question
1285
+
1286
+ If you spot a bug in the library, please raise an issue in this repo.
1287
+
1288
+ If you have a more generic question regarding RL in PyTorch, post it on
1289
+ the [PyTorch forum](https://discuss.pytorch.org/c/reinforcement-learning/6).
1290
+
1291
+ ## Contributing
1292
+
1293
+ Internal collaborations to torchrl are welcome! Feel free to fork, submit issues and PRs.
1294
+ You can checkout the detailed contribution guide [here](https://github.com/pytorch/rl/blob/main/CONTRIBUTING.md).
1295
+ As mentioned above, a list of open contributions can be found in [here](https://github.com/pytorch/rl/issues/509).
1296
+
1297
+ Contributors are recommended to install [pre-commit hooks](https://pre-commit.com/) (using `pre-commit install`). pre-commit will check for linting related issues when the code is committed locally. You can disable th check by appending `-n` to your commit command: `git commit -m <commit message> -n`
1298
+
1299
+
1300
+ ## Disclaimer
1301
+
1302
+ This library is released as a PyTorch beta feature.
1303
+ BC-breaking changes are likely to happen but they will be introduced with a deprecation
1304
+ warranty after a few release cycles.
1305
+
1306
+ # License
1307
+ TorchRL is licensed under the MIT License. See [LICENSE](https://github.com/pytorch/rl/blob/main/LICENSE) for details.