torchrl 0.11.0__cp314-cp314t-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cpython-314t-aarch64-linux-gnu.so +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/METADATA +1308 -0
  390. torchrl-0.11.0.dist-info/RECORD +394 -0
  391. torchrl-0.11.0.dist-info/WHEEL +5 -0
  392. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  393. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,104 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """
6
+ Benchmarking different types of batched environments
7
+ ====================================================
8
+ Compares runtime for different environments which allow performing operations in a batch.
9
+ - SerialEnv executes the operations sequentially
10
+ - ParallelEnv uses multiprocess parallelism
11
+ - MultiThreadedEnv uses multithreaded parallelism and is based on envpool library.
12
+
13
+ Run as "python benchmarks/benchmark_batched_envs.py"
14
+ Requires pandas ("pip install pandas").
15
+
16
+ """
17
+
18
+ import pandas as pd
19
+ from torchrl._utils import logger as torchrl_logger
20
+
21
+ pd.set_option("display.max_columns", 100)
22
+ pd.set_option("display.width", 1000)
23
+ import torch
24
+ from torch.utils.benchmark import Timer
25
+ from torchrl.envs import MultiThreadedEnv, ParallelEnv, SerialEnv
26
+ from torchrl.envs.libs.gym import GymEnv
27
+
28
+ N_STEPS = 1000
29
+
30
+
31
+ def create_multithreaded(num_workers, device):
32
+ env = MultiThreadedEnv(num_workers=num_workers, env_name="Pendulum-v1")
33
+ # GPU doesn't lead to any speedup for MultiThreadedEnv, as the underlying library (envpool) works only on CPU
34
+ env = env.to(device=torch.device(device))
35
+ env.rollout(policy=None, max_steps=5) # Warm-up
36
+ return env
37
+
38
+
39
+ def factory():
40
+ return GymEnv("Pendulum-v1")
41
+
42
+
43
+ def create_serial(num_workers, device):
44
+ env = SerialEnv(num_workers=num_workers, create_env_fn=factory)
45
+ env = env.to(device=torch.device(device))
46
+ env.rollout(policy=None, max_steps=5) # Warm-up
47
+ return env
48
+
49
+
50
+ def create_parallel(num_workers, device):
51
+ env = ParallelEnv(num_workers=num_workers, create_env_fn=factory)
52
+ env = env.to(device=torch.device(device))
53
+ env.rollout(policy=None, max_steps=5) # Warm-up
54
+ return env
55
+
56
+
57
+ def run_env(env):
58
+ env.rollout(policy=None, max_steps=N_STEPS)
59
+
60
+
61
+ if __name__ == "__main__":
62
+ res = {}
63
+ devices = ["cpu"]
64
+ if torch.cuda.is_available():
65
+ devices.append("cuda")
66
+ for device in devices:
67
+ for num_workers in [1, 4, 16]:
68
+ torchrl_logger.info(f"With num_workers={num_workers}, {device}")
69
+ torchrl_logger.info("Multithreaded...")
70
+ env_multithreaded = create_multithreaded(num_workers, device)
71
+ res_multithreaded = Timer(
72
+ stmt="run_env(env)",
73
+ setup="from __main__ import run_env",
74
+ globals={"env": env_multithreaded},
75
+ )
76
+ time_multithreaded = res_multithreaded.blocked_autorange().mean
77
+
78
+ torchrl_logger.info("Serial...")
79
+ env_serial = create_serial(num_workers, device)
80
+ res_serial = Timer(
81
+ stmt="run_env(env)",
82
+ setup="from __main__ import run_env",
83
+ globals={"env": env_serial},
84
+ )
85
+ time_serial = res_serial.blocked_autorange().mean
86
+
87
+ torchrl_logger.info("Parallel...")
88
+ env_parallel = create_parallel(num_workers, device)
89
+ res_parallel = Timer(
90
+ stmt="run_env(env)",
91
+ setup="from __main__ import run_env",
92
+ globals={"env": env_parallel},
93
+ )
94
+ time_parallel = res_parallel.blocked_autorange().mean
95
+
96
+ res[f"num_workers_{num_workers}_{device}"] = {
97
+ "Serial, s": time_serial,
98
+ "Parallel, s": time_parallel,
99
+ "Multithreaded, s": time_multithreaded,
100
+ }
101
+ df = pd.DataFrame(res).round(3)
102
+ gain = 1 - df.loc["Multithreaded, s"] / df.loc["Parallel, s"]
103
+ df.loc["Gain, %", :] = (gain * 100).round(1)
104
+ df.to_csv("multithreaded_benchmark.csv")
benchmarks/conftest.py ADDED
@@ -0,0 +1,91 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import os
6
+ import time
7
+ import warnings
8
+ from collections import defaultdict
9
+
10
+ import pytest
11
+ from torchrl._utils import logger as torchrl_logger
12
+
13
+ CALL_TIMES = defaultdict(float)
14
+
15
+
16
+ def pytest_sessionfinish(maxprint=50):
17
+ out_str = """
18
+ Call times:
19
+ ===========
20
+ """
21
+ keys = list(CALL_TIMES.keys())
22
+ if len(keys) > 1:
23
+ maxchar = max(*[len(key) for key in keys])
24
+ elif len(keys) == 1:
25
+ maxchar = len(keys[0])
26
+ else:
27
+ return
28
+ for i, (key, item) in enumerate(
29
+ sorted(CALL_TIMES.items(), key=lambda x: x[1], reverse=True)
30
+ ):
31
+ spaces = " " + " " * (maxchar - len(key))
32
+ out_str += f"\t{key}{spaces}{item: 4.4f}s\n"
33
+ if i == maxprint - 1:
34
+ break
35
+ torchrl_logger.info(out_str)
36
+
37
+
38
+ @pytest.fixture(autouse=True)
39
+ def measure_duration(request: pytest.FixtureRequest):
40
+ start_time = time.time()
41
+
42
+ def fin():
43
+ duration = time.time() - start_time
44
+ name = request.node.name
45
+ class_name = request.cls.__name__ if request.cls else None
46
+ name = name.split("[")[0]
47
+ if class_name is not None:
48
+ name = "::".join([class_name, name])
49
+ file = os.path.basename(request.path)
50
+ name = f"{file}::{name}"
51
+ CALL_TIMES[name] = CALL_TIMES[name] + duration
52
+
53
+ request.addfinalizer(fin)
54
+
55
+
56
+ def pytest_addoption(parser):
57
+ parser.addoption("--rank", action="store")
58
+
59
+
60
+ @pytest.fixture(scope="session", autouse=True)
61
+ def set_warnings() -> None:
62
+ warnings.filterwarnings(
63
+ "ignore",
64
+ category=UserWarning,
65
+ message=r"Lazy modules are a new feature under heavy development",
66
+ )
67
+ warnings.filterwarnings(
68
+ "ignore",
69
+ category=UserWarning,
70
+ message=r"Couldn't cast the policy onto the desired device on remote process",
71
+ )
72
+ warnings.filterwarnings(
73
+ "ignore",
74
+ category=DeprecationWarning,
75
+ message=r"Deprecated call to `pkg_resources.declare_namespace",
76
+ )
77
+ warnings.filterwarnings(
78
+ "ignore",
79
+ category=DeprecationWarning,
80
+ message=r"Using or importing the ABCs",
81
+ )
82
+ warnings.filterwarnings(
83
+ "ignore",
84
+ category=DeprecationWarning,
85
+ message=r"Please use `coo_matrix` from the `scipy.sparse` namespace",
86
+ )
87
+ warnings.filterwarnings(
88
+ "ignore",
89
+ category=DeprecationWarning,
90
+ message=r"jax.tree_util.register_keypaths is deprecated|jax.ShapedArray is deprecated",
91
+ )
@@ -0,0 +1,321 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """This script executes some envs across the Gym library with the explicit scope of testing the throughput using the various TorchRL components.
7
+
8
+ We test:
9
+ - gym async envs embedded in a TorchRL's GymEnv wrapper,
10
+ - ParallelEnv with regular GymEnv instances,
11
+ - Data collector
12
+ - Multiprocessed data collectors with parallel envs.
13
+
14
+ The tests are executed with various number of cpus, and on different devices.
15
+
16
+ """
17
+ import time
18
+
19
+ # import myosuite # noqa: F401
20
+ import torch
21
+ import tqdm
22
+ from torchrl._utils import timeit
23
+ from torchrl.collectors import (
24
+ MultiaSyncDataCollector,
25
+ MultiSyncDataCollector,
26
+ SyncDataCollector,
27
+ )
28
+ from torchrl.envs import EnvCreator, GymEnv, ParallelEnv
29
+ from torchrl.envs.libs.gym import gym_backend as gym_bc, set_gym_backend
30
+ from torchrl.modules import RandomPolicy
31
+
32
+ if __name__ == "__main__":
33
+ avail_devices = ("cpu",)
34
+ if torch.cuda.is_available():
35
+ avail_devices = avail_devices + ("cuda:0",)
36
+
37
+ for envname in [
38
+ "CartPole-v1",
39
+ "HalfCheetah-v4",
40
+ "myoHandReachRandom-v0",
41
+ "ALE/Breakout-v5",
42
+ ]:
43
+ # the number of collectors won't affect the resources, just impacts how the envs are split in sub-sub-processes
44
+ for num_workers, num_collectors in zip((32, 64, 8, 16), (8, 8, 2, 4)):
45
+ with open(f"{envname}_{num_workers}.txt".replace("/", "-"), "w+") as log:
46
+ if "myo" in envname:
47
+ gym_backend = "gym"
48
+ else:
49
+ gym_backend = "gymnasium"
50
+
51
+ total_frames = num_workers * 10_000
52
+
53
+ # pure gym
54
+ def make(envname=envname, gym_backend=gym_backend):
55
+ with set_gym_backend(gym_backend):
56
+ return gym_bc().make(envname)
57
+
58
+ with set_gym_backend(gym_backend):
59
+ env = gym_bc().vector.AsyncVectorEnv(
60
+ [make for _ in range(num_workers)]
61
+ )
62
+ env.reset()
63
+ global_step = 0
64
+ times = []
65
+ start = time.time()
66
+ for _ in tqdm.tqdm(range(total_frames // num_workers)):
67
+ env.step(env.action_space.sample())
68
+ global_step += num_workers
69
+ env.close()
70
+ log.write(
71
+ f"pure gym: {num_workers * 10_000 / (time.time() - start): 4.4f} fps\n"
72
+ )
73
+ log.flush()
74
+
75
+ # regular parallel env
76
+ for device in avail_devices:
77
+
78
+ def make(envname=envname, gym_backend=gym_backend):
79
+ with set_gym_backend(gym_backend):
80
+ return GymEnv(envname, device="cpu")
81
+
82
+ # env_make = EnvCreator(make)
83
+ penv = ParallelEnv(num_workers, EnvCreator(make), device=device)
84
+ with torch.inference_mode():
85
+ # warmup
86
+ penv.rollout(2)
87
+ pbar = tqdm.tqdm(total=num_workers * 10_000)
88
+ t0 = time.time()
89
+ data = None
90
+ for _ in range(100):
91
+ data = penv.rollout(
92
+ 100, break_when_any_done=False, out=data
93
+ )
94
+ pbar.update(100 * num_workers)
95
+ log.write(
96
+ f"penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n"
97
+ )
98
+ log.flush()
99
+ penv.close()
100
+ timeit.print()
101
+ del penv
102
+
103
+ for device in avail_devices:
104
+
105
+ def make(envname=envname, gym_backend=gym_backend):
106
+ with set_gym_backend(gym_backend):
107
+ return GymEnv(envname, device="cpu")
108
+
109
+ env_make = EnvCreator(make)
110
+ # penv = SerialEnv(num_workers, env_make)
111
+ penv = ParallelEnv(num_workers, env_make, device=device)
112
+ collector = SyncDataCollector(
113
+ penv,
114
+ RandomPolicy(penv.action_spec),
115
+ frames_per_batch=1024,
116
+ total_frames=num_workers * 10_000,
117
+ device=device,
118
+ )
119
+ pbar = tqdm.tqdm(total=num_workers * 10_000)
120
+ total_frames = 0
121
+ t0 = time.time()
122
+ for data in collector:
123
+ total_frames += data.numel()
124
+ pbar.update(data.numel())
125
+ pbar.set_description(
126
+ f"single collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps"
127
+ )
128
+ log.write(
129
+ f"single collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
130
+ )
131
+ log.flush()
132
+ collector.shutdown()
133
+ del collector
134
+
135
+ for device in avail_devices:
136
+ # gym parallel env
137
+ def make_env(
138
+ envname=envname,
139
+ num_workers=num_workers,
140
+ gym_backend=gym_backend,
141
+ device=device,
142
+ ):
143
+ with set_gym_backend(gym_backend):
144
+ penv = GymEnv(envname, num_envs=num_workers, device=device)
145
+ return penv
146
+
147
+ penv = make_env()
148
+ # warmup
149
+ penv.rollout(2)
150
+ pbar = tqdm.tqdm(total=num_workers * 10_000)
151
+ t0 = time.time()
152
+ for _ in range(100):
153
+ data = penv.rollout(100, break_when_any_done=False)
154
+ pbar.update(100 * num_workers)
155
+ log.write(
156
+ f"gym penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n"
157
+ )
158
+ log.flush()
159
+ penv.close()
160
+ del penv
161
+
162
+ for device in avail_devices:
163
+ # async collector
164
+ # + torchrl parallel env
165
+ def make_env(envname=envname, gym_backend=gym_backend):
166
+ with set_gym_backend(gym_backend):
167
+ return GymEnv(envname, device="cpu")
168
+
169
+ penv = ParallelEnv(
170
+ num_workers // num_collectors,
171
+ EnvCreator(make_env),
172
+ device=device,
173
+ )
174
+ collector = MultiaSyncDataCollector(
175
+ [penv] * num_collectors,
176
+ policy=RandomPolicy(penv.action_spec),
177
+ frames_per_batch=1024,
178
+ total_frames=num_workers * 10_000,
179
+ device=device,
180
+ )
181
+ pbar = tqdm.tqdm(total=num_workers * 10_000)
182
+ total_frames = 0
183
+ for i, data in enumerate(collector):
184
+ if i == num_collectors:
185
+ t0 = time.time()
186
+ if i >= num_collectors:
187
+ total_frames += data.numel()
188
+ pbar.update(data.numel())
189
+ pbar.set_description(
190
+ f"collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps"
191
+ )
192
+ log.write(
193
+ f"async collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
194
+ )
195
+ log.flush()
196
+ collector.shutdown()
197
+ del collector
198
+
199
+ for device in avail_devices:
200
+ # async collector
201
+ # + gym async env
202
+ def make_env(
203
+ envname=envname,
204
+ num_workers=num_workers,
205
+ gym_backend=gym_backend,
206
+ ):
207
+ with set_gym_backend(gym_backend):
208
+ penv = GymEnv(envname, num_envs=num_workers, device="cpu")
209
+ return penv
210
+
211
+ penv = EnvCreator(
212
+ lambda num_workers=num_workers // num_collectors: make_env(
213
+ num_workers=num_workers
214
+ )
215
+ )
216
+ collector = MultiaSyncDataCollector(
217
+ [penv] * num_collectors,
218
+ policy=RandomPolicy(penv().action_spec),
219
+ frames_per_batch=1024,
220
+ total_frames=num_workers * 10_000,
221
+ num_sub_threads=num_workers // num_collectors,
222
+ device=device,
223
+ )
224
+ pbar = tqdm.tqdm(total=num_workers * 10_000)
225
+ total_frames = 0
226
+ for i, data in enumerate(collector):
227
+ if i == num_collectors:
228
+ t0 = time.time()
229
+ if i >= num_collectors:
230
+ total_frames += data.numel()
231
+ pbar.update(data.numel())
232
+ pbar.set_description(
233
+ f"{i} collector + gym penv: {total_frames / (time.time() - t0): 4.4f} fps"
234
+ )
235
+ log.write(
236
+ f"async collector + gym penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
237
+ )
238
+ log.flush()
239
+ collector.shutdown()
240
+ del collector
241
+
242
+ for device in avail_devices:
243
+ # sync collector
244
+ # + torchrl parallel env
245
+ def make_env(envname=envname, gym_backend=gym_backend):
246
+ with set_gym_backend(gym_backend):
247
+ return GymEnv(envname, device="cpu")
248
+
249
+ penv = ParallelEnv(
250
+ num_workers // num_collectors,
251
+ EnvCreator(make_env),
252
+ device=device,
253
+ )
254
+ collector = MultiSyncDataCollector(
255
+ [penv] * num_collectors,
256
+ policy=RandomPolicy(penv.action_spec),
257
+ frames_per_batch=1024,
258
+ total_frames=num_workers * 10_000,
259
+ device=device,
260
+ )
261
+ pbar = tqdm.tqdm(total=num_workers * 10_000)
262
+ total_frames = 0
263
+ for i, data in enumerate(collector):
264
+ if i == num_collectors:
265
+ t0 = time.time()
266
+ if i >= num_collectors:
267
+ total_frames += data.numel()
268
+ pbar.update(data.numel())
269
+ pbar.set_description(
270
+ f"collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps"
271
+ )
272
+ log.write(
273
+ f"sync collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
274
+ )
275
+ log.flush()
276
+ collector.shutdown()
277
+ del collector
278
+
279
+ for device in avail_devices:
280
+ # sync collector
281
+ # + gym async env
282
+ def make_env(
283
+ envname=envname,
284
+ num_workers=num_workers,
285
+ gym_backend=gym_backend,
286
+ ):
287
+ with set_gym_backend(gym_backend):
288
+ penv = GymEnv(envname, num_envs=num_workers, device="cpu")
289
+ return penv
290
+
291
+ penv = EnvCreator(
292
+ lambda num_workers=num_workers // num_collectors: make_env(
293
+ num_workers=num_workers
294
+ )
295
+ )
296
+ collector = MultiSyncDataCollector(
297
+ [penv] * num_collectors,
298
+ policy=RandomPolicy(penv().action_spec),
299
+ frames_per_batch=1024,
300
+ total_frames=num_workers * 10_000,
301
+ num_sub_threads=num_workers // num_collectors,
302
+ device=device,
303
+ )
304
+ pbar = tqdm.tqdm(total=num_workers * 10_000)
305
+ total_frames = 0
306
+ for i, data in enumerate(collector):
307
+ if i == num_collectors:
308
+ t0 = time.time()
309
+ if i >= num_collectors:
310
+ total_frames += data.numel()
311
+ pbar.update(data.numel())
312
+ pbar.set_description(
313
+ f"{i} collector + gym penv: {total_frames / (time.time() - t0): 4.4f} fps"
314
+ )
315
+ log.write(
316
+ f"sync collector + gym penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
317
+ )
318
+ log.flush()
319
+ collector.shutdown()
320
+ del collector
321
+ exit()