torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,88 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """Constants and helper classes for collectors."""
6
+ from __future__ import annotations
7
+
8
+ import os
9
+ import sys
10
+ from multiprocessing.managers import SyncManager
11
+
12
+ import torch
13
+ from torch import multiprocessing as mp
14
+
15
+ from torchrl.envs.utils import ExplorationType
16
+
17
+ try:
18
+ from torch.compiler import cudagraph_mark_step_begin
19
+ except ImportError:
20
+
21
+ def cudagraph_mark_step_begin():
22
+ """Placeholder for missing cudagraph_mark_step_begin method."""
23
+ raise NotImplementedError("cudagraph_mark_step_begin not implemented.")
24
+
25
+
26
+ __all__ = [
27
+ "_TIMEOUT",
28
+ "INSTANTIATE_TIMEOUT",
29
+ "_MIN_TIMEOUT",
30
+ "_MAX_IDLE_COUNT",
31
+ "WEIGHT_SYNC_TIMEOUT",
32
+ "DEFAULT_EXPLORATION_TYPE",
33
+ "_is_osx",
34
+ "_Interruptor",
35
+ "_InterruptorManager",
36
+ "cudagraph_mark_step_begin",
37
+ ]
38
+
39
+ _TIMEOUT = 1.0
40
+ INSTANTIATE_TIMEOUT = 20
41
+ _MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory
42
+ # Timeout for weight synchronization during collector init.
43
+ # Increase this when using many collectors across different CUDA devices.
44
+ WEIGHT_SYNC_TIMEOUT = float(os.environ.get("TORCHRL_WEIGHT_SYNC_TIMEOUT", 120.0))
45
+ # MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue.
46
+ _MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", torch.iinfo(torch.int64).max))
47
+
48
+ DEFAULT_EXPLORATION_TYPE: ExplorationType = ExplorationType.RANDOM
49
+
50
+ _is_osx = sys.platform.startswith("darwin")
51
+
52
+
53
+ class _Interruptor:
54
+ """A class for managing the collection state of a process.
55
+
56
+ This class provides methods to start and stop collection, and to check
57
+ whether collection has been stopped. The collection state is protected
58
+ by a lock to ensure thread-safety.
59
+ """
60
+
61
+ # interrupter vs interruptor: google trends seems to indicate that "or" is more
62
+ # widely used than "er" even if my IDE complains about that...
63
+ def __init__(self):
64
+ self._collect = True
65
+ self._lock = mp.Lock()
66
+
67
+ def start_collection(self):
68
+ with self._lock:
69
+ self._collect = True
70
+
71
+ def stop_collection(self):
72
+ with self._lock:
73
+ self._collect = False
74
+
75
+ def collection_stopped(self):
76
+ with self._lock:
77
+ return self._collect is False
78
+
79
+
80
+ class _InterruptorManager(SyncManager):
81
+ """A custom SyncManager for managing the collection state of a process.
82
+
83
+ This class extends the SyncManager class and allows to share an Interruptor object
84
+ between processes.
85
+ """
86
+
87
+
88
+ _InterruptorManager.register("_Interruptor", _Interruptor)
@@ -0,0 +1,324 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ import warnings
5
+ from collections import defaultdict, OrderedDict
6
+ from collections.abc import Iterator, Sequence
7
+ from copy import deepcopy
8
+ from queue import Empty
9
+
10
+ import torch
11
+
12
+ from tensordict import TensorDictBase
13
+ from tensordict.nn import TensorDictModuleBase
14
+ from torchrl._utils import (
15
+ _check_for_faulty_process,
16
+ accept_remote_rref_udf_invocation,
17
+ logger as torchrl_logger,
18
+ )
19
+ from torchrl.collectors._base import _make_legacy_metaclass
20
+ from torchrl.collectors._constants import _MAX_IDLE_COUNT, _TIMEOUT
21
+ from torchrl.collectors._multi_base import _MultiCollectorMeta, MultiCollector
22
+ from torchrl.collectors.utils import split_trajectories
23
+
24
+
25
+ @accept_remote_rref_udf_invocation
26
+ class MultiAsyncCollector(MultiCollector):
27
+ """Runs a given number of DataCollectors on separate processes asynchronously.
28
+
29
+ .. aafig::
30
+
31
+
32
+ +----------------------------------------------------------------------+
33
+ | "MultiConcurrentCollector" | |
34
+ |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| |
35
+ | "Collector 1" | "Collector 2" | "Collector 3" | "Main" |
36
+ |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~|
37
+ | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | |
38
+ |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~|
39
+ |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | |
40
+ | | | | | | | |
41
+ | "actor" | | | "actor" | |
42
+ | | | | | |
43
+ | "step" | "step" | "actor" | | |
44
+ | | | | | |
45
+ | | | | "step" | "step" | |
46
+ | | | | | | |
47
+ | "actor | "step" | "step" | "actor" | |
48
+ | | | | | |
49
+ | "yield batch 1" | "actor" | |"collect, train"|
50
+ | | | | |
51
+ | "step" | "step" | | "yield batch 2" |"collect, train"|
52
+ | | | | | |
53
+ | | | "yield batch 3" | |"collect, train"|
54
+ | | | | | |
55
+ +----------------------------------------------------------------------+
56
+
57
+ Environment types can be identical or different.
58
+
59
+ The collection keeps on occurring on all processes even between the time
60
+ the batch of rollouts is collected and the next call to the iterator.
61
+ This class can be safely used with offline RL sota-implementations.
62
+
63
+ .. note:: Python requires multiprocessed code to be instantiated within a main guard:
64
+
65
+ >>> from torchrl.collectors import MultiAsyncCollector
66
+ >>> if __name__ == "__main__":
67
+ ... # Create your collector here
68
+
69
+ See https://docs.python.org/3/library/multiprocessing.html for more info.
70
+
71
+ Examples:
72
+ >>> from torchrl.envs.libs.gym import GymEnv
73
+ >>> from tensordict.nn import TensorDictModule
74
+ >>> from torch import nn
75
+ >>> from torchrl.collectors import MultiAsyncCollector
76
+ >>> if __name__ == "__main__":
77
+ ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
78
+ ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
79
+ ... collector = MultiAsyncCollector(
80
+ ... create_env_fn=[env_maker, env_maker],
81
+ ... policy=policy,
82
+ ... total_frames=2000,
83
+ ... max_frames_per_traj=50,
84
+ ... frames_per_batch=200,
85
+ ... init_random_frames=-1,
86
+ ... reset_at_each_iter=False,
87
+ ... device="cpu",
88
+ ... storing_device="cpu",
89
+ ... cat_results="stack",
90
+ ... )
91
+ ... for i, data in enumerate(collector):
92
+ ... if i == 2:
93
+ ... print(data)
94
+ ... break
95
+ ... collector.shutdown()
96
+ ... del collector
97
+ TensorDict(
98
+ fields={
99
+ action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
100
+ collector: TensorDict(
101
+ fields={
102
+ traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
103
+ batch_size=torch.Size([200]),
104
+ device=cpu,
105
+ is_shared=False),
106
+ done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
107
+ next: TensorDict(
108
+ fields={
109
+ done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
110
+ observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
111
+ reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
112
+ step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
113
+ truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
114
+ batch_size=torch.Size([200]),
115
+ device=cpu,
116
+ is_shared=False),
117
+ observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
118
+ step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
119
+ truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
120
+ batch_size=torch.Size([200]),
121
+ device=cpu,
122
+ is_shared=False)
123
+
124
+ """
125
+
126
+ __doc__ += MultiCollector.__doc__
127
+
128
+ def __init__(self, *args, **kwargs):
129
+ super().__init__(*args, **kwargs)
130
+ self.out_tensordicts = defaultdict(lambda: None)
131
+ self.running = False
132
+
133
+ if self.postprocs is not None and self.replay_buffer is None:
134
+ postproc = self.postprocs
135
+ self.postprocs = {}
136
+ for _device in self.storing_device:
137
+ if _device not in self.postprocs:
138
+ if hasattr(postproc, "to"):
139
+ postproc = deepcopy(postproc).to(_device)
140
+ self.postprocs[_device] = postproc
141
+
142
+ # for RPC
143
+ def next(self):
144
+ return super().next()
145
+
146
+ # for RPC
147
+ def shutdown(
148
+ self,
149
+ timeout: float | None = None,
150
+ close_env: bool = True,
151
+ raise_on_error: bool = True,
152
+ ) -> None:
153
+ if hasattr(self, "out_tensordicts"):
154
+ del self.out_tensordicts
155
+ if not close_env:
156
+ raise RuntimeError(
157
+ f"Cannot shutdown {type(self).__name__} collector without environment being closed."
158
+ )
159
+ return super().shutdown(timeout=timeout, raise_on_error=raise_on_error)
160
+
161
+ # for RPC
162
+ def set_seed(self, seed: int, static_seed: bool = False) -> int:
163
+ return super().set_seed(seed, static_seed)
164
+
165
+ # for RPC
166
+ def state_dict(self) -> OrderedDict:
167
+ return super().state_dict()
168
+
169
+ # for RPC
170
+ def load_state_dict(self, state_dict: OrderedDict) -> None:
171
+ return super().load_state_dict(state_dict)
172
+
173
+ # for RPC
174
+ def update_policy_weights_(
175
+ self,
176
+ policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
177
+ *,
178
+ worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
179
+ **kwargs,
180
+ ) -> None:
181
+ if "policy_weights" in kwargs:
182
+ warnings.warn(
183
+ "`policy_weights` is deprecated. Use `policy_or_weights` instead.",
184
+ DeprecationWarning,
185
+ )
186
+ policy_or_weights = kwargs.pop("policy_weights")
187
+
188
+ super().update_policy_weights_(
189
+ policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
190
+ )
191
+
192
+ def frames_per_batch_worker(self, *, worker_idx: int | None = None) -> int:
193
+ return self.requested_frames_per_batch
194
+
195
+ def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]:
196
+ new_data, j = self.queue_out.get(timeout=timeout)
197
+ use_buffers = self._use_buffers
198
+ if self.replay_buffer is not None:
199
+ idx = new_data
200
+ elif j == 0 or not use_buffers:
201
+ try:
202
+ data, idx = new_data
203
+ self.out_tensordicts[idx] = data
204
+ if use_buffers is None and j > 0:
205
+ use_buffers = self._use_buffers = False
206
+ except TypeError:
207
+ if use_buffers is None:
208
+ use_buffers = self._use_buffers = True
209
+ idx = new_data
210
+ else:
211
+ raise
212
+ else:
213
+ idx = new_data
214
+ out = self.out_tensordicts[idx]
215
+ if not self.replay_buffer and (j == 0 or use_buffers):
216
+ # we clone the data to make sure that we'll be working with a fixed copy
217
+ out = out.clone()
218
+ return idx, j, out
219
+
220
+ @property
221
+ def _queue_len(self) -> int:
222
+ return 1
223
+
224
+ def iterator(self) -> Iterator[TensorDictBase]:
225
+ if self.update_at_each_batch:
226
+ self.update_policy_weights_()
227
+
228
+ for i in range(self.num_workers):
229
+ if self._should_use_random_frames():
230
+ self.pipes[i].send((None, "continue_random"))
231
+ else:
232
+ self.pipes[i].send((None, "continue"))
233
+ self.running = True
234
+
235
+ workers_frames = [0 for _ in range(self.num_workers)]
236
+ _iter_start_time = time.time()
237
+ while self._frames < self.total_frames:
238
+ self._iter += 1
239
+ counter = 0
240
+ while True:
241
+ try:
242
+ idx, j, out = self._get_from_queue(timeout=_TIMEOUT)
243
+ break
244
+ except (TimeoutError, Empty):
245
+ counter += _TIMEOUT
246
+ _check_for_faulty_process(self.procs)
247
+ # Debug logging for queue timeout
248
+ if counter % (10 * _TIMEOUT) == 0: # Log every 10 timeouts
249
+ _elapsed = time.time() - _iter_start_time
250
+ torchrl_logger.debug(
251
+ f"MultiAsyncCollector.iterator: Queue timeout, counter={counter:.1f}s, "
252
+ f"iter={self._iter}, frames={self._frames}, elapsed={_elapsed:.1f}s"
253
+ )
254
+ if counter > (_TIMEOUT * _MAX_IDLE_COUNT):
255
+ _elapsed = time.time() - _iter_start_time
256
+ torchrl_logger.debug(
257
+ f"MultiAsyncCollector.iterator: CRITICAL - Max idle exceeded, "
258
+ f"counter={counter:.1f}s, iter={self._iter}, frames={self._frames}, elapsed={_elapsed:.1f}s"
259
+ )
260
+ raise RuntimeError(
261
+ f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. "
262
+ f"Increase the MAX_IDLE_COUNT environment variable to bypass this error."
263
+ )
264
+ if self.replay_buffer is None:
265
+ worker_frames = out.numel()
266
+ if self.split_trajs:
267
+ out = split_trajectories(out, prefix="collector")
268
+ else:
269
+ worker_frames = self.frames_per_batch_worker()
270
+ self._frames += worker_frames
271
+ workers_frames[idx] = workers_frames[idx] + worker_frames
272
+ if out is not None and self.postprocs:
273
+ out = self.postprocs[out.device](out)
274
+
275
+ # the function blocks here until the next item is asked, hence we send the message to the
276
+ # worker to keep on working in the meantime before the yield statement
277
+ if self._should_use_random_frames():
278
+ msg = "continue_random"
279
+ else:
280
+ msg = "continue"
281
+ self.pipes[idx].send((idx, msg))
282
+ if out is not None and self._exclude_private_keys:
283
+ excluded_keys = [key for key in out.keys() if key.startswith("_")]
284
+ out = out.exclude(*excluded_keys)
285
+ yield out
286
+
287
+ # We don't want to shutdown yet, the user may want to call state_dict before
288
+ # self._shutdown_main()
289
+ self.running = False
290
+
291
+ def _shutdown_main(self, *args, **kwargs) -> None:
292
+ if hasattr(self, "out_tensordicts"):
293
+ del self.out_tensordicts
294
+ return super()._shutdown_main(*args, **kwargs)
295
+
296
+ def reset(self, reset_idx: Sequence[bool] | None = None) -> None:
297
+ super().reset(reset_idx)
298
+ if self.queue_out.full():
299
+ time.sleep(_TIMEOUT) # wait until queue is empty
300
+ if self.queue_out.full():
301
+ raise Exception("self.queue_out is full")
302
+ if self.running:
303
+ for idx in range(self.num_workers):
304
+ if self._should_use_random_frames():
305
+ self.pipes[idx].send((idx, "continue_random"))
306
+ else:
307
+ self.pipes[idx].send((idx, "continue"))
308
+
309
+ # for RPC
310
+ def _receive_weights_scheme(self):
311
+ return super()._receive_weights_scheme()
312
+
313
+ # for RPC
314
+ def receive_weights(self, policy_or_weights: TensorDictBase | None = None):
315
+ return super().receive_weights(policy_or_weights)
316
+
317
+
318
+ _LegacyMultiAsyncMeta = _make_legacy_metaclass(_MultiCollectorMeta)
319
+
320
+
321
+ class MultiaSyncDataCollector(MultiAsyncCollector, metaclass=_LegacyMultiAsyncMeta):
322
+ """Deprecated version of :class:`~torchrl.collectors.MultiAsyncCollector`."""
323
+
324
+ ...