torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,259 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import OrderedDict
4
+ from collections.abc import Callable, Sequence
5
+ from typing import Any
6
+
7
+ from tensordict import TensorDictBase
8
+ from tensordict.nn import TensorDictModule
9
+
10
+ from torchrl._utils import accept_remote_rref_udf_invocation
11
+ from torchrl.collectors._base import _make_legacy_metaclass
12
+ from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE, ExplorationType
13
+ from torchrl.collectors._multi_async import MultiAsyncCollector
14
+ from torchrl.collectors._multi_base import _MultiCollectorMeta
15
+ from torchrl.data.utils import DEVICE_TYPING
16
+ from torchrl.envs import EnvBase
17
+
18
+
19
+ @accept_remote_rref_udf_invocation
20
+ class AsyncCollector(MultiAsyncCollector):
21
+ """Runs a single DataCollector on a separate process.
22
+
23
+ This is mostly useful for offline RL paradigms where the policy being
24
+ trained can differ from the policy used to collect data. In online
25
+ settings, a regular DataCollector should be preferred. This class is
26
+ merely a wrapper around a MultiAsyncCollector where a single process
27
+ is being created.
28
+
29
+ Args:
30
+ create_env_fn (Callabled): Callable returning an instance of EnvBase
31
+ policy (Callable): Policy to be executed in the environment.
32
+ Must accept :class:`tensordict.tensordict.TensorDictBase` object as input.
33
+ If ``None`` is provided, the policy used will be a
34
+ :class:`~torchrl.collectors.RandomPolicy` instance with the environment
35
+ ``action_spec``.
36
+ Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`.
37
+ This is the recommended usage of the collector.
38
+ Other callables are accepted too:
39
+ If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module`
40
+ instances) it will be wrapped in a `nn.Module` first.
41
+ Then, the collector will try to assess if these
42
+ modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
43
+
44
+ - If the policy forward signature matches any of ``forward(self, tensordict)``,
45
+ ``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
46
+ any typing with a single argument typed as a subclass of ``TensorDictBase``)
47
+ then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
48
+
49
+ - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
50
+
51
+ .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
52
+ pickled directly), the ``policy_factory`` should be used instead.
53
+
54
+ Keyword Args:
55
+ policy_factory (Callable[[], Callable], optional): a callable that returns
56
+ a policy instance. This is exclusive with the `policy` argument.
57
+
58
+ .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
59
+
60
+ frames_per_batch (int): A keyword-only argument representing the
61
+ total number of elements in a batch.
62
+ total_frames (int, optional): A keyword-only argument representing the
63
+ total number of frames returned by the collector
64
+ during its lifespan. If the ``total_frames`` is not divisible by
65
+ ``frames_per_batch``, an exception is raised.
66
+ Endless collectors can be created by passing ``total_frames=-1``.
67
+ Defaults to ``-1`` (never ending collector).
68
+ device (int, str or torch.device, optional): The generic device of the
69
+ collector. The ``device`` args fills any non-specified device: if
70
+ ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
71
+ ``env_device`` is not specified, its value will be set to ``device``.
72
+ Defaults to ``None`` (No default device).
73
+ Supports a list of devices if one wishes to indicate a different device
74
+ for each worker. The list must be as long as the number of workers.
75
+ storing_device (int, str or torch.device, optional): The device on which
76
+ the output :class:`~tensordict.TensorDict` will be stored.
77
+ If ``device`` is passed and ``storing_device`` is ``None``, it will
78
+ default to the value indicated by ``device``.
79
+ For long trajectories, it may be necessary to store the data on a different
80
+ device than the one where the policy and env are executed.
81
+ Defaults to ``None`` (the output tensordict isn't on a specific device,
82
+ leaf tensors sit on the device where they were created).
83
+ Supports a list of devices if one wishes to indicate a different device
84
+ for each worker. The list must be as long as the number of workers.
85
+ env_device (int, str or torch.device, optional): The device on which
86
+ the environment should be cast (or executed if that functionality is
87
+ supported). If not specified and the env has a non-``None`` device,
88
+ ``env_device`` will default to that value. If ``device`` is passed
89
+ and ``env_device=None``, it will default to ``device``. If the value
90
+ as such specified of ``env_device`` differs from ``policy_device``
91
+ and one of them is not ``None``, the data will be cast to ``env_device``
92
+ before being passed to the env (i.e., passing different devices to
93
+ policy and env is supported). Defaults to ``None``.
94
+ Supports a list of devices if one wishes to indicate a different device
95
+ for each worker. The list must be as long as the number of workers.
96
+ policy_device (int, str or torch.device, optional): The device on which
97
+ the policy should be cast.
98
+ If ``device`` is passed and ``policy_device=None``, it will default
99
+ to ``device``. If the value as such specified of ``policy_device``
100
+ differs from ``env_device`` and one of them is not ``None``,
101
+ the data will be cast to ``policy_device`` before being passed to
102
+ the policy (i.e., passing different devices to policy and env is
103
+ supported). Defaults to ``None``.
104
+ Supports a list of devices if one wishes to indicate a different device
105
+ for each worker. The list must be as long as the number of workers.
106
+ create_env_kwargs (dict, optional): A dictionary with the
107
+ keyword arguments used to create an environment. If a list is
108
+ provided, each of its elements will be assigned to a sub-collector.
109
+ max_frames_per_traj (int, optional): Maximum steps per trajectory.
110
+ Note that a trajectory can span across multiple batches (unless
111
+ ``reset_at_each_iter`` is set to ``True``, see below).
112
+ Once a trajectory reaches ``n_steps``, the environment is reset.
113
+ If the environment wraps multiple environments together, the number
114
+ of steps is tracked for each environment independently. Negative
115
+ values are allowed, in which case this argument is ignored.
116
+ Defaults to ``None`` (i.e. no maximum number of steps).
117
+ init_random_frames (int, optional): Number of frames for which the
118
+ policy is ignored before it is called. This feature is mainly
119
+ intended to be used in offline/model-based settings, where a
120
+ batch of random trajectories can be used to initialize training.
121
+ If provided, it will be rounded up to the closest multiple of frames_per_batch.
122
+ Defaults to ``None`` (i.e. no random frames).
123
+ reset_at_each_iter (bool, optional): Whether environments should be reset
124
+ at the beginning of a batch collection.
125
+ Defaults to ``False``.
126
+ postproc (Callable, optional): A post-processing transform, such as
127
+ a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
128
+ instance.
129
+ Defaults to ``None``.
130
+ split_trajs (bool, optional): Boolean indicating whether the resulting
131
+ TensorDict should be split according to the trajectories.
132
+ See :func:`~torchrl.collectors.utils.split_trajectories` for more
133
+ information.
134
+ Defaults to ``False``.
135
+ exploration_type (ExplorationType, optional): interaction mode to be used when
136
+ collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
137
+ ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
138
+ or ``torchrl.envs.utils.ExplorationType.MEAN``.
139
+ reset_when_done (bool, optional): if ``True`` (default), an environment
140
+ that return a ``True`` value in its ``"done"`` or ``"truncated"``
141
+ entry will be reset at the corresponding indices.
142
+ update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()`
143
+ will be called before (sync) or after (async) each data collection.
144
+ Defaults to ``False``.
145
+ preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
146
+ that will be allowed to finished collecting their rollout before the rest are forced to end early.
147
+ num_threads (int, optional): number of threads for this process.
148
+ Defaults to the number of workers.
149
+ num_sub_threads (int, optional): number of threads of the subprocesses.
150
+ Should be equal to one plus the number of processes launched within
151
+ each subprocess (or one if a single process is launched).
152
+ Defaults to 1 for safety: if none is indicated, launching multiple
153
+ workers may charge the cpu load too much and harm performance.
154
+ set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding
155
+ ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of
156
+ a rollout is reached. If no ``"truncated"`` key is found, an exception is raised.
157
+ Truncated keys can be set through ``env.add_truncated_keys``.
158
+ Defaults to ``False``.
159
+ track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
160
+ This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
161
+ Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
162
+ the policy version.
163
+ Defaults to `False`.
164
+
165
+ """
166
+
167
+ def __init__(
168
+ self,
169
+ create_env_fn: Callable[[], EnvBase],
170
+ policy: None
171
+ | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None,
172
+ *,
173
+ policy_factory: Callable[[], Callable] | None = None,
174
+ frames_per_batch: int,
175
+ total_frames: int | None = -1,
176
+ device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
177
+ storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
178
+ env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
179
+ policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
180
+ create_env_kwargs: Sequence[dict[str, Any]] | None = None,
181
+ max_frames_per_traj: int | None = None,
182
+ init_random_frames: int | None = None,
183
+ reset_at_each_iter: bool = False,
184
+ postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
185
+ split_trajs: bool | None = None,
186
+ exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
187
+ reset_when_done: bool = True,
188
+ update_at_each_batch: bool = False,
189
+ preemptive_threshold: float | None = None,
190
+ num_threads: int | None = None,
191
+ num_sub_threads: int = 1,
192
+ set_truncated: bool = False,
193
+ track_policy_version: bool = False,
194
+ **kwargs,
195
+ ):
196
+ super().__init__(
197
+ create_env_fn=[create_env_fn],
198
+ policy=policy,
199
+ policy_factory=policy_factory,
200
+ total_frames=total_frames,
201
+ create_env_kwargs=[create_env_kwargs]
202
+ if create_env_kwargs
203
+ else create_env_kwargs,
204
+ max_frames_per_traj=max_frames_per_traj,
205
+ frames_per_batch=frames_per_batch,
206
+ reset_at_each_iter=reset_at_each_iter,
207
+ init_random_frames=init_random_frames,
208
+ postproc=postproc,
209
+ split_trajs=split_trajs,
210
+ device=device,
211
+ policy_device=policy_device,
212
+ env_device=env_device,
213
+ storing_device=storing_device,
214
+ exploration_type=exploration_type,
215
+ reset_when_done=reset_when_done,
216
+ update_at_each_batch=update_at_each_batch,
217
+ preemptive_threshold=preemptive_threshold,
218
+ num_threads=num_threads,
219
+ num_sub_threads=num_sub_threads,
220
+ set_truncated=set_truncated,
221
+ track_policy_version=track_policy_version,
222
+ **kwargs,
223
+ )
224
+
225
+ # for RPC
226
+ def next(self):
227
+ return super().next()
228
+
229
+ # for RPC
230
+ def shutdown(
231
+ self,
232
+ timeout: float | None = None,
233
+ close_env: bool = True,
234
+ raise_on_error: bool = True,
235
+ ) -> None:
236
+ return super().shutdown(
237
+ timeout=timeout, close_env=close_env, raise_on_error=raise_on_error
238
+ )
239
+
240
+ # for RPC
241
+ def set_seed(self, seed: int, static_seed: bool = False) -> int:
242
+ return super().set_seed(seed, static_seed)
243
+
244
+ # for RPC
245
+ def state_dict(self) -> OrderedDict:
246
+ return super().state_dict()
247
+
248
+ # for RPC
249
+ def load_state_dict(self, state_dict: OrderedDict) -> None:
250
+ return super().load_state_dict(state_dict)
251
+
252
+
253
+ _LegacyAsyncCollectorMeta = _make_legacy_metaclass(_MultiCollectorMeta)
254
+
255
+
256
+ class aSyncDataCollector(AsyncCollector, metaclass=_LegacyAsyncCollectorMeta):
257
+ """Deprecated version of :class:`~torchrl.collectors.AsyncCollector`."""
258
+
259
+ ...
@@ -0,0 +1,62 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """Re-exports of collector classes for backward compatibility."""
6
+ from __future__ import annotations
7
+
8
+ from torchrl.collectors._base import BaseCollector, DataCollectorBase
9
+
10
+ # Re-export constants for backward compatibility
11
+ from torchrl.collectors._constants import (
12
+ _Interruptor,
13
+ _InterruptorManager,
14
+ _is_osx,
15
+ _MAX_IDLE_COUNT,
16
+ _MIN_TIMEOUT,
17
+ _TIMEOUT,
18
+ cudagraph_mark_step_begin,
19
+ DEFAULT_EXPLORATION_TYPE,
20
+ INSTANTIATE_TIMEOUT,
21
+ WEIGHT_SYNC_TIMEOUT,
22
+ )
23
+
24
+ from torchrl.collectors._multi_async import MultiAsyncCollector, MultiaSyncDataCollector
25
+ from torchrl.collectors._multi_base import (
26
+ MultiCollector,
27
+ MultiCollector as _MultiDataCollector,
28
+ )
29
+ from torchrl.collectors._multi_sync import MultiSyncCollector, MultiSyncDataCollector
30
+ from torchrl.collectors._runner import _main_async_collector
31
+ from torchrl.collectors._single import Collector, SyncDataCollector
32
+ from torchrl.collectors._single_async import AsyncCollector, aSyncDataCollector
33
+
34
+ __all__ = [
35
+ # New canonical names (preferred)
36
+ "BaseCollector",
37
+ "Collector",
38
+ "AsyncCollector",
39
+ "MultiCollector",
40
+ "MultiSyncCollector",
41
+ "MultiAsyncCollector",
42
+ # Legacy names (backward-compatible aliases)
43
+ "DataCollectorBase",
44
+ "SyncDataCollector",
45
+ "aSyncDataCollector",
46
+ "_MultiDataCollector",
47
+ "MultiSyncDataCollector",
48
+ "MultiaSyncDataCollector",
49
+ # Other exports
50
+ "_main_async_collector",
51
+ # Constants
52
+ "_TIMEOUT",
53
+ "INSTANTIATE_TIMEOUT",
54
+ "WEIGHT_SYNC_TIMEOUT",
55
+ "_MIN_TIMEOUT",
56
+ "_MAX_IDLE_COUNT",
57
+ "DEFAULT_EXPLORATION_TYPE",
58
+ "_is_osx",
59
+ "_Interruptor",
60
+ "_InterruptorManager",
61
+ "cudagraph_mark_step_begin",
62
+ ]
@@ -0,0 +1,32 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .generic import (
7
+ DEFAULT_SLURM_CONF,
8
+ DistributedCollector,
9
+ DistributedDataCollector,
10
+ DistributedWeightUpdater,
11
+ )
12
+ from .ray import RayCollector
13
+ from .rpc import RPCCollector, RPCDataCollector, RPCWeightUpdater
14
+ from .sync import DistributedSyncCollector, DistributedSyncDataCollector
15
+ from .utils import submitit_delayed_launcher
16
+
17
+ __all__ = [
18
+ "DEFAULT_SLURM_CONF",
19
+ # New canonical names (preferred)
20
+ "DistributedCollector",
21
+ "DistributedSyncCollector",
22
+ "RPCCollector",
23
+ # Legacy names (backward-compatible aliases)
24
+ "DistributedDataCollector",
25
+ "DistributedSyncDataCollector",
26
+ "RPCDataCollector",
27
+ # Other exports
28
+ "DistributedWeightUpdater",
29
+ "RPCWeightUpdater",
30
+ "RayCollector",
31
+ "submitit_delayed_launcher",
32
+ ]
@@ -0,0 +1,133 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import os
8
+ import random
9
+ import socket
10
+ from datetime import timedelta
11
+
12
+ import torch.distributed
13
+
14
+ from torchrl._utils import logger as torchrl_logger
15
+
16
+ TCP_PORT = os.environ.get("TCP_PORT", "10003")
17
+ IDLE_TIMEOUT = os.environ.get("RCP_IDLE_TIMEOUT", 10)
18
+
19
+ MAX_TIME_TO_CONNECT = 1000
20
+
21
+ SLEEP_INTERVAL = 1e-6
22
+
23
+ DEFAULT_SLURM_CONF = {
24
+ "timeout_min": 10,
25
+ "slurm_partition": "train",
26
+ "slurm_cpus_per_task": 32,
27
+ "slurm_gpus_per_node": 0,
28
+ } #: Default value of the SLURM jobs
29
+
30
+ DEFAULT_SLURM_CONF_MAIN = {
31
+ "timeout_min": 10,
32
+ "slurm_partition": "train",
33
+ "slurm_cpus_per_task": 32,
34
+ "slurm_gpus_per_node": 1,
35
+ } #: Default value of the SLURM main job
36
+
37
+ DEFAULT_TENSORPIPE_OPTIONS = {
38
+ "num_worker_threads": 16,
39
+ "rpc_timeout": 10_000,
40
+ "_transports": ["uv"],
41
+ }
42
+
43
+
44
+ def _find_free_port() -> int:
45
+ """Find a free port by binding to port 0 and letting the OS choose."""
46
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
47
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
48
+ s.bind(("", 0))
49
+ return s.getsockname()[1]
50
+
51
+
52
+ def _create_tcpstore_with_retry(
53
+ host_name: str,
54
+ port: int | None,
55
+ world_size: int,
56
+ is_master: bool,
57
+ timeout: float = 10.0,
58
+ max_retries: int = 10,
59
+ wait_for_workers: bool = True,
60
+ ) -> tuple[torch.distributed.TCPStore, int]:
61
+ """Create a TCPStore with retry logic for handling port conflicts.
62
+
63
+ This function attempts to create a TCPStore, and if the port is already in use,
64
+ it will retry with different random ports up to max_retries times.
65
+
66
+ Args:
67
+ host_name: The hostname for the TCPStore.
68
+ port: The initial port to try. If None, a random port will be chosen.
69
+ world_size: The world size for the TCPStore.
70
+ is_master: Whether this is the master (server) process.
71
+ timeout: Timeout in seconds for the TCPStore.
72
+ max_retries: Maximum number of retry attempts.
73
+ wait_for_workers: Whether the master should wait for workers.
74
+ Only used when is_master=True.
75
+
76
+ Returns:
77
+ A tuple of (TCPStore, actual_port) where actual_port is the port
78
+ that was successfully bound.
79
+
80
+ Raises:
81
+ RuntimeError: If unable to create a TCPStore after max_retries attempts.
82
+ """
83
+ last_error = None
84
+
85
+ for attempt in range(max_retries):
86
+ if port is None or attempt > 0:
87
+ # For the first attempt use provided port, for retries find a new free port
88
+ current_port = _find_free_port()
89
+ else:
90
+ current_port = int(port)
91
+
92
+ try:
93
+ if is_master:
94
+ store = torch.distributed.TCPStore(
95
+ host_name=host_name,
96
+ port=current_port,
97
+ world_size=world_size,
98
+ is_master=True,
99
+ timeout=timedelta(seconds=timeout),
100
+ wait_for_workers=wait_for_workers,
101
+ )
102
+ else:
103
+ store = torch.distributed.TCPStore(
104
+ host_name=host_name,
105
+ port=current_port,
106
+ is_master=False,
107
+ timeout=timedelta(seconds=timeout),
108
+ )
109
+ torchrl_logger.debug(
110
+ f"TCPStore created successfully on {host_name}:{current_port} "
111
+ f"(attempt {attempt + 1}/{max_retries})"
112
+ )
113
+ return store, current_port
114
+
115
+ except (RuntimeError, OSError) as e:
116
+ error_msg = str(e).lower()
117
+ if "address already in use" in error_msg or "eaddrinuse" in error_msg:
118
+ torchrl_logger.debug(
119
+ f"Port {current_port} already in use, "
120
+ f"retrying ({attempt + 1}/{max_retries})..."
121
+ )
122
+ last_error = e
123
+ # Add small random delay to reduce collision probability
124
+ import time
125
+
126
+ time.sleep(random.uniform(0.01, 0.1))
127
+ continue
128
+ # For other errors, re-raise immediately
129
+ raise
130
+
131
+ raise RuntimeError(
132
+ f"Failed to create TCPStore after {max_retries} attempts. Last error: {last_error}"
133
+ )