torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,254 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import argparse
6
+ import functools
7
+
8
+ import pytest
9
+ import torch
10
+ from tensordict import TensorDict
11
+
12
+ from torchrl.data import (
13
+ LazyMemmapStorage,
14
+ LazyTensorStorage,
15
+ ListStorage,
16
+ ReplayBuffer,
17
+ TensorDictPrioritizedReplayBuffer,
18
+ TensorDictReplayBuffer,
19
+ )
20
+ from torchrl.data.replay_buffers import (
21
+ RandomSampler,
22
+ SamplerWithoutReplacement,
23
+ SliceSampler,
24
+ )
25
+
26
+ _TensorDictPrioritizedReplayBuffer = functools.partial(
27
+ TensorDictPrioritizedReplayBuffer, alpha=1, beta=0.9
28
+ )
29
+ # preserve the name of the class even after partial
30
+ _TensorDictPrioritizedReplayBuffer.__name__ = TensorDictPrioritizedReplayBuffer.__name__
31
+
32
+
33
+ class create_rb:
34
+ def __init__(self, rb, storage, sampler, populated, size=1_000_000):
35
+ self.storage = storage
36
+ self.rb = rb
37
+ self.sampler = sampler
38
+ self.populated = populated
39
+ self.size = size
40
+
41
+ def __call__(self):
42
+
43
+ kwargs = {"batch_size": 256}
44
+ if self.sampler is not None:
45
+ kwargs["sampler"] = self.sampler()
46
+ if self.storage is not None:
47
+ kwargs["storage"] = self.storage(self.size)
48
+
49
+ rb = self.rb(**kwargs)
50
+ data = TensorDict(
51
+ {
52
+ "a": torch.zeros(self.size, 5),
53
+ ("b", "c"): torch.zeros(self.size, 3, 32, 32, dtype=torch.uint8),
54
+ },
55
+ batch_size=[self.size],
56
+ )
57
+ if "sampler" in kwargs and isinstance(kwargs["sampler"], SliceSampler):
58
+ data["traj"] = torch.arange(self.size) // 123
59
+ if self.populated:
60
+ rb.extend(data)
61
+ return ((rb,), {})
62
+ else:
63
+ return ((rb, data), {})
64
+
65
+
66
+ def populate(rb, td):
67
+ rb.extend(td)
68
+
69
+
70
+ def sample(rb):
71
+ rb.sample()
72
+
73
+
74
+ def iterate(rb):
75
+ next(rb)
76
+
77
+
78
+ @pytest.mark.parametrize(
79
+ "rb,storage,sampler,size",
80
+ [
81
+ [TensorDictReplayBuffer, ListStorage, RandomSampler, 4000],
82
+ [TensorDictReplayBuffer, LazyMemmapStorage, RandomSampler, 10_000],
83
+ [TensorDictReplayBuffer, LazyTensorStorage, RandomSampler, 10_000],
84
+ [TensorDictReplayBuffer, ListStorage, SamplerWithoutReplacement, 4000],
85
+ [TensorDictReplayBuffer, LazyMemmapStorage, SamplerWithoutReplacement, 10_000],
86
+ [TensorDictReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement, 10_000],
87
+ [
88
+ TensorDictReplayBuffer,
89
+ LazyMemmapStorage,
90
+ functools.partial(SliceSampler, num_slices=8, traj_key="traj"),
91
+ 10_000,
92
+ ],
93
+ [
94
+ TensorDictReplayBuffer,
95
+ LazyTensorStorage,
96
+ functools.partial(SliceSampler, num_slices=8, traj_key="traj"),
97
+ 10_000,
98
+ ],
99
+ [_TensorDictPrioritizedReplayBuffer, ListStorage, None, 4000],
100
+ [_TensorDictPrioritizedReplayBuffer, LazyMemmapStorage, None, 10_000],
101
+ [_TensorDictPrioritizedReplayBuffer, LazyTensorStorage, None, 10_000],
102
+ ],
103
+ )
104
+ def test_rb_sample(benchmark, rb, storage, sampler, size):
105
+ (rb,), _ = create_rb(
106
+ rb=rb,
107
+ storage=storage,
108
+ sampler=sampler,
109
+ populated=True,
110
+ size=size,
111
+ )()
112
+ torch.manual_seed(0)
113
+ benchmark(sample, rb)
114
+
115
+
116
+ def infinite_iter(obj):
117
+ torch.manual_seed(0)
118
+ while True:
119
+ yield from iter(obj)
120
+
121
+
122
+ @pytest.mark.parametrize(
123
+ "rb,storage,sampler,size",
124
+ [
125
+ [TensorDictReplayBuffer, ListStorage, RandomSampler, 4000],
126
+ [TensorDictReplayBuffer, LazyMemmapStorage, RandomSampler, 10_000],
127
+ [TensorDictReplayBuffer, LazyTensorStorage, RandomSampler, 10_000],
128
+ [TensorDictReplayBuffer, ListStorage, SamplerWithoutReplacement, 4000],
129
+ [TensorDictReplayBuffer, LazyMemmapStorage, SamplerWithoutReplacement, 10_000],
130
+ [TensorDictReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement, 10_000],
131
+ [_TensorDictPrioritizedReplayBuffer, ListStorage, None, 4000],
132
+ [_TensorDictPrioritizedReplayBuffer, LazyMemmapStorage, None, 10_000],
133
+ [_TensorDictPrioritizedReplayBuffer, LazyTensorStorage, None, 10_000],
134
+ ],
135
+ )
136
+ def test_rb_iterate(benchmark, rb, storage, sampler, size):
137
+ (rb,), _ = create_rb(
138
+ rb=rb,
139
+ storage=storage,
140
+ sampler=sampler,
141
+ populated=True,
142
+ size=size,
143
+ )()
144
+ benchmark(iterate, infinite_iter(rb))
145
+
146
+
147
+ @pytest.mark.parametrize(
148
+ "rb,storage,sampler,size",
149
+ [
150
+ [TensorDictReplayBuffer, ListStorage, RandomSampler, 400],
151
+ [TensorDictReplayBuffer, LazyMemmapStorage, RandomSampler, 400],
152
+ [TensorDictReplayBuffer, LazyTensorStorage, RandomSampler, 400],
153
+ [TensorDictReplayBuffer, ListStorage, SamplerWithoutReplacement, 400],
154
+ [TensorDictReplayBuffer, LazyMemmapStorage, SamplerWithoutReplacement, 400],
155
+ [TensorDictReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement, 400],
156
+ [_TensorDictPrioritizedReplayBuffer, ListStorage, None, 400],
157
+ [_TensorDictPrioritizedReplayBuffer, LazyMemmapStorage, None, 400],
158
+ [_TensorDictPrioritizedReplayBuffer, LazyTensorStorage, None, 400],
159
+ ],
160
+ )
161
+ def test_rb_populate(benchmark, rb, storage, sampler, size):
162
+ benchmark.pedantic(
163
+ populate,
164
+ setup=create_rb(
165
+ rb=rb,
166
+ storage=storage,
167
+ sampler=sampler,
168
+ populated=False,
169
+ size=size,
170
+ ),
171
+ iterations=1,
172
+ rounds=50,
173
+ )
174
+
175
+
176
+ class create_compiled_tensor_rb:
177
+ def __init__(
178
+ self, rb, storage, sampler, storage_size, data_size, iters, compilable=False
179
+ ):
180
+ self.storage = storage
181
+ self.rb = rb
182
+ self.sampler = sampler
183
+ self.storage_size = storage_size
184
+ self.data_size = data_size
185
+ self.iters = iters
186
+ self.compilable = compilable
187
+
188
+ def __call__(self):
189
+ kwargs = {}
190
+ if self.sampler is not None:
191
+ kwargs["sampler"] = self.sampler()
192
+ if self.storage is not None:
193
+ kwargs["storage"] = self.storage(
194
+ self.storage_size, compilable=self.compilable
195
+ )
196
+
197
+ rb = self.rb(batch_size=3, compilable=self.compilable, **kwargs)
198
+ data = torch.randn(self.data_size, 1)
199
+ return ((rb, data, self.iters), {})
200
+
201
+
202
+ def extend_and_sample(rb, td, iters):
203
+ for _ in range(iters):
204
+ rb.extend(td)
205
+ rb.sample()
206
+
207
+
208
+ def extend_and_sample_compiled(rb, td, iters):
209
+ @torch.compile
210
+ def fn(td):
211
+ rb.extend(td)
212
+ rb.sample()
213
+
214
+ for _ in range(iters):
215
+ fn(td)
216
+
217
+
218
+ @pytest.mark.parametrize(
219
+ "rb,storage,sampler,storage_size,data_size,iters,compiled",
220
+ [
221
+ [ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, True],
222
+ [ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, False],
223
+ [ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, True],
224
+ [ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, False],
225
+ [ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, True],
226
+ [ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, False],
227
+ ],
228
+ )
229
+ def test_rb_extend_sample(
230
+ benchmark, rb, storage, sampler, storage_size, data_size, iters, compiled
231
+ ):
232
+ if compiled:
233
+ torch._dynamo.reset_code_caches()
234
+
235
+ benchmark.pedantic(
236
+ extend_and_sample_compiled if compiled else extend_and_sample,
237
+ setup=create_compiled_tensor_rb(
238
+ rb=rb,
239
+ storage=storage,
240
+ sampler=sampler,
241
+ storage_size=storage_size,
242
+ data_size=data_size,
243
+ iters=iters,
244
+ compilable=compiled,
245
+ ),
246
+ iterations=1,
247
+ warmup_rounds=10,
248
+ rounds=50,
249
+ )
250
+
251
+
252
+ if __name__ == "__main__":
253
+ args, unknown = argparse.ArgumentParser().parse_known_args()
254
+ pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
sota-check/README.md ADDED
@@ -0,0 +1,35 @@
1
+ # SOTA Performance checks
2
+
3
+ This folder contains a `submitit-release-check.sh` file that executes all
4
+ the training scripts using `sbatch` with the default configuration and long them
5
+ into a common WandB project.
6
+
7
+ This script is to be executed before every release to assess the performance of
8
+ the various algorithms available in torchrl. The name of the project will include
9
+ the specific commit of torchrl used to run the scripts (e.g. `torchrl-examples-check-<commit>`).
10
+
11
+ ## Usage
12
+
13
+ To display the script usage, you can use the `--help` option:
14
+
15
+ ```bash
16
+ ./submitit-release-check.sh --help
17
+ ```
18
+
19
+ ## Setup
20
+
21
+ The following setup should allow you to run the scripts:
22
+
23
+ ```bash
24
+ export MUJOCO_GL=egl
25
+
26
+ conda create -n rl-sota-bench python=3.10 -y
27
+ conda install anaconda::libglu -y
28
+ pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121
29
+ pip3 install "gymnasium[atari,mujoco]" vmas tqdm wandb pygame "moviepy<2.0.0" imageio submitit hydra-core transformers
30
+
31
+ cd /path/to/tensordict
32
+ python setup.py develop
33
+ cd /path/to/torchrl
34
+ python setup.py develop
35
+ ```
@@ -0,0 +1,142 @@
1
+ # Examples
2
+
3
+ We provide examples to train the following algorithms:
4
+ - [CQL](../sota-implementations/cql/)
5
+ - [DDPG](ddpg/ddpg.py)
6
+ - [DQN](../sota-implementations/dqn/)
7
+ - [Decision Transformers](../sota-implementations/decision_transformer)
8
+ - [Discrete SAC](discrete_sac/discrete_sac.py)
9
+ - [Dreamer](../sota-implementations/dreamer)
10
+ - [IQL](iql/)
11
+ - [Impala](impala/)
12
+ - [PPO](../sota-implementations/ppo/)
13
+ - [REDQ](redq/redq.py)
14
+ - [SAC](sac/sac.py)
15
+ - [TD3](../sota-implementations/td3/td3.py)
16
+ - [Various multiagent examples](multiagent/)
17
+
18
+ To run these examples, make sure you have installed hydra:
19
+ ```
20
+ pip install hydra-core
21
+ ```
22
+
23
+ Scripts can be run from the directory of interest using:
24
+ ```
25
+ python sac.py
26
+ ```
27
+ or similar. Hyperparameters can be easily changed by providing the arguments to hydra:
28
+ ```
29
+ python sac.py collector.frames_per_batch=63
30
+ ```
31
+
32
+ [//]: # (# Results)
33
+
34
+ [//]: # ()
35
+ [//]: # (Here we can see some results for the SAC and REDQ algorithm.)
36
+
37
+ [//]: # (We average the results over 5 different seeds and plot the standard error.)
38
+
39
+ [//]: # (## Gym's HalfCheetah-v4)
40
+
41
+ [//]: # ()
42
+ [//]: # (<p align="center">)
43
+
44
+ [//]: # (<img src="media/halfcheetah_chart.png" width="600px">)
45
+
46
+ [//]: # (</p>)
47
+
48
+ [//]: # (To reproduce a single run:)
49
+
50
+ [//]: # ()
51
+ [//]: # (```)
52
+
53
+ [//]: # (python sac/sac.py env.name="HalfCheetah-v4" env.task="" env.library="gym")
54
+
55
+ [//]: # (```)
56
+
57
+ [//]: # ()
58
+ [//]: # (``` )
59
+
60
+ [//]: # (python redq/redq.py env.name="HalfCheetah-v4" env.library="gymnasium")
61
+
62
+ [//]: # (```)
63
+
64
+ [//]: # ()
65
+ [//]: # ()
66
+ [//]: # (## dm_control's cheetah-run)
67
+
68
+ [//]: # ()
69
+ [//]: # (<p align="center">)
70
+
71
+ [//]: # (<img src="media/cheetah_chart.png" width="600px">)
72
+
73
+ [//]: # (</p>)
74
+
75
+ [//]: # (To reproduce a single run:)
76
+
77
+ [//]: # ()
78
+ [//]: # (```)
79
+
80
+ [//]: # (python sac/sac.py env.name="cheetah" env.task="run" env.library="dm_control")
81
+
82
+ [//]: # (```)
83
+
84
+ [//]: # ()
85
+ [//]: # (``` )
86
+
87
+ [//]: # (python redq/redq.py env.name="cheetah" env.task="run" env.library="dm_control")
88
+
89
+ [//]: # (```)
90
+
91
+ [//]: # ()
92
+ [//]: # ([//]: # &#40;TODO: adapt these scripts&#41;)
93
+ [//]: # ([//]: # &#40;## Gym's Ant-v4&#41;)
94
+ [//]: # ()
95
+ [//]: # ([//]: # &#40;&#41;)
96
+ [//]: # ([//]: # &#40;<p align="center">&#41;)
97
+ [//]: # ()
98
+ [//]: # ([//]: # &#40;<img src="media/ant_chart.png" width="600px">&#41;)
99
+ [//]: # ()
100
+ [//]: # ([//]: # &#40;</p>&#41;)
101
+ [//]: # ()
102
+ [//]: # ([//]: # &#40;To reproduce a single run:&#41;)
103
+ [//]: # ()
104
+ [//]: # ([//]: # &#40;&#41;)
105
+ [//]: # ([//]: # &#40;```&#41;)
106
+ [//]: # ()
107
+ [//]: # ([//]: # &#40;python sac/sac.py env.name="Ant-v4" env.task="" env.library="gym"&#41;)
108
+ [//]: # ()
109
+ [//]: # ([//]: # &#40;```&#41;)
110
+ [//]: # ()
111
+ [//]: # ([//]: # &#40;&#41;)
112
+ [//]: # ([//]: # &#40;``` &#41;)
113
+ [//]: # ()
114
+ [//]: # ([//]: # &#40;python redq/redq.py env_name="Ant-v4" env_task="" env_library="gym"&#41;)
115
+ [//]: # ()
116
+ [//]: # ([//]: # &#40;```&#41;)
117
+ [//]: # ()
118
+ [//]: # ([//]: # &#40;&#41;)
119
+ [//]: # ([//]: # &#40;## Gym's Walker2D-v4&#41;)
120
+ [//]: # ()
121
+ [//]: # ([//]: # &#40;&#41;)
122
+ [//]: # ([//]: # &#40;<p align="center">&#41;)
123
+ [//]: # ()
124
+ [//]: # ([//]: # &#40;<img src="media/walker2d_chart.png" width="600px">&#41;)
125
+ [//]: # ()
126
+ [//]: # ([//]: # &#40;</p>&#41;)
127
+ [//]: # ()
128
+ [//]: # ([//]: # &#40;To reproduce a single run:&#41;)
129
+ [//]: # ()
130
+ [//]: # ([//]: # &#40;&#41;)
131
+ [//]: # ([//]: # &#40;```&#41;)
132
+ [//]: # ()
133
+ [//]: # ([//]: # &#40;python sac/sac.py env_name="Walker2D-v4" env_task="" env_library="gym"&#41;)
134
+ [//]: # ()
135
+ [//]: # ([//]: # &#40;```&#41;)
136
+ [//]: # ()
137
+ [//]: # ([//]: # &#40;&#41;)
138
+ [//]: # ([//]: # &#40;``` &#41;)
139
+ [//]: # ()
140
+ [//]: # ([//]: # &#40;python redq/redq.py env_name="Walker2D-v4" env_task="" env_library="gym"&#41;)
141
+ [//]: # ()
142
+ [//]: # ([//]: # &#40;```&#41;)
@@ -0,0 +1,39 @@
1
+ ## Reproducing Advantage Actor Critic (A2C) Algorithm Results
2
+
3
+ This repository contains scripts that enable training agents using the Advantage Actor Critic (A2C) Algorithm on MuJoCo and Atari environments. We follow the original paper [Asynchronous Methods for Deep Reinforcement Learning](https://arxiv.org/abs/1602.01783) by Mnih et al. (2016) to implement the A2C algorithm but fix the number of steps during the collection phase.
4
+
5
+
6
+ ## Examples Structure
7
+
8
+ Please note that each example is independent of each other for the sake of simplicity. Each example contains the following files:
9
+
10
+ 1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. a2c_atari.py).
11
+
12
+ 2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils_atari.py).
13
+
14
+ 3. **Configuration File:** This file includes default hyperparameters specified in the original paper. Users can modify these hyperparameters to customize their experiments (e.g. config_atari.yaml).
15
+
16
+
17
+ ## Running the Examples
18
+
19
+ You can execute the A2C algorithm on Atari environments by running the following command:
20
+
21
+ ```bash
22
+ python a2c_atari.py compile.compile=1 compile.cudagraphs=1
23
+ ```
24
+
25
+
26
+ You can execute the A2C algorithm on MuJoCo environments by running the following command:
27
+
28
+ ```bash
29
+ python a2c_mujoco.py compile.compile=1 compile.cudagraphs=1
30
+ ```
31
+
32
+ ## Runtimes
33
+
34
+ Runtimes when executed on H100:
35
+
36
+ | Environment | Eager | Compile | Compile+cudagraphs |
37
+ |-------------|-----------|-----------|--------------------|
38
+ | MUJOCO | < 25 mins | < 23 mins | < 20 mins |
39
+ | ATARI | < 85 mins | < 60 mins | < 45 mins |