torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,145 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import io
6
+ import pickle
7
+
8
+ import pytest
9
+ import torch
10
+
11
+
12
+ try:
13
+ from safetensors.torch import save
14
+ except ImportError:
15
+ save = None
16
+
17
+
18
+ class TestCompressedStorageBenchmark:
19
+ """Benchmark tests for CompressedListStorage."""
20
+
21
+ @staticmethod
22
+ def make_compressible_mock_data(num_experiences: int, device=None) -> dict:
23
+ """Easily compressible data for testing."""
24
+ if device is None:
25
+ device = torch.device("cpu")
26
+
27
+ return {
28
+ "observations": torch.zeros(
29
+ (num_experiences, 4, 84, 84),
30
+ dtype=torch.uint8,
31
+ device=device,
32
+ ),
33
+ "actions": torch.zeros((num_experiences,), device=device),
34
+ "rewards": torch.zeros((num_experiences,), device=device),
35
+ "next_observations": torch.zeros(
36
+ (num_experiences, 4, 84, 84),
37
+ dtype=torch.uint8,
38
+ device=device,
39
+ ),
40
+ "terminations": torch.zeros(
41
+ (num_experiences,), dtype=torch.bool, device=device
42
+ ),
43
+ "truncations": torch.zeros(
44
+ (num_experiences,), dtype=torch.bool, device=device
45
+ ),
46
+ "batch_size": [num_experiences],
47
+ }
48
+
49
+ @staticmethod
50
+ def make_uncompressible_mock_data(num_experiences: int, device=None) -> dict:
51
+ """Uncompressible data for testing."""
52
+ if device is None:
53
+ device = torch.device("cpu")
54
+ return {
55
+ "observations": torch.randn(
56
+ (num_experiences, 4, 84, 84),
57
+ dtype=torch.float32,
58
+ device=device,
59
+ ),
60
+ "actions": torch.randint(0, 10, (num_experiences,), device=device),
61
+ "rewards": torch.randn(
62
+ (num_experiences,), dtype=torch.float32, device=device
63
+ ),
64
+ "next_observations": torch.randn(
65
+ (num_experiences, 4, 84, 84),
66
+ dtype=torch.float32,
67
+ device=device,
68
+ ),
69
+ "terminations": torch.rand((num_experiences,), device=device)
70
+ < 0.2, # ~20% True
71
+ "truncations": torch.rand((num_experiences,), device=device)
72
+ < 0.1, # ~10% True
73
+ "batch_size": [num_experiences],
74
+ }
75
+
76
+ @pytest.mark.benchmark(
77
+ group="tensor_serialization_speed",
78
+ min_time=0.1,
79
+ max_time=0.5,
80
+ min_rounds=5,
81
+ disable_gc=True,
82
+ warmup=False,
83
+ )
84
+ @pytest.mark.parametrize(
85
+ "serialization_method",
86
+ ["pickle", "torch.save", "untyped_storage", "numpy", "safetensors"],
87
+ )
88
+ def test_tensor_to_bytestream_speed(self, benchmark, serialization_method: str):
89
+ """Benchmark the speed of different tensor serialization methods.
90
+
91
+ TODO: we might need to also test which methods work on the gpu.
92
+ pytest benchmarks/test_compressed_storage_benchmark.py::TestCompressedStorageBenchmark::test_tensor_to_bytestream_speed -v --benchmark-only --benchmark-sort='mean' --benchmark-columns='mean, ops'
93
+
94
+ ------------------------ benchmark 'tensor_to_bytestream_speed': 5 tests -------------------------
95
+ Name (time in us) Mean (smaller is better) OPS (bigger is better)
96
+ --------------------------------------------------------------------------------------------------
97
+ test_tensor_serialization_speed[numpy] 2.3520 (1.0) 425,162.1779 (1.0)
98
+ test_tensor_serialization_speed[safetensors] 14.7170 (6.26) 67,948.7129 (0.16)
99
+ test_tensor_serialization_speed[pickle] 19.0711 (8.11) 52,435.3333 (0.12)
100
+ test_tensor_serialization_speed[torch.save] 32.0648 (13.63) 31,186.8261 (0.07)
101
+ test_tensor_serialization_speed[untyped_storage] 38,227.0224 (>1000.0) 26.1595 (0.00)
102
+ --------------------------------------------------------------------------------------------------
103
+ """
104
+
105
+ def serialize_with_pickle(data: torch.Tensor) -> bytes:
106
+ """Serialize tensor using pickle."""
107
+ buffer = io.BytesIO()
108
+ pickle.dump(data, buffer)
109
+ return buffer.getvalue()
110
+
111
+ def serialize_with_untyped_storage(data: torch.Tensor) -> bytes:
112
+ """Serialize tensor using torch's built-in method."""
113
+ return bytes(data.untyped_storage())
114
+
115
+ def serialize_with_numpy(data: torch.Tensor) -> bytes:
116
+ """Serialize tensor using numpy."""
117
+ return data.numpy().tobytes()
118
+
119
+ def serialize_with_safetensors(data: torch.Tensor) -> bytes:
120
+ return save({"0": data})
121
+
122
+ def serialize_with_torch(data: torch.Tensor) -> bytes:
123
+ """Serialize tensor using torch's built-in method."""
124
+ buffer = io.BytesIO()
125
+ torch.save(data, buffer)
126
+ return buffer.getvalue()
127
+
128
+ # Benchmark each serialization method
129
+ if serialization_method == "pickle":
130
+ serialize_fn = serialize_with_pickle
131
+ elif serialization_method == "torch.save":
132
+ serialize_fn = serialize_with_torch
133
+ elif serialization_method == "untyped_storage":
134
+ serialize_fn = serialize_with_untyped_storage
135
+ elif serialization_method == "numpy":
136
+ serialize_fn = serialize_with_numpy
137
+ elif serialization_method == "safetensors":
138
+ serialize_fn = serialize_with_safetensors
139
+ else:
140
+ raise ValueError(f"Unknown serialization method: {serialization_method}")
141
+
142
+ data = self.make_compressible_mock_data(1).get("observations")
143
+
144
+ # Run the actual benchmark
145
+ benchmark(serialize_fn, data)
@@ -0,0 +1,133 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import argparse
6
+
7
+ import pytest
8
+ import torch
9
+
10
+ from tensordict import TensorDict
11
+ from torchrl.envs import ParallelEnv, SerialEnv, step_mdp, StepCounter, TransformedEnv
12
+ from torchrl.envs.libs.dm_control import DMControlEnv
13
+
14
+
15
+ def make_simple_env():
16
+ device = "cuda:0" if torch.cuda.device_count() else "cpu"
17
+ env = DMControlEnv("cheetah", "run", device=device)
18
+ env.rollout(3)
19
+ return ((env,), {})
20
+
21
+
22
+ def make_transformed_env():
23
+ device = "cuda:0" if torch.cuda.device_count() else "cpu"
24
+ env = TransformedEnv(DMControlEnv("cheetah", "run", device=device), StepCounter(50))
25
+ env.rollout(3)
26
+ return ((env,), {})
27
+
28
+
29
+ def make_serial_env():
30
+ device = "cuda:0" if torch.cuda.device_count() else "cpu"
31
+ env = SerialEnv(3, lambda: DMControlEnv("cheetah", "run", device=device))
32
+ env.rollout(3)
33
+ return ((env,), {})
34
+
35
+
36
+ def make_parallel_env():
37
+ device = "cuda:0" if torch.cuda.device_count() else "cpu"
38
+ env = ParallelEnv(3, lambda: DMControlEnv("cheetah", "run", device=device))
39
+ env.rollout(3)
40
+ return ((env,), {})
41
+
42
+
43
+ def make_nested_td():
44
+ return TensorDict(
45
+ {
46
+ ("agent", "action"): 0,
47
+ ("agent", "done"): 0,
48
+ ("agent", "obs"): 0,
49
+ ("agent", "other"): 0,
50
+ ("next", "agent", "action"): 1,
51
+ ("next", "agent", "reward"): 1,
52
+ ("next", "agent", "done"): 1,
53
+ ("next", "agent", "obs"): 1,
54
+ },
55
+ [],
56
+ )
57
+
58
+
59
+ def make_flat_td():
60
+ return TensorDict(
61
+ {
62
+ "action": 0,
63
+ "done": 0,
64
+ "obs": 0,
65
+ "other": 0,
66
+ ("next", "action"): 1,
67
+ ("next", "reward"): 1,
68
+ ("next", "done"): 1,
69
+ ("next", "obs"): 1,
70
+ },
71
+ [],
72
+ )
73
+
74
+
75
+ def execute_env(env):
76
+ env.rollout(1000, break_when_any_done=False)
77
+
78
+
79
+ def test_simple(benchmark):
80
+ (c,), _ = make_simple_env()
81
+ benchmark(execute_env, c)
82
+
83
+
84
+ def test_transformed(benchmark):
85
+ (c,), _ = make_transformed_env()
86
+ benchmark(execute_env, c)
87
+
88
+
89
+ def test_serial(benchmark):
90
+ (c,), _ = make_serial_env()
91
+ benchmark(execute_env, c)
92
+
93
+
94
+ def test_parallel(benchmark):
95
+ (c,), _ = make_parallel_env()
96
+ benchmark(execute_env, c)
97
+
98
+
99
+ @pytest.mark.parametrize("nested", [True, False])
100
+ @pytest.mark.parametrize("keep_other", [True, False])
101
+ @pytest.mark.parametrize("exclude_reward", [True, False])
102
+ @pytest.mark.parametrize("exclude_done", [True, False])
103
+ @pytest.mark.parametrize("exclude_action", [True, False])
104
+ def test_step_mdp_speed(
105
+ benchmark, nested, keep_other, exclude_reward, exclude_done, exclude_action
106
+ ):
107
+ if nested:
108
+ td = make_nested_td()
109
+ reward_key = ("agent", "reward")
110
+ done_key = ("agent", "done")
111
+ action_key = ("agent", "action")
112
+ else:
113
+ td = make_flat_td()
114
+ reward_key = "reward"
115
+ done_key = "done"
116
+ action_key = "action"
117
+
118
+ benchmark(
119
+ step_mdp,
120
+ td,
121
+ action_keys=action_key,
122
+ reward_keys=reward_key,
123
+ done_keys=done_key,
124
+ keep_other=keep_other,
125
+ exclude_reward=exclude_reward,
126
+ exclude_done=exclude_done,
127
+ exclude_action=exclude_action,
128
+ )
129
+
130
+
131
+ if __name__ == "__main__":
132
+ args, unknown = argparse.ArgumentParser().parse_known_args()
133
+ pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
benchmarks/test_llm.py ADDED
@@ -0,0 +1,101 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import importlib.util
9
+
10
+ import pytest
11
+ import torch
12
+ from tensordict import set_list_to_stack, TensorDict
13
+ from torchrl.data.llm import History
14
+ from torchrl.modules.llm.policies.common import ChatHistory
15
+ from torchrl.modules.llm.policies.transformers_wrapper import TransformersWrapper
16
+
17
+ _has_transformers = importlib.import_module("transformers") is not None
18
+
19
+ # Skip all these tests if gpu is not available
20
+ pytestmark = pytest.mark.skipif(
21
+ not torch.cuda.is_available(), reason="GPU not available"
22
+ )
23
+
24
+
25
+ @pytest.fixture(scope="module")
26
+ def transformers_wrapper():
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ with torch.device(device):
29
+ model = TransformersWrapper(
30
+ model="Qwen/Qwen2.5-0.5B",
31
+ tokenizer="Qwen/Qwen2.5-0.5B",
32
+ pad_model_input=False,
33
+ generate=False,
34
+ )
35
+ return model
36
+
37
+
38
+ @pytest.mark.skipif(not _has_transformers, reason="transformers not installed")
39
+ class TestWrappers:
40
+ @pytest.mark.parametrize("packing", [True, False])
41
+ @set_list_to_stack(True)
42
+ def test_packing(self, benchmark, transformers_wrapper, packing: bool):
43
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+ with torch.device(device):
45
+ transformers_wrapper = TransformersWrapper(
46
+ model=transformers_wrapper.model,
47
+ tokenizer=transformers_wrapper.tokenizer,
48
+ pad_model_input=not packing,
49
+ generate=False,
50
+ pad_output=False,
51
+ )
52
+ data = TensorDict(
53
+ {
54
+ "history": ChatHistory(
55
+ full=History(
56
+ role=[
57
+ ["user", "assistant"],
58
+ ["user", "assistant"],
59
+ ["user", "assistant"],
60
+ ["user", "assistant"],
61
+ ],
62
+ content=[
63
+ [
64
+ "Lorem ipsum dolor sit amet",
65
+ "consectetur adipiscing elit",
66
+ ],
67
+ [
68
+ "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua",
69
+ "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat",
70
+ ],
71
+ [
72
+ "Lorem ipsum dolor sit amet",
73
+ "consectetur adipiscing elit",
74
+ ],
75
+ [
76
+ "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua",
77
+ "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat",
78
+ ],
79
+ ],
80
+ batch_size=(4, 2),
81
+ device=device,
82
+ ),
83
+ batch_size=(4,),
84
+ device=device,
85
+ )
86
+ },
87
+ batch_size=(4,),
88
+ device=device,
89
+ ).to_lazystack()
90
+
91
+ def setup():
92
+ if torch.cuda.is_available():
93
+ torch.cuda.empty_cache()
94
+
95
+ benchmark.pedantic(
96
+ transformers_wrapper,
97
+ (data,),
98
+ rounds=10,
99
+ warmup_rounds=3,
100
+ setup=setup,
101
+ )
@@ -0,0 +1,70 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import gc
8
+ import time
9
+
10
+ import pytest
11
+ from tensordict import set_capture_non_tensor_stack
12
+ from torchrl.envs import ParallelEnv, SerialEnv
13
+ from torchrl.testing.mocking_classes import EnvWithMetadata
14
+
15
+
16
+ def _rollout(env, n_steps: int, break_when_any_done: bool) -> None:
17
+ env.rollout(n_steps, break_when_any_done=break_when_any_done)
18
+
19
+
20
+ @pytest.mark.parametrize("break_when_any_done", [True, False])
21
+ @pytest.mark.parametrize(
22
+ "kind,use_buffers",
23
+ [
24
+ pytest.param("single", None, id="single"),
25
+ pytest.param("serial", False, id="serial-no-buffers"),
26
+ pytest.param("serial", True, id="serial-buffers"),
27
+ pytest.param("parallel", False, id="parallel-no-buffers"),
28
+ pytest.param("parallel", True, id="parallel-buffers"),
29
+ ],
30
+ )
31
+ @pytest.mark.parametrize("n_steps", [1000])
32
+ def test_non_tensor_env_rollout_speed(
33
+ benchmark,
34
+ break_when_any_done: bool,
35
+ kind: str,
36
+ use_buffers: bool | None,
37
+ n_steps: int,
38
+ ):
39
+ """Benchmarks a single rollout, after a warmup rollout, for non-tensor stacking envs.
40
+
41
+ Mirrors `test/test_envs.py::TestNonTensorEnv`'s option matrix (single/serial/parallel,
42
+ break_when_any_done, use_buffers).
43
+ """
44
+ with set_capture_non_tensor_stack(False):
45
+ if kind == "single":
46
+ env = EnvWithMetadata()
47
+ elif kind == "serial":
48
+ env = SerialEnv(2, EnvWithMetadata, use_buffers=use_buffers)
49
+ elif kind == "parallel":
50
+ env = ParallelEnv(2, EnvWithMetadata, use_buffers=use_buffers)
51
+ else:
52
+ raise RuntimeError(f"Unknown kind={kind}")
53
+
54
+ env.set_seed(0)
55
+ env.reset()
56
+
57
+ try:
58
+ # Warmup run (not timed)
59
+ _rollout(env, n_steps=n_steps, break_when_any_done=break_when_any_done)
60
+
61
+ # Timed run(s)
62
+ benchmark(
63
+ _rollout, env, n_steps=n_steps, break_when_any_done=break_when_any_done
64
+ )
65
+ finally:
66
+ env.close(raise_if_closed=False)
67
+ del env
68
+ # Give multiprocessing envs a brief chance to terminate cleanly.
69
+ time.sleep(0.05)
70
+ gc.collect()