torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1308 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchrl
3
+ Version: 0.11.0
4
+ Summary: A modular, primitive-first, python-first PyTorch library for Reinforcement Learning
5
+ Author-email: torchrl contributors <vmoens@fb.com>
6
+ Maintainer-email: torchrl contributors <vmoens@fb.com>
7
+ Project-URL: Homepage, https://github.com/pytorch/rl
8
+ Project-URL: Documentation, https://pytorch.org/rl
9
+ Project-URL: Repository, https://github.com/pytorch/rl
10
+ Project-URL: Bug Tracker, https://github.com/pytorch/rl/issues
11
+ Project-URL: twitter, https://x.com/torchrl1
12
+ Project-URL: linkedin, https://www.linkedin.com/company/torchrl
13
+ Project-URL: discord, https://discord.gg/cZs26Qq3Dd
14
+ Project-URL: benchmark, https://docs.pytorch.org/rl/dev/bench/
15
+ Keywords: reinforcement-learning,pytorch,rl,machine-learning
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Programming Language :: Python :: 3.13
20
+ Classifier: Programming Language :: Python :: 3.14
21
+ Classifier: Operating System :: OS Independent
22
+ Classifier: Development Status :: 4 - Beta
23
+ Classifier: Intended Audience :: Developers
24
+ Classifier: Intended Audience :: Science/Research
25
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
26
+ Requires-Python: >=3.10
27
+ Description-Content-Type: text/markdown
28
+ License-File: LICENSE
29
+ Requires-Dist: torch>=2.1.0
30
+ Requires-Dist: pyvers
31
+ Requires-Dist: numpy
32
+ Requires-Dist: packaging
33
+ Requires-Dist: cloudpickle
34
+ Requires-Dist: tensordict<0.12.0,>=0.11.0
35
+ Provides-Extra: atari
36
+ Requires-Dist: gymnasium[atari]; extra == "atari"
37
+ Provides-Extra: dm-control
38
+ Requires-Dist: dm_control; extra == "dm-control"
39
+ Provides-Extra: replay-buffer
40
+ Requires-Dist: torch>=2.7.0; extra == "replay-buffer"
41
+ Provides-Extra: gym-continuous
42
+ Requires-Dist: gymnasium<1.0; extra == "gym-continuous"
43
+ Requires-Dist: mujoco; extra == "gym-continuous"
44
+ Provides-Extra: rendering
45
+ Requires-Dist: moviepy<2.0.0; extra == "rendering"
46
+ Provides-Extra: tests
47
+ Requires-Dist: pytest; extra == "tests"
48
+ Requires-Dist: pyyaml; extra == "tests"
49
+ Requires-Dist: pytest-instafail; extra == "tests"
50
+ Requires-Dist: scipy; extra == "tests"
51
+ Requires-Dist: psutil; extra == "tests"
52
+ Requires-Dist: pytest-mock; extra == "tests"
53
+ Requires-Dist: pytest-cov; extra == "tests"
54
+ Requires-Dist: pytest-asyncio; extra == "tests"
55
+ Requires-Dist: pytest-benchmark; extra == "tests"
56
+ Requires-Dist: pytest-rerunfailures; extra == "tests"
57
+ Requires-Dist: pytest-error-for-skips; extra == "tests"
58
+ Requires-Dist: pytest-timeout; extra == "tests"
59
+ Requires-Dist: pytest-forked; extra == "tests"
60
+ Requires-Dist: pytest-random-order; extra == "tests"
61
+ Requires-Dist: pytest-repeat; extra == "tests"
62
+ Requires-Dist: pytest-isolate; extra == "tests"
63
+ Provides-Extra: utils
64
+ Requires-Dist: tensorboard; extra == "utils"
65
+ Requires-Dist: wandb; extra == "utils"
66
+ Requires-Dist: tqdm; extra == "utils"
67
+ Requires-Dist: hydra-core>=1.1; extra == "utils"
68
+ Requires-Dist: hydra-submitit-launcher; extra == "utils"
69
+ Provides-Extra: checkpointing
70
+ Requires-Dist: torchsnapshot; extra == "checkpointing"
71
+ Provides-Extra: offline-data
72
+ Requires-Dist: huggingface_hub; extra == "offline-data"
73
+ Requires-Dist: minari; extra == "offline-data"
74
+ Requires-Dist: requests; extra == "offline-data"
75
+ Requires-Dist: tqdm; extra == "offline-data"
76
+ Requires-Dist: torchvision; extra == "offline-data"
77
+ Requires-Dist: scikit-learn; extra == "offline-data"
78
+ Requires-Dist: pandas; extra == "offline-data"
79
+ Requires-Dist: h5py; extra == "offline-data"
80
+ Requires-Dist: pillow; extra == "offline-data"
81
+ Provides-Extra: marl
82
+ Requires-Dist: vmas>=1.2.10; extra == "marl"
83
+ Requires-Dist: pettingzoo>=1.24.1; extra == "marl"
84
+ Requires-Dist: dm-meltingpot; python_version >= "3.11" and extra == "marl"
85
+ Provides-Extra: open-spiel
86
+ Requires-Dist: open_spiel>=1.5; extra == "open-spiel"
87
+ Provides-Extra: brax
88
+ Requires-Dist: jax>=0.7.0; python_version >= "3.11" and extra == "brax"
89
+ Requires-Dist: brax; python_version >= "3.11" and extra == "brax"
90
+ Provides-Extra: procgen
91
+ Requires-Dist: procgen; extra == "procgen"
92
+ Provides-Extra: llm
93
+ Requires-Dist: transformers; extra == "llm"
94
+ Requires-Dist: vllm; extra == "llm"
95
+ Requires-Dist: playwright; extra == "llm"
96
+ Requires-Dist: datasets; extra == "llm"
97
+ Requires-Dist: langdetect; extra == "llm"
98
+ Requires-Dist: nltk; extra == "llm"
99
+ Requires-Dist: immutabledict; extra == "llm"
100
+ Requires-Dist: accelerate; extra == "llm"
101
+ Requires-Dist: sentencepiece; extra == "llm"
102
+ Requires-Dist: protobuf; extra == "llm"
103
+ Requires-Dist: einops; extra == "llm"
104
+ Requires-Dist: safetensors; extra == "llm"
105
+ Provides-Extra: grpo
106
+ Requires-Dist: datasets; extra == "grpo"
107
+ Requires-Dist: peft; extra == "grpo"
108
+ Requires-Dist: wandb; extra == "grpo"
109
+ Requires-Dist: vllm; extra == "grpo"
110
+ Requires-Dist: transformers; extra == "grpo"
111
+ Requires-Dist: accelerate; extra == "grpo"
112
+ Requires-Dist: ray; extra == "grpo"
113
+ Requires-Dist: tqdm; extra == "grpo"
114
+ Requires-Dist: flash-attn; extra == "grpo"
115
+ Requires-Dist: bitsandbytes; extra == "grpo"
116
+ Requires-Dist: xformers; extra == "grpo"
117
+ Requires-Dist: nltk; extra == "grpo"
118
+ Requires-Dist: langdetect; extra == "grpo"
119
+ Requires-Dist: immutabledict; extra == "grpo"
120
+ Provides-Extra: dev
121
+ Requires-Dist: pre-commit; extra == "dev"
122
+ Requires-Dist: autoflake; extra == "dev"
123
+ Dynamic: license-file
124
+
125
+ [![Unit-tests](https://github.com/pytorch/rl/actions/workflows/test-linux.yml/badge.svg)](https://github.com/pytorch/rl/actions/workflows/test-linux.yml)
126
+ [![Documentation](https://img.shields.io/badge/Documentation-blue.svg)](https://pytorch.org/rl/)
127
+ [![Benchmarks](https://img.shields.io/badge/Benchmarks-blue.svg)](https://pytorch.github.io/rl/dev/bench/)
128
+ [![codecov](https://codecov.io/gh/pytorch/rl/branch/main/graph/badge.svg?token=HcpK1ILV6r)](https://codecov.io/gh/pytorch/rl)
129
+ [![Twitter Follow](https://img.shields.io/twitter/follow/torchrl1?style=social)](https://twitter.com/torchrl1)
130
+ [![Python version](https://img.shields.io/pypi/pyversions/torchrl.svg)](https://www.python.org/downloads/)
131
+ [![GitHub license](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/pytorch/rl/blob/main/LICENSE)
132
+ <a href="https://pypi.org/project/torchrl"><img src="https://img.shields.io/pypi/v/torchrl" alt="pypi version"></a>
133
+ <a href="https://pypi.org/project/torchrl-nightly"><img src="https://img.shields.io/pypi/v/torchrl-nightly?label=nightly" alt="pypi nightly version"></a>
134
+ [![Downloads](https://static.pepy.tech/personalized-badge/torchrl?period=total&units=international_system&left_color=blue&right_color=orange&left_text=Downloads)](https://pepy.tech/project/torchrl)
135
+ [![Downloads](https://static.pepy.tech/personalized-badge/torchrl-nightly?period=total&units=international_system&left_color=blue&right_color=orange&left_text=Downloads%20(nightly))](https://pepy.tech/project/torchrl-nightly)
136
+ [![Discord Shield](https://dcbadge.vercel.app/api/server/cZs26Qq3Dd)](https://discord.gg/cZs26Qq3Dd)
137
+
138
+ # TorchRL
139
+
140
+ <p align="center">
141
+ <img src="docs/source/_static/img/icon.png" width="200" >
142
+ </p>
143
+
144
+ [**What's New**](#-whats-new) | [**LLM API**](#llm-api---complete-framework-for-language-model-fine-tuning) | [**Getting Started**](#getting-started) | [**Documentation**](#documentation-and-knowledge-base) | [**TensorDict**](#writing-simplified-and-portable-rl-codebase-with-tensordict) |
145
+ [**Features**](#features) | [**Examples, tutorials and demos**](#examples-tutorials-and-demos) | [**Citation**](#citation) | [**Installation**](#installation) |
146
+ [**Asking a question**](#asking-a-question) | [**Contributing**](#contributing)
147
+
148
+ **TorchRL** is an open-source Reinforcement Learning (RL) library for PyTorch.
149
+
150
+ ## 🚀 What's New
151
+
152
+ ### 🚀 **Command-Line Training Interface** - Train RL Agents Without Writing Code! (Experimental)
153
+
154
+ TorchRL now provides a **powerful command-line interface** that lets you train state-of-the-art RL agents with simple bash commands! No Python scripting required - just run training with customizable parameters:
155
+
156
+ - 🎯 **One-Command Training**: `python sota-implementations/ppo_trainer/train.py`
157
+ - ⚙️ **Full Customization**: Override any parameter via command line: `trainer.total_frames=2000000 optimizer.lr=0.0003`
158
+ - 🌍 **Multi-Environment Support**: Switch between Gym, Brax, DM Control, and more with `env=gym training_env.create_env_fn.base_env.env_name=HalfCheetah-v4`
159
+ - 📊 **Built-in Logging**: TensorBoard, Weights & Biases, CSV logging out of the box
160
+ - 🔧 **Hydra-Powered**: Leverages Hydra's powerful configuration system for maximum flexibility
161
+ - 🏃‍♂️ **Production Ready**: Same robust training pipeline as our SOTA implementations
162
+
163
+ **Perfect for**: Researchers, practitioners, and anyone who wants to train RL agents without diving into implementation details.
164
+
165
+ ⚠️ **Note**: This is an experimental feature. The API may change in future versions. We welcome feedback and contributions to help improve this implementation!
166
+
167
+ 📋 **Prerequisites**: The training interface requires Hydra for configuration management. Install with:
168
+ ```bash
169
+ pip install "torchrl[utils]"
170
+ # or manually:
171
+ pip install hydra-core omegaconf
172
+ ```
173
+
174
+ Check out the [complete CLI documentation](https://github.com/pytorch/rl/tree/main/sota-implementations/ppo_trainer) to get started!
175
+
176
+ ### 🚀 **vLLM Revamp** - Major Enhancement to LLM Infrastructure (v0.10)
177
+
178
+ This release introduces a comprehensive revamp of TorchRL's vLLM integration, delivering significant improvements in performance, scalability, and usability for large language model inference and training workflows:
179
+
180
+ - 🔥 **AsyncVLLM Service**: Production-ready distributed vLLM inference with multi-replica scaling and automatic Ray actor management
181
+ - ⚖️ **Multiple Load Balancing Strategies**: Routing strategies including prefix-aware, request-based, and KV-cache load balancing for optimal performance
182
+ - 🏗️ **Unified vLLM Architecture**: New `RLvLLMEngine` interface standardizing all vLLM backends with simplified `vLLMUpdaterV2` for seamless weight updates
183
+ - 🌐 **Distributed Data Loading**: New `RayDataLoadingPrimer` for shared, distributed data loading across multiple environments
184
+ - 📈 **Enhanced Performance**: Native vLLM batching, concurrent request processing, and optimized resource allocation via Ray placement groups
185
+
186
+ ```python
187
+ # Simple AsyncVLLM usage - production ready!
188
+ from torchrl.modules.llm import AsyncVLLM, vLLMWrapper
189
+
190
+ # Create distributed vLLM service with load balancing
191
+ service = AsyncVLLM.from_pretrained(
192
+ "Qwen/Qwen2.5-7B",
193
+ num_devices=2, # Tensor parallel across 2 GPUs
194
+ num_replicas=4, # 4 replicas for high throughput
195
+ max_model_len=4096
196
+ )
197
+
198
+ # Use with TorchRL's LLM wrappers
199
+ wrapper = vLLMWrapper(service, input_mode="history")
200
+
201
+ # Simplified weight updates
202
+ from torchrl.collectors.llm import vLLMUpdaterV2
203
+ updater = vLLMUpdaterV2(service) # Auto-configures from engine
204
+ ```
205
+
206
+ This revamp positions TorchRL as the leading platform for scalable LLM inference and training, providing production-ready tools for both research and deployment scenarios.
207
+
208
+ ### 🧪 PPOTrainer (Experimental) - High-Level Training Interface
209
+
210
+ TorchRL now includes an **experimental PPOTrainer** that provides a complete, configurable PPO training solution! This prototype feature combines TorchRL's modular components into a cohesive training system with sensible defaults:
211
+
212
+ - 🎯 **Complete Training Pipeline**: Handles environment setup, data collection, loss computation, and optimization automatically
213
+ - ⚙️ **Extensive Configuration**: Comprehensive Hydra-based config system for easy experimentation and hyperparameter tuning
214
+ - 📊 **Built-in Logging**: Automatic tracking of rewards, actions, episode completion rates, and training statistics
215
+ - 🔧 **Modular Design**: Built on existing TorchRL components (collectors, losses, replay buffers) for maximum flexibility
216
+ - 📝 **Minimal Code**: Complete SOTA implementation in [just ~20 lines](sota-implementations/ppo_trainer/train.py)!
217
+
218
+ **Working Example**: See [`sota-implementations/ppo_trainer/`](sota-implementations/ppo_trainer/) for a complete, working PPO implementation that trains on Pendulum-v1 with full Hydra configuration support.
219
+
220
+ **Prerequisites**: Requires Hydra for configuration management: `pip install "torchrl[utils]"`
221
+
222
+ <details>
223
+ <summary>Complete Training Script (sota-implementations/ppo_trainer/train.py)</summary>
224
+
225
+ ```python
226
+ import hydra
227
+ from torchrl.trainers.algorithms.configs import *
228
+
229
+ @hydra.main(config_path="config", config_name="config", version_base="1.1")
230
+ def main(cfg):
231
+ trainer = hydra.utils.instantiate(cfg.trainer)
232
+ trainer.train()
233
+
234
+ if __name__ == "__main__":
235
+ main()
236
+ ```
237
+ *Complete PPO training in ~20 lines with full configurability.*
238
+
239
+ </details>
240
+
241
+ <details>
242
+ <summary>API Usage Examples</summary>
243
+
244
+ ```bash
245
+ # Basic usage - train PPO on Pendulum-v1 with default settings
246
+ python sota-implementations/ppo_trainer/train.py
247
+
248
+ # Custom configuration with command-line overrides
249
+ python sota-implementations/ppo_trainer/train.py \
250
+ trainer.total_frames=2000000 \
251
+ training_env.create_env_fn.base_env.env_name=HalfCheetah-v4 \
252
+ networks.policy_network.num_cells=[256,256] \
253
+ optimizer.lr=0.0003
254
+
255
+ # Use different environment and logger
256
+ python sota-implementations/ppo_trainer/train.py \
257
+ env=gym \
258
+ training_env.create_env_fn.base_env.env_name=Walker2d-v4 \
259
+ logger=tensorboard
260
+
261
+ # See all available options
262
+ python sota-implementations/ppo_trainer/train.py --help
263
+ ```
264
+
265
+ </details>
266
+
267
+ **Future Plans**: Additional algorithm trainers (SAC, TD3, DQN) and full integration of all TorchRL components within the configuration system are planned for upcoming releases.
268
+
269
+ ## LLM API - Complete Framework for Language Model Fine-tuning
270
+
271
+ TorchRL includes a comprehensive **LLM API** for post-training and fine-tuning of language models! This framework provides everything you need for RLHF, supervised fine-tuning, and tool-augmented training:
272
+
273
+ - 🤖 **Unified LLM Wrappers**: Seamless integration with Hugging Face models and vLLM inference engines
274
+ - 💬 **Conversation Management**: Advanced [`History`](torchrl/data/llm/history.py) class for multi-turn dialogue with automatic chat template detection
275
+ - 🛠️ **Tool Integration**: [Built-in support](torchrl/envs/llm/transforms/) for Python code execution, function calling, and custom tool transforms
276
+ - 🎯 **Specialized Objectives**: [GRPO](torchrl/objectives/llm/grpo.py) (Group Relative Policy Optimization) and [SFT](torchrl/objectives/llm/sft.py) loss functions optimized for language models
277
+ - ⚡ **High-Performance Collectors**: [Async data collection](torchrl/collectors/llm/) with distributed training support
278
+ - 🔄 **Flexible Environments**: Transform-based architecture for reward computation, data loading, and conversation augmentation
279
+
280
+ The LLM API follows TorchRL's modular design principles, allowing you to mix and match components for your specific use case. Check out the [complete documentation](https://pytorch.org/rl/main/reference/llms.html) and [GRPO implementation example](https://github.com/pytorch/rl/tree/main/sota-implementations/grpo) to get started!
281
+
282
+ <details>
283
+ <summary>Quick LLM API Example</summary>
284
+
285
+ ```python
286
+ from torchrl.envs.llm import ChatEnv
287
+ from torchrl.modules.llm import TransformersWrapper
288
+ from torchrl.objectives.llm import GRPOLoss
289
+ from torchrl.collectors.llm import LLMCollector
290
+
291
+ # Create environment with Python tool execution
292
+ env = ChatEnv(
293
+ tokenizer=tokenizer,
294
+ system_prompt="You are an assistant that can execute Python code.",
295
+ batch_size=[1]
296
+ ).append_transform(PythonInterpreter())
297
+
298
+ # Wrap your language model
299
+ llm = TransformersWrapper(
300
+ model=model,
301
+ tokenizer=tokenizer,
302
+ input_mode="history"
303
+ )
304
+
305
+ # Set up GRPO training
306
+ loss_fn = GRPOLoss(llm, critic, gamma=0.99)
307
+ collector = LLMCollector(env, llm, frames_per_batch=100)
308
+
309
+ # Training loop
310
+ for data in collector:
311
+ loss = loss_fn(data)
312
+ loss.backward()
313
+ optimizer.step()
314
+ ```
315
+
316
+ </details>
317
+
318
+ ## Key features
319
+
320
+ - 🐍 **Python-first**: Designed with Python as the primary language for ease of use and flexibility
321
+ - ⏱️ **Efficient**: Optimized for performance to support demanding RL research applications
322
+ - 🧮 **Modular, customizable, extensible**: Highly modular architecture allows for easy swapping, transformation, or creation of new components
323
+ - 📚 **Documented**: Thorough documentation ensures that users can quickly understand and utilize the library
324
+ - ✅ **Tested**: Rigorously tested to ensure reliability and stability
325
+ - ⚙️ **Reusable functionals**: Provides a set of highly reusable functions for cost functions, returns, and data processing
326
+
327
+ ### Design Principles
328
+
329
+ - 🔥 **Aligns with PyTorch ecosystem**: Follows the structure and conventions of popular PyTorch libraries
330
+ (e.g., dataset pillar, transforms, models, data utilities)
331
+ - ➖ Minimal dependencies: Only requires Python standard library, NumPy, and PyTorch; optional dependencies for
332
+ common environment libraries (e.g., OpenAI Gym) and datasets (D4RL, OpenX...)
333
+
334
+ Read the [full paper](https://arxiv.org/abs/2306.00577) for a more curated description of the library.
335
+
336
+ ## Getting started
337
+
338
+ Check our [Getting Started tutorials](https://pytorch.org/rl/stable/index.html#getting-started) for quickly ramp up with the basic
339
+ features of the library!
340
+
341
+ <p align="center">
342
+ <img src="docs/ppo.png" width="800" >
343
+ </p>
344
+
345
+ ## Documentation and knowledge base
346
+
347
+ The TorchRL documentation can be found [here](https://pytorch.org/rl).
348
+ It contains tutorials and the API reference.
349
+
350
+ TorchRL also provides a RL knowledge base to help you debug your code, or simply
351
+ learn the basics of RL. Check it out [here](https://pytorch.org/rl/stable/reference/knowledge_base.html).
352
+
353
+ We have some introductory videos for you to get to know the library better, check them out:
354
+
355
+ - [TalkRL podcast](https://www.talkrl.com/episodes/vincent-moens-on-torchrl)
356
+ - [TorchRL intro at PyTorch day 2022](https://youtu.be/cIKMhZoykEE)
357
+ - [PyTorch 2.0 Q&A: TorchRL](https://www.youtube.com/live/myEfUoYrbts?feature=share)
358
+
359
+ ## Spotlight publications
360
+
361
+ TorchRL being domain-agnostic, you can use it across many different fields. Here are a few examples:
362
+
363
+ - [ACEGEN](https://pubs.acs.org/doi/10.1021/acs.jcim.4c00895): Reinforcement Learning of Generative Chemical Agents
364
+ for Drug Discovery
365
+ - [BenchMARL](https://www.jmlr.org/papers/v25/23-1612.html): Benchmarking Multi-Agent Reinforcement Learning
366
+ - [BricksRL](https://arxiv.org/abs/2406.17490): A Platform for Democratizing Robotics and Reinforcement Learning
367
+ Research and Education with LEGO
368
+ - [OmniDrones](https://ieeexplore.ieee.org/abstract/document/10409589): An Efficient and Flexible Platform for Reinforcement Learning in Drone Control
369
+ - [RL4CO](https://arxiv.org/abs/2306.17100): an Extensive Reinforcement Learning for Combinatorial Optimization Benchmark
370
+ - [Robohive](https://proceedings.neurips.cc/paper_files/paper/2023/file/8a84a4341c375b8441b36836bb343d4e-Paper-Datasets_and_Benchmarks.pdf): A unified framework for robot learning
371
+
372
+ ## Writing simplified and portable RL codebase with `TensorDict`
373
+
374
+ RL algorithms are very heterogeneous, and it can be hard to recycle a codebase
375
+ across settings (e.g. from online to offline, from state-based to pixel-based
376
+ learning).
377
+ TorchRL solves this problem through [`TensorDict`](https://github.com/pytorch/tensordict/),
378
+ a convenient data structure<sup>(1)</sup> that can be used to streamline one's
379
+ RL codebase.
380
+ With this tool, one can write a *complete PPO training script in less than 100
381
+ lines of code*!
382
+
383
+ <details>
384
+ <summary>Code</summary>
385
+
386
+ ```python
387
+ import torch
388
+ from tensordict.nn import TensorDictModule
389
+ from tensordict.nn.distributions import NormalParamExtractor
390
+ from torch import nn
391
+
392
+ from torchrl.collectors import SyncDataCollector
393
+ from torchrl.data.replay_buffers import TensorDictReplayBuffer, \
394
+ LazyTensorStorage, SamplerWithoutReplacement
395
+ from torchrl.envs.libs.gym import GymEnv
396
+ from torchrl.modules import ProbabilisticActor, ValueOperator, TanhNormal
397
+ from torchrl.objectives import ClipPPOLoss
398
+ from torchrl.objectives.value import GAE
399
+
400
+ env = GymEnv("Pendulum-v1")
401
+ model = TensorDictModule(
402
+ nn.Sequential(
403
+ nn.Linear(3, 128), nn.Tanh(),
404
+ nn.Linear(128, 128), nn.Tanh(),
405
+ nn.Linear(128, 128), nn.Tanh(),
406
+ nn.Linear(128, 2),
407
+ NormalParamExtractor()
408
+ ),
409
+ in_keys=["observation"],
410
+ out_keys=["loc", "scale"]
411
+ )
412
+ critic = ValueOperator(
413
+ nn.Sequential(
414
+ nn.Linear(3, 128), nn.Tanh(),
415
+ nn.Linear(128, 128), nn.Tanh(),
416
+ nn.Linear(128, 128), nn.Tanh(),
417
+ nn.Linear(128, 1),
418
+ ),
419
+ in_keys=["observation"],
420
+ )
421
+ actor = ProbabilisticActor(
422
+ model,
423
+ in_keys=["loc", "scale"],
424
+ distribution_class=TanhNormal,
425
+ distribution_kwargs={"low": -1.0, "high": 1.0},
426
+ return_log_prob=True
427
+ )
428
+ buffer = TensorDictReplayBuffer(
429
+ storage=LazyTensorStorage(1000),
430
+ sampler=SamplerWithoutReplacement(),
431
+ batch_size=50,
432
+ )
433
+ collector = SyncDataCollector(
434
+ env,
435
+ actor,
436
+ frames_per_batch=1000,
437
+ total_frames=1_000_000,
438
+ )
439
+ loss_fn = ClipPPOLoss(actor, critic)
440
+ adv_fn = GAE(value_network=critic, average_gae=True, gamma=0.99, lmbda=0.95)
441
+ optim = torch.optim.Adam(loss_fn.parameters(), lr=2e-4)
442
+
443
+ for data in collector: # collect data
444
+ for epoch in range(10):
445
+ adv_fn(data) # compute advantage
446
+ buffer.extend(data)
447
+ for sample in buffer: # consume data
448
+ loss_vals = loss_fn(sample)
449
+ loss_val = sum(
450
+ value for key, value in loss_vals.items() if
451
+ key.startswith("loss")
452
+ )
453
+ loss_val.backward()
454
+ optim.step()
455
+ optim.zero_grad()
456
+ print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}")
457
+ ```
458
+ </details>
459
+
460
+ Here is an example of how the [environment API](https://pytorch.org/rl/stable/reference/envs.html)
461
+ relies on tensordict to carry data from one function to another during a rollout
462
+ execution:
463
+ ![Alt Text](https://github.com/pytorch/rl/blob/main/docs/source/_static/img/rollout.gif)
464
+
465
+ `TensorDict` makes it easy to re-use pieces of code across environments, models and
466
+ algorithms.
467
+ <details>
468
+ <summary>Code</summary>
469
+
470
+ For instance, here's how to code a rollout in TorchRL:
471
+
472
+ ```diff
473
+ - obs, done = env.reset()
474
+ + tensordict = env.reset()
475
+ policy = SafeModule(
476
+ model,
477
+ in_keys=["observation_pixels", "observation_vector"],
478
+ out_keys=["action"],
479
+ )
480
+ out = []
481
+ for i in range(n_steps):
482
+ - action, log_prob = policy(obs)
483
+ - next_obs, reward, done, info = env.step(action)
484
+ - out.append((obs, next_obs, action, log_prob, reward, done))
485
+ - obs = next_obs
486
+ + tensordict = policy(tensordict)
487
+ + tensordict = env.step(tensordict)
488
+ + out.append(tensordict)
489
+ + tensordict = step_mdp(tensordict) # renames next_observation_* keys to observation_*
490
+ - obs, next_obs, action, log_prob, reward, done = [torch.stack(vals, 0) for vals in zip(*out)]
491
+ + out = torch.stack(out, 0) # TensorDict supports multiple tensor operations
492
+ ```
493
+ </details>
494
+
495
+ Using this, TorchRL abstracts away the input / output signatures of the modules, env,
496
+ collectors, replay buffers and losses of the library, allowing all primitives
497
+ to be easily recycled across settings.
498
+
499
+ <details>
500
+ <summary>Code</summary>
501
+
502
+ Here's another example of an off-policy training loop in TorchRL (assuming
503
+ that a data collector, a replay buffer, a loss and an optimizer have been instantiated):
504
+
505
+ ```diff
506
+ - for i, (obs, next_obs, action, hidden_state, reward, done) in enumerate(collector):
507
+ + for i, tensordict in enumerate(collector):
508
+ - replay_buffer.add((obs, next_obs, action, log_prob, reward, done))
509
+ + replay_buffer.add(tensordict)
510
+ for j in range(num_optim_steps):
511
+ - obs, next_obs, action, hidden_state, reward, done = replay_buffer.sample(batch_size)
512
+ - loss = loss_fn(obs, next_obs, action, hidden_state, reward, done)
513
+ + tensordict = replay_buffer.sample(batch_size)
514
+ + loss = loss_fn(tensordict)
515
+ loss.backward()
516
+ optim.step()
517
+ optim.zero_grad()
518
+ ```
519
+ This training loop can be re-used across algorithms as it makes a minimal number of assumptions about the structure of the data.
520
+ </details>
521
+
522
+ TensorDict supports multiple tensor operations on its device and shape
523
+ (the shape of TensorDict, or its batch size, is the common arbitrary N first dimensions of all its contained tensors):
524
+
525
+ <details>
526
+ <summary>Code</summary>
527
+
528
+ ```python
529
+ # stack and cat
530
+ tensordict = torch.stack(list_of_tensordicts, 0)
531
+ tensordict = torch.cat(list_of_tensordicts, 0)
532
+ # reshape
533
+ tensordict = tensordict.view(-1)
534
+ tensordict = tensordict.permute(0, 2, 1)
535
+ tensordict = tensordict.unsqueeze(-1)
536
+ tensordict = tensordict.squeeze(-1)
537
+ # indexing
538
+ tensordict = tensordict[:2]
539
+ tensordict[:, 2] = sub_tensordict
540
+ # device and memory location
541
+ tensordict.cuda()
542
+ tensordict.to("cuda:1")
543
+ tensordict.share_memory_()
544
+ ```
545
+ </details>
546
+
547
+ TensorDict comes with a dedicated [`tensordict.nn`](https://pytorch.github.io/tensordict/reference/nn.html)
548
+ module that contains everything you might need to write your model with it.
549
+ And it is `functorch` and `torch.compile` compatible!
550
+
551
+ <details>
552
+ <summary>Code</summary>
553
+
554
+ ```diff
555
+ transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
556
+ + td_module = SafeModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"])
557
+ src = torch.rand((10, 32, 512))
558
+ tgt = torch.rand((20, 32, 512))
559
+ + tensordict = TensorDict({"src": src, "tgt": tgt}, batch_size=[20, 32])
560
+ - out = transformer_model(src, tgt)
561
+ + td_module(tensordict)
562
+ + out = tensordict["out"]
563
+ ```
564
+
565
+ The `TensorDictSequential` class allows to branch sequences of `nn.Module` instances in a highly modular way.
566
+ For instance, here is an implementation of a transformer using the encoder and decoder blocks:
567
+ ```python
568
+ encoder_module = TransformerEncoder(...)
569
+ encoder = TensorDictSequential(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"])
570
+ decoder_module = TransformerDecoder(...)
571
+ decoder = TensorDictModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"])
572
+ transformer = TensorDictSequential(encoder, decoder)
573
+ assert transformer.in_keys == ["src", "src_mask", "tgt"]
574
+ assert transformer.out_keys == ["memory", "output"]
575
+ ```
576
+
577
+ `TensorDictSequential` allows to isolate subgraphs by querying a set of desired input / output keys:
578
+ ```python
579
+ transformer.select_subsequence(out_keys=["memory"]) # returns the encoder
580
+ transformer.select_subsequence(in_keys=["tgt", "memory"]) # returns the decoder
581
+ ```
582
+ </details>
583
+
584
+ Check [TensorDict tutorials](https://pytorch.github.io/tensordict/) to
585
+ learn more!
586
+
587
+
588
+ ## Features
589
+
590
+ - A common [interface for environments](https://github.com/pytorch/rl/blob/main/torchrl/envs)
591
+ which supports common libraries (OpenAI gym, deepmind control lab, etc.)<sup>(1)</sup> and state-less execution
592
+ (e.g. Model-based environments).
593
+ The [batched environments](https://github.com/pytorch/rl/blob/main/torchrl/envs/batched_envs.py) containers allow parallel execution<sup>(2)</sup>.
594
+ A common PyTorch-first class of [tensor-specification class](https://github.com/pytorch/rl/blob/main/torchrl/data/tensor_specs.py) is also provided.
595
+ TorchRL's environments API is simple but stringent and specific. Check the
596
+ [documentation](https://pytorch.org/rl/stable/reference/envs.html)
597
+ and [tutorial](https://pytorch.org/rl/stable/tutorials/pendulum.html) to learn more!
598
+ <details>
599
+ <summary>Code</summary>
600
+
601
+ ```python
602
+ env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
603
+ env_parallel = ParallelEnv(4, env_make) # creates 4 envs in parallel
604
+ tensordict = env_parallel.rollout(max_steps=20, policy=None) # random rollout (no policy given)
605
+ assert tensordict.shape == [4, 20] # 4 envs, 20 steps rollout
606
+ env_parallel.action_spec.is_in(tensordict["action"]) # spec check returns True
607
+ ```
608
+ </details>
609
+
610
+ - multiprocess and distributed [data collectors](https://github.com/pytorch/rl/blob/main/torchrl/collectors/collectors.py)<sup>(2)</sup>
611
+ that work synchronously or asynchronously.
612
+ Through the use of TensorDict, TorchRL's training loops are made very similar
613
+ to regular training loops in supervised
614
+ learning (although the "dataloader" -- read data collector -- is modified on-the-fly):
615
+ <details>
616
+ <summary>Code</summary>
617
+
618
+ ```python
619
+ env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
620
+ collector = MultiaSyncDataCollector(
621
+ [env_make, env_make],
622
+ policy=policy,
623
+ devices=["cuda:0", "cuda:0"],
624
+ total_frames=10000,
625
+ frames_per_batch=50,
626
+ ...
627
+ )
628
+ for i, tensordict_data in enumerate(collector):
629
+ loss = loss_module(tensordict_data)
630
+ loss.backward()
631
+ optim.step()
632
+ optim.zero_grad()
633
+ collector.update_policy_weights_()
634
+ ```
635
+ </details>
636
+
637
+ Check our [distributed collector examples](https://github.com/pytorch/rl/blob/main/examples/distributed/collectors) to
638
+ learn more about ultra-fast data collection with TorchRL.
639
+
640
+ - efficient<sup>(2)</sup> and generic<sup>(1)</sup> [replay buffers](https://github.com/pytorch/rl/blob/main/torchrl/data/replay_buffers/replay_buffers.py) with modularized storage:
641
+ <details>
642
+ <summary>Code</summary>
643
+
644
+ ```python
645
+ storage = LazyMemmapStorage( # memory-mapped (physical) storage
646
+ cfg.buffer_size,
647
+ scratch_dir="/tmp/"
648
+ )
649
+ buffer = TensorDictPrioritizedReplayBuffer(
650
+ alpha=0.7,
651
+ beta=0.5,
652
+ collate_fn=lambda x: x,
653
+ pin_memory=device != torch.device("cpu"),
654
+ prefetch=10, # multi-threaded sampling
655
+ storage=storage
656
+ )
657
+ ```
658
+ </details>
659
+
660
+ Replay buffers are also offered as wrappers around common datasets for *offline RL*:
661
+ <details>
662
+ <summary>Code</summary>
663
+
664
+ ```python
665
+ from torchrl.data.replay_buffers import SamplerWithoutReplacement
666
+ from torchrl.data.datasets.d4rl import D4RLExperienceReplay
667
+ data = D4RLExperienceReplay(
668
+ "maze2d-open-v0",
669
+ split_trajs=True,
670
+ batch_size=128,
671
+ sampler=SamplerWithoutReplacement(drop_last=True),
672
+ )
673
+ for sample in data: # or alternatively sample = data.sample()
674
+ fun(sample)
675
+ ```
676
+ </details>
677
+
678
+
679
+ - cross-library [environment transforms](https://github.com/pytorch/rl/blob/main/torchrl/envs/transforms/transforms.py)<sup>(1)</sup>,
680
+ executed on device and in a vectorized fashion<sup>(2)</sup>,
681
+ which process and prepare the data coming out of the environments to be used by the agent:
682
+ <details>
683
+ <summary>Code</summary>
684
+
685
+ ```python
686
+ env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
687
+ env_base = ParallelEnv(4, env_make, device="cuda:0") # creates 4 envs in parallel
688
+ env = TransformedEnv(
689
+ env_base,
690
+ Compose(
691
+ ToTensorImage(),
692
+ ObservationNorm(loc=0.5, scale=1.0)), # executes the transforms once and on device
693
+ )
694
+ tensordict = env.reset()
695
+ assert tensordict.device == torch.device("cuda:0")
696
+ ```
697
+ Other transforms include: reward scaling (`RewardScaling`), shape operations (concatenation of tensors, unsqueezing etc.), concatenation of
698
+ successive operations (`CatFrames`), resizing (`Resize`) and many more.
699
+
700
+ Unlike other libraries, the transforms are stacked as a list (and not wrapped in each other), which makes it
701
+ easy to add and remove them at will:
702
+ ```python
703
+ env.insert_transform(0, NoopResetEnv()) # inserts the NoopResetEnv transform at the index 0
704
+ ```
705
+ Nevertheless, transforms can access and execute operations on the parent environment:
706
+ ```python
707
+ transform = env.transform[1] # gathers the second transform of the list
708
+ parent_env = transform.parent # returns the base environment of the second transform, i.e. the base env + the first transform
709
+ ```
710
+ </details>
711
+
712
+ - various tools for distributed learning (e.g. [memory mapped tensors](https://github.com/pytorch/tensordict/blob/main/tensordict/memmap.py))<sup>(2)</sup>;
713
+ - various [architectures](https://github.com/pytorch/rl/blob/main/torchrl/modules/models/) and models (e.g. [actor-critic](https://github.com/pytorch/rl/blob/main/torchrl/modules/tensordict_module/actors.py))<sup>(1)</sup>:
714
+ <details>
715
+ <summary>Code</summary>
716
+
717
+ ```python
718
+ # create an nn.Module
719
+ common_module = ConvNet(
720
+ bias_last_layer=True,
721
+ depth=None,
722
+ num_cells=[32, 64, 64],
723
+ kernel_sizes=[8, 4, 3],
724
+ strides=[4, 2, 1],
725
+ )
726
+ # Wrap it in a SafeModule, indicating what key to read in and where to
727
+ # write out the output
728
+ common_module = SafeModule(
729
+ common_module,
730
+ in_keys=["pixels"],
731
+ out_keys=["hidden"],
732
+ )
733
+ # Wrap the policy module in NormalParamsWrapper, such that the output
734
+ # tensor is split in loc and scale, and scale is mapped onto a positive space
735
+ policy_module = SafeModule(
736
+ NormalParamsWrapper(
737
+ MLP(num_cells=[64, 64], out_features=32, activation=nn.ELU)
738
+ ),
739
+ in_keys=["hidden"],
740
+ out_keys=["loc", "scale"],
741
+ )
742
+ # Use a SafeProbabilisticTensorDictSequential to combine the SafeModule with a
743
+ # SafeProbabilisticModule, indicating how to build the
744
+ # torch.distribution.Distribution object and what to do with it
745
+ policy_module = SafeProbabilisticTensorDictSequential( # stochastic policy
746
+ policy_module,
747
+ SafeProbabilisticModule(
748
+ in_keys=["loc", "scale"],
749
+ out_keys="action",
750
+ distribution_class=TanhNormal,
751
+ ),
752
+ )
753
+ value_module = MLP(
754
+ num_cells=[64, 64],
755
+ out_features=1,
756
+ activation=nn.ELU,
757
+ )
758
+ # Wrap the policy and value funciton in a common module
759
+ actor_value = ActorValueOperator(common_module, policy_module, value_module)
760
+ # standalone policy from this
761
+ standalone_policy = actor_value.get_policy_operator()
762
+ ```
763
+ </details>
764
+
765
+ - exploration [wrappers](https://github.com/pytorch/rl/blob/main/torchrl/modules/tensordict_module/exploration.py) and
766
+ [modules](https://github.com/pytorch/rl/blob/main/torchrl/modules/models/exploration.py) to easily swap between exploration and exploitation<sup>(1)</sup>:
767
+ <details>
768
+ <summary>Code</summary>
769
+
770
+ ```python
771
+ policy_explore = EGreedyWrapper(policy)
772
+ with set_exploration_type(ExplorationType.RANDOM):
773
+ tensordict = policy_explore(tensordict) # will use eps-greedy
774
+ with set_exploration_type(ExplorationType.DETERMINISTIC):
775
+ tensordict = policy_explore(tensordict) # will not use eps-greedy
776
+ ```
777
+ </details>
778
+
779
+ - A series of efficient [loss modules](https://github.com/pytorch/rl/tree/main/torchrl/objectives)
780
+ and highly vectorized
781
+ [functional return and advantage](https://github.com/pytorch/rl/blob/main/torchrl/objectives/value/functional.py)
782
+ computation.
783
+
784
+ <details>
785
+ <summary>Code</summary>
786
+
787
+ ### Loss modules
788
+ ```python
789
+ from torchrl.objectives import DQNLoss
790
+ loss_module = DQNLoss(value_network=value_network, gamma=0.99)
791
+ tensordict = replay_buffer.sample(batch_size)
792
+ loss = loss_module(tensordict)
793
+ ```
794
+
795
+ ### Advantage computation
796
+ ```python
797
+ from torchrl.objectives.value.functional import vec_td_lambda_return_estimate
798
+ advantage = vec_td_lambda_return_estimate(gamma, lmbda, next_state_value, reward, done, terminated)
799
+ ```
800
+
801
+ </details>
802
+
803
+ - a generic [trainer class](https://github.com/pytorch/rl/blob/main/torchrl/trainers/trainers.py)<sup>(1)</sup> that
804
+ executes the aforementioned training loop. Through a hooking mechanism,
805
+ it also supports any logging or data transformation operation at any given
806
+ time.
807
+
808
+ - various [recipes](https://github.com/pytorch/rl/blob/main/torchrl/trainers/helpers/models.py) to build models that
809
+ correspond to the environment being deployed.
810
+
811
+ - **LLM API**: Complete framework for language model fine-tuning with unified wrappers for Hugging Face and vLLM backends,
812
+ conversation management with automatic chat template detection, tool integration (Python execution, function calling),
813
+ specialized objectives (GRPO, SFT), and high-performance async collectors. Perfect for RLHF, supervised fine-tuning,
814
+ and tool-augmented training scenarios.
815
+ <details>
816
+ <summary>Code</summary>
817
+
818
+ ```python
819
+ from torchrl.envs.llm import ChatEnv
820
+ from torchrl.modules.llm import TransformersWrapper
821
+ from torchrl.envs.llm.transforms import PythonInterpreter
822
+
823
+ # Create environment with tool execution
824
+ env = ChatEnv(
825
+ tokenizer=tokenizer,
826
+ system_prompt="You can execute Python code.",
827
+ batch_size=[1]
828
+ ).append_transform(PythonInterpreter())
829
+
830
+ # Wrap language model for training
831
+ llm = TransformersWrapper(
832
+ model=model,
833
+ tokenizer=tokenizer,
834
+ input_mode="history"
835
+ )
836
+
837
+ # Multi-turn conversation with tool use
838
+ obs = env.reset(TensorDict({"query": "Calculate 2+2"}, batch_size=[1]))
839
+ llm_output = llm(obs) # Generates response
840
+ obs = env.step(llm_output) # Environment processes response
841
+ ```
842
+ </details>
843
+
844
+ If you feel a feature is missing from the library, please submit an issue!
845
+ If you would like to contribute to new features, check our [call for contributions](https://github.com/pytorch/rl/issues/509) and our [contribution](https://github.com/pytorch/rl/blob/main/CONTRIBUTING.md) page.
846
+
847
+
848
+ ## Examples, tutorials and demos
849
+
850
+ A series of [State-of-the-Art implementations](https://github.com/pytorch/rl/blob/main/sota-implementations/) are provided with an illustrative purpose:
851
+
852
+ <table>
853
+ <tr>
854
+ <td><strong>Algorithm</strong>
855
+ </td>
856
+ <td><strong>Compile Support**</strong>
857
+ </td>
858
+ <td><strong>Tensordict-free API</strong>
859
+ </td>
860
+ <td><strong>Modular Losses</strong>
861
+ </td>
862
+ <td><strong>Continuous and Discrete</strong>
863
+ </td>
864
+ </tr>
865
+ <tr>
866
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/dqn">DQN</a>
867
+ </td>
868
+ <td> 1.9x
869
+ </td>
870
+ <td> +
871
+ </td>
872
+ <td> NA
873
+ </td>
874
+ <td> + (through <a href="https://pytorch.org/rl/stable/reference/generated/torchrl.envs.transforms.ActionDiscretizer.html?highlight=actiondiscretizer">ActionDiscretizer</a> transform)
875
+ </td>
876
+ </tr>
877
+ <tr>
878
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/ddpg/ddpg.py">DDPG</a>
879
+ </td>
880
+ <td> 1.87x
881
+ </td>
882
+ <td> +
883
+ </td>
884
+ <td> +
885
+ </td>
886
+ <td> - (continuous only)
887
+ </td>
888
+ </tr>
889
+ <tr>
890
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/iql/">IQL</a>
891
+ </td>
892
+ <td> 3.22x
893
+ </td>
894
+ <td> +
895
+ </td>
896
+ <td> +
897
+ </td>
898
+ <td> +
899
+ </td>
900
+ </tr>
901
+ <tr>
902
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/cql/cql_offline.py">CQL</a>
903
+ </td>
904
+ <td> 2.68x
905
+ </td>
906
+ <td> +
907
+ </td>
908
+ <td> +
909
+ </td>
910
+ <td> +
911
+ </td>
912
+ </tr>
913
+ <tr>
914
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/td3/td3.py">TD3</a>
915
+ </td>
916
+ <td> 2.27x
917
+ </td>
918
+ <td> +
919
+ </td>
920
+ <td> +
921
+ </td>
922
+ <td> - (continuous only)
923
+ </td>
924
+ </tr>
925
+ <tr>
926
+ <td>
927
+ <a href="https://github.com/pytorch/rl/blob/main/sota-implementations/td3_bc/td3_bc.py">TD3+BC</a>
928
+ </td>
929
+ <td> untested
930
+ </td>
931
+ <td> +
932
+ </td>
933
+ <td> +
934
+ </td>
935
+ <td> - (continuous only)
936
+ </td>
937
+ </tr>
938
+ <tr>
939
+ <td>
940
+ <a href="https://github.com/pytorch/rl/blob/main/examples/a2c/">A2C</a>
941
+ </td>
942
+ <td> 2.67x
943
+ </td>
944
+ <td> +
945
+ </td>
946
+ <td> -
947
+ </td>
948
+ <td> +
949
+ </td>
950
+ </tr>
951
+ <tr>
952
+ <td>
953
+ <a href="https://github.com/pytorch/rl/blob/main/sota-implementations/ppo/">PPO</a>
954
+ </td>
955
+ <td> 2.42x
956
+ </td>
957
+ <td> +
958
+ </td>
959
+ <td> -
960
+ </td>
961
+ <td> +
962
+ </td>
963
+ </tr>
964
+ <tr>
965
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/sac/sac.py">SAC</a>
966
+ </td>
967
+ <td> 2.62x
968
+ </td>
969
+ <td> +
970
+ </td>
971
+ <td> -
972
+ </td>
973
+ <td> +
974
+ </td>
975
+ </tr>
976
+ <tr>
977
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/redq/redq.py">REDQ</a>
978
+ </td>
979
+ <td> 2.28x
980
+ </td>
981
+ <td> +
982
+ </td>
983
+ <td> -
984
+ </td>
985
+ <td> - (continuous only)
986
+ </td>
987
+ </tr>
988
+ <tr>
989
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/dreamer/dreamer.py">Dreamer v1</a>
990
+ </td>
991
+ <td> untested
992
+ </td>
993
+ <td> +
994
+ </td>
995
+ <td> + (<a href="https://pytorch.org/rl/stable/reference/objectives.html#dreamer">different classes</a>)
996
+ </td>
997
+ <td> - (continuous only)
998
+ </td>
999
+ </tr>
1000
+ <tr>
1001
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/decision_transformer">Decision Transformers</a>
1002
+ </td>
1003
+ <td> untested
1004
+ </td>
1005
+ <td> +
1006
+ </td>
1007
+ <td> NA
1008
+ </td>
1009
+ <td> - (continuous only)
1010
+ </td>
1011
+ </tr>
1012
+ <tr>
1013
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/crossq">CrossQ</a>
1014
+ </td>
1015
+ <td> untested
1016
+ </td>
1017
+ <td> +
1018
+ </td>
1019
+ <td> +
1020
+ </td>
1021
+ <td> - (continuous only)
1022
+ </td>
1023
+ </tr>
1024
+ <tr>
1025
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/gail">Gail</a>
1026
+ </td>
1027
+ <td> untested
1028
+ </td>
1029
+ <td> +
1030
+ </td>
1031
+ <td> NA
1032
+ </td>
1033
+ <td> +
1034
+ </td>
1035
+ </tr>
1036
+ <tr>
1037
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/impala">Impala</a>
1038
+ </td>
1039
+ <td> untested
1040
+ </td>
1041
+ <td> +
1042
+ </td>
1043
+ <td> -
1044
+ </td>
1045
+ <td> +
1046
+ </td>
1047
+ </tr>
1048
+ <tr>
1049
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/multiagent/iql.py">IQL (MARL)</a>
1050
+ </td>
1051
+ <td> untested
1052
+ </td>
1053
+ <td> +
1054
+ </td>
1055
+ <td> +
1056
+ </td>
1057
+ <td> +
1058
+ </td>
1059
+ </tr>
1060
+ <tr>
1061
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/multiagent/maddpg_iddpg.py">DDPG (MARL)</a>
1062
+ </td>
1063
+ <td> untested
1064
+ </td>
1065
+ <td> +
1066
+ </td>
1067
+ <td> +
1068
+ </td>
1069
+ <td> - (continuous only)
1070
+ </td>
1071
+ </tr>
1072
+ <tr>
1073
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/multiagent/mappo_ippo.py">PPO (MARL)</a>
1074
+ </td>
1075
+ <td> untested
1076
+ </td>
1077
+ <td> +
1078
+ </td>
1079
+ <td> -
1080
+ </td>
1081
+ <td> +
1082
+ </td>
1083
+ </tr>
1084
+ <tr>
1085
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/multiagent/qmix_vdn.py">QMIX-VDN (MARL)</a>
1086
+ </td>
1087
+ <td> untested
1088
+ </td>
1089
+ <td> +
1090
+ </td>
1091
+ <td> NA
1092
+ </td>
1093
+ <td> +
1094
+ </td>
1095
+ </tr>
1096
+ <tr>
1097
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/multiagent/sac.py">SAC (MARL)</a>
1098
+ </td>
1099
+ <td> untested
1100
+ </td>
1101
+ <td> +
1102
+ </td>
1103
+ <td> -
1104
+ </td>
1105
+ <td> +
1106
+ </td>
1107
+ </tr>
1108
+ <tr>
1109
+ <td><a href="https://github.com/pytorch/rl/blob/main/examples/rlhf">RLHF</a>
1110
+ </td>
1111
+ <td> NA
1112
+ </td>
1113
+ <td> +
1114
+ </td>
1115
+ <td> NA
1116
+ </td>
1117
+ <td> NA
1118
+ </td>
1119
+ </tr>
1120
+ <tr>
1121
+ <td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/grpo">LLM API (GRPO)</a>
1122
+ </td>
1123
+ <td> NA
1124
+ </td>
1125
+ <td> +
1126
+ </td>
1127
+ <td> +
1128
+ </td>
1129
+ <td> NA
1130
+ </td>
1131
+ </tr>
1132
+ </table>
1133
+
1134
+ ** The number indicates expected speed-up compared to eager mode when executed on CPU. Numbers may vary depending on
1135
+ architecture and device.
1136
+
1137
+ and many more to come!
1138
+
1139
+ [Code examples](examples/) displaying toy code snippets and training scripts are also available
1140
+ - [LLM API & GRPO](sota-implementations/grpo) - Complete language model fine-tuning pipeline
1141
+ - [RLHF](examples/rlhf)
1142
+ - [Memory-mapped replay buffers](examples/torchrl_features)
1143
+
1144
+
1145
+ Check the [examples](https://github.com/pytorch/rl/blob/main/sota-implementations/) directory for more details
1146
+ about handling the various configuration settings.
1147
+
1148
+ We also provide [tutorials and demos](https://pytorch.org/rl/stable#tutorials) that give a sense of
1149
+ what the library can do.
1150
+
1151
+ ## Citation
1152
+
1153
+ If you're using TorchRL, please refer to this BibTeX entry to cite this work:
1154
+ ```
1155
+ @misc{bou2023torchrl,
1156
+ title={TorchRL: A data-driven decision-making library for PyTorch},
1157
+ author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
1158
+ year={2023},
1159
+ eprint={2306.00577},
1160
+ archivePrefix={arXiv},
1161
+ primaryClass={cs.LG}
1162
+ }
1163
+ ```
1164
+
1165
+ ## Installation
1166
+
1167
+ ### Create a new virtual environment:
1168
+ ```bash
1169
+ python -m venv torchrl
1170
+ source torchrl/bin/activate # On Windows use: venv\Scripts\activate
1171
+ ```
1172
+
1173
+ Or create a conda environment where the packages will be installed.
1174
+
1175
+ ```
1176
+ conda create --name torchrl python=3.10
1177
+ conda activate torchrl
1178
+ ```
1179
+
1180
+ ### Install dependencies:
1181
+
1182
+ #### PyTorch
1183
+
1184
+ Depending on the use of torchrl that you want to make, you may want to
1185
+ install the latest (nightly) PyTorch release or the latest stable version of PyTorch.
1186
+ See [here](https://pytorch.org/get-started/locally/) for a detailed list of commands,
1187
+ including `pip3` or other special installation instructions.
1188
+
1189
+ TorchRL offers a few pre-defined dependencies such as `"torchrl[tests]"`, `"torchrl[atari]"`, `"torchrl[utils]"` etc.
1190
+
1191
+ For the experimental training interface and configuration system, install:
1192
+ ```bash
1193
+ pip3 install "torchrl[utils]" # Includes hydra-core and other utilities
1194
+ ```
1195
+
1196
+ #### Torchrl
1197
+
1198
+ You can install the **latest stable release** by using
1199
+ ```bash
1200
+ pip3 install torchrl
1201
+ ```
1202
+ This should work on linux (including AArch64 machines), Windows 10 and OsX (Metal chips only).
1203
+ On certain Windows machines (Windows 11), one should build the library locally.
1204
+ This can be done in two ways:
1205
+
1206
+ ```bash
1207
+ # Install and build locally v0.8.1 of the library without cloning
1208
+ pip3 install git+https://github.com/pytorch/rl@v0.8.1
1209
+ # Clone the library and build it locally
1210
+ git clone https://github.com/pytorch/tensordict
1211
+ git clone https://github.com/pytorch/rl
1212
+ pip install -e tensordict
1213
+ pip install -e rl
1214
+ ```
1215
+
1216
+ If you use `uv` (instead of `pip`) and you have already installed a specific PyTorch build (e.g. nightly),
1217
+ make sure `uv` doesn't re-resolve dependencies (which can downgrade PyTorch). Use `--no-deps` for the local installs:
1218
+
1219
+ ```bash
1220
+ uv pip install --no-deps -e tensordict
1221
+ uv pip install --no-deps -e rl
1222
+ ```
1223
+
1224
+ Note that tensordict local build requires `cmake` to be installed via [homebrew](https://brew.sh/) (MacOS) or another package manager
1225
+ such as `apt`, `apt-get`, `conda` or `yum` but NOT `pip`, as well as `pip install "pybind11[global]"`.
1226
+
1227
+ One can also build the wheels to distribute to co-workers using
1228
+ ```bash
1229
+ pip install build
1230
+ python -m build --wheel
1231
+ ```
1232
+ Your wheels will be stored there `./dist/torchrl<name>.whl` and installable via
1233
+ ```bash
1234
+ pip install torchrl<name>.whl
1235
+ ```
1236
+
1237
+ The **nightly build** can be installed via
1238
+ ```bash
1239
+ pip3 install tensordict-nightly torchrl-nightly
1240
+ ```
1241
+ which we currently only ship for Linux machines.
1242
+ Importantly, the nightly builds require the nightly builds of PyTorch too.
1243
+ Also, a local build of torchrl with the nightly build of tensordict may fail - install both nightlies or both local builds but do not mix them.
1244
+
1245
+
1246
+ **Disclaimer**: As of today, TorchRL requires Python 3.10+ and is roughly compatible with any pytorch version >= 2.1. Installing it will not
1247
+ directly require a newer version of pytorch to be installed. Indirectly though, tensordict still requires the latest
1248
+ PyTorch to be installed and we are working hard to loosen that requirement.
1249
+ The C++ binaries of TorchRL (mainly for prioritized replay buffers) will only work with PyTorch 2.7.0 and above.
1250
+ Some features (e.g., working with nested jagged tensors) may also
1251
+ be limited with older versions of pytorch. It is recommended to use the latest TorchRL with the latest PyTorch version
1252
+ unless there is a strong reason not to do so.
1253
+
1254
+ **Optional dependencies**
1255
+
1256
+ The following libraries can be installed depending on the usage one wants to
1257
+ make of torchrl:
1258
+ ```
1259
+ # diverse
1260
+ pip3 install tqdm tensorboard "hydra-core>=1.1" hydra-submitit-launcher
1261
+
1262
+ # rendering
1263
+ pip3 install "moviepy<2.0.0"
1264
+
1265
+ # deepmind control suite
1266
+ pip3 install dm_control
1267
+
1268
+ # gym, atari games
1269
+ pip3 install "gym[atari]" "gym[accept-rom-license]" pygame
1270
+
1271
+ # tests
1272
+ pip3 install pytest pyyaml pytest-instafail
1273
+
1274
+ # tensorboard
1275
+ pip3 install tensorboard
1276
+
1277
+ # wandb
1278
+ pip3 install wandb
1279
+ ```
1280
+
1281
+ Versioning issues can cause error message of the type ```undefined symbol```
1282
+ and such. For these, refer to the [versioning issues document](https://github.com/pytorch/rl/blob/main/knowledge_base/VERSIONING_ISSUES.md)
1283
+ for a complete explanation and proposed workarounds.
1284
+
1285
+ ## Asking a question
1286
+
1287
+ If you spot a bug in the library, please raise an issue in this repo.
1288
+
1289
+ If you have a more generic question regarding RL in PyTorch, post it on
1290
+ the [PyTorch forum](https://discuss.pytorch.org/c/reinforcement-learning/6).
1291
+
1292
+ ## Contributing
1293
+
1294
+ Internal collaborations to torchrl are welcome! Feel free to fork, submit issues and PRs.
1295
+ You can checkout the detailed contribution guide [here](https://github.com/pytorch/rl/blob/main/CONTRIBUTING.md).
1296
+ As mentioned above, a list of open contributions can be found in [here](https://github.com/pytorch/rl/issues/509).
1297
+
1298
+ Contributors are recommended to install [pre-commit hooks](https://pre-commit.com/) (using `pre-commit install`). pre-commit will check for linting related issues when the code is committed locally. You can disable th check by appending `-n` to your commit command: `git commit -m <commit message> -n`
1299
+
1300
+
1301
+ ## Disclaimer
1302
+
1303
+ This library is released as a PyTorch beta feature.
1304
+ BC-breaking changes are likely to happen but they will be introduced with a deprecation
1305
+ warranty after a few release cycles.
1306
+
1307
+ # License
1308
+ TorchRL is licensed under the MIT License. See [LICENSE](https://github.com/pytorch/rl/blob/main/LICENSE) for details.