torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,731 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ r"""Generic distributed data-collector using torch.distributed backend."""
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ import socket
11
+ import warnings
12
+ from collections import OrderedDict
13
+ from collections.abc import Callable, Sequence
14
+ from copy import copy, deepcopy
15
+ from datetime import timedelta
16
+ from typing import Any, Literal
17
+
18
+ import torch.cuda
19
+ from tensordict import TensorDict, TensorDictBase
20
+ from torch import nn
21
+ from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE
22
+ from torchrl.collectors._base import _LegacyCollectorMeta, BaseCollector
23
+ from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE
24
+ from torchrl.collectors._multi_async import MultiAsyncCollector
25
+ from torchrl.collectors._multi_sync import MultiSyncCollector
26
+ from torchrl.collectors._single import Collector
27
+ from torchrl.collectors.distributed.default_configs import (
28
+ DEFAULT_SLURM_CONF,
29
+ MAX_TIME_TO_CONNECT,
30
+ )
31
+ from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories
32
+ from torchrl.data.utils import CloudpickleWrapper
33
+ from torchrl.envs.common import EnvBase
34
+ from torchrl.envs.env_creator import EnvCreator
35
+
36
+ SUBMITIT_ERR = None
37
+ try:
38
+ import submitit
39
+
40
+ _has_submitit = True
41
+ except ModuleNotFoundError as err:
42
+ _has_submitit = False
43
+ SUBMITIT_ERR = err
44
+
45
+
46
+ def _distributed_init_collection_node(
47
+ *,
48
+ rank,
49
+ rank0_ip,
50
+ tcpport,
51
+ world_size,
52
+ backend,
53
+ collector_class,
54
+ num_workers,
55
+ env_make,
56
+ policy,
57
+ policy_factory,
58
+ frames_per_batch,
59
+ collector_kwargs,
60
+ update_interval,
61
+ total_frames,
62
+ weight_sync_schemes=None,
63
+ verbose=VERBOSE,
64
+ ):
65
+ os.environ["MASTER_ADDR"] = str(rank0_ip)
66
+ os.environ["MASTER_PORT"] = str(tcpport)
67
+
68
+ if verbose:
69
+ torchrl_logger.debug(
70
+ f"node with rank {rank} -- creating collector of type {collector_class}"
71
+ )
72
+ if not issubclass(collector_class, Collector):
73
+ env_make = [env_make] * num_workers
74
+ else:
75
+ collector_kwargs["return_same_td"] = True
76
+ if num_workers != 1:
77
+ raise RuntimeError(
78
+ "Collector and subclasses can only support a single environment."
79
+ )
80
+
81
+ torchrl_logger.debug(f"IP address: {rank0_ip} \ttcp port: {tcpport}")
82
+
83
+ # Pass weight_recv_schemes to the collector - it will handle init_on_receiver and connect
84
+ # The scheme's connect() will call init_process_group as a collective operation
85
+ if weight_sync_schemes is not None:
86
+ collector_kwargs["weight_recv_schemes"] = weight_sync_schemes
87
+ else:
88
+ # No schemes - init process group manually for data.isend to work
89
+ if verbose:
90
+ torchrl_logger.debug(
91
+ f"node with rank {rank} -- launching distributed (no weight schemes)"
92
+ )
93
+ torch.distributed.init_process_group(
94
+ backend,
95
+ rank=rank,
96
+ world_size=world_size,
97
+ timeout=timedelta(MAX_TIME_TO_CONNECT),
98
+ )
99
+
100
+ # When policy_factory is provided, the child collector should use it
101
+ # instead of the policy (which is only used as a weight source for the parent)
102
+ collector = collector_class(
103
+ env_make,
104
+ policy if policy_factory is None else None,
105
+ frames_per_batch=frames_per_batch,
106
+ split_trajs=False,
107
+ total_frames=total_frames,
108
+ policy_factory=policy_factory,
109
+ worker_idx=rank,
110
+ **collector_kwargs,
111
+ )
112
+
113
+ if verbose:
114
+ torchrl_logger.debug(f"node with rank {rank} -- loop")
115
+
116
+ # Collection loop - weight updates are handled by the background thread in the scheme
117
+ for i, data in enumerate(collector):
118
+ torchrl_logger.debug(
119
+ f"Sending batch {i} from sync distributed collector on rank {rank}"
120
+ )
121
+ data.isend(dst=0)
122
+ torchrl_logger.debug(
123
+ f"Sent batch {i} from distributed collector on rank {rank}"
124
+ )
125
+
126
+ # Cleanup
127
+ if weight_sync_schemes is not None:
128
+ for scheme in weight_sync_schemes.values():
129
+ scheme.shutdown()
130
+ if not collector.closed:
131
+ collector.shutdown()
132
+ del collector
133
+ return
134
+
135
+
136
+ class DistributedSyncCollector(BaseCollector):
137
+ """A distributed synchronous data collector with torch.distributed backend.
138
+
139
+ Args:
140
+ create_env_fn (Callable or List[Callabled]): list of Callables, each returning an
141
+ instance of :class:`~torchrl.envs.EnvBase`.
142
+ policy (Callable): Policy to be executed in the environment.
143
+ Must accept :class:`tensordict.tensordict.TensorDictBase` object as input.
144
+ If ``None`` is provided, the policy used will be a
145
+ :class:`~torchrl.collectors.RandomPolicy` instance with the environment
146
+ ``action_spec``.
147
+ Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`.
148
+ This is the recommended usage of the collector.
149
+ Other callables are accepted too:
150
+ If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module`
151
+ instances) it will be wrapped in a `nn.Module` first.
152
+ Then, the collector will try to assess if these
153
+ modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
154
+
155
+ - If the policy forward signature matches any of ``forward(self, tensordict)``,
156
+ ``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
157
+ any typing with a single argument typed as a subclass of ``TensorDictBase``)
158
+ then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
159
+
160
+ - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
161
+
162
+ .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
163
+ pickled directly), the ``policy_factory`` should be used instead.
164
+
165
+ Keyword Args:
166
+ policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
167
+ (or list of callables) that returns a policy instance. This is exclusive with the `policy` argument.
168
+
169
+ .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
170
+
171
+ frames_per_batch (int): A keyword-only argument representing the total
172
+ number of elements in a batch.
173
+ total_frames (int): A keyword-only argument representing the total
174
+ number of frames returned by the collector
175
+ during its lifespan. If the ``total_frames`` is not divisible by
176
+ ``frames_per_batch``, an exception is raised.
177
+ Endless collectors can be created by passing ``total_frames=-1``.
178
+ Defaults to ``-1`` (endless collector).
179
+ device (int, str or torch.device, optional): The generic device of the
180
+ collector. The ``device`` args fills any non-specified device: if
181
+ ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
182
+ ``env_device`` is not specified, its value will be set to ``device``.
183
+ Defaults to ``None`` (No default device).
184
+ Lists of devices are supported.
185
+ storing_device (int, str or torch.device, optional): The *remote* device on which
186
+ the output :class:`~tensordict.TensorDict` will be stored.
187
+ If ``device`` is passed and ``storing_device`` is ``None``, it will
188
+ default to the value indicated by ``device``.
189
+ For long trajectories, it may be necessary to store the data on a different
190
+ device than the one where the policy and env are executed.
191
+ Defaults to ``None`` (the output tensordict isn't on a specific device,
192
+ leaf tensors sit on the device where they were created).
193
+ Lists of devices are supported.
194
+ env_device (int, str or torch.device, optional): The *remote* device on which
195
+ the environment should be cast (or executed if that functionality is
196
+ supported). If not specified and the env has a non-``None`` device,
197
+ ``env_device`` will default to that value. If ``device`` is passed
198
+ and ``env_device=None``, it will default to ``device``. If the value
199
+ as such specified of ``env_device`` differs from ``policy_device``
200
+ and one of them is not ``None``, the data will be cast to ``env_device``
201
+ before being passed to the env (i.e., passing different devices to
202
+ policy and env is supported). Defaults to ``None``.
203
+ Lists of devices are supported.
204
+ policy_device (int, str or torch.device, optional): The *remote* device on which
205
+ the policy should be cast.
206
+ If ``device`` is passed and ``policy_device=None``, it will default
207
+ to ``device``. If the value as such specified of ``policy_device``
208
+ differs from ``env_device`` and one of them is not ``None``,
209
+ the data will be cast to ``policy_device`` before being passed to
210
+ the policy (i.e., passing different devices to policy and env is
211
+ supported). Defaults to ``None``.
212
+ Lists of devices are supported.
213
+ max_frames_per_traj (int, optional): Maximum steps per trajectory.
214
+ Note that a trajectory can span across multiple batches (unless
215
+ ``reset_at_each_iter`` is set to ``True``, see below).
216
+ Once a trajectory reaches ``n_steps``, the environment is reset.
217
+ If the environment wraps multiple environments together, the number
218
+ of steps is tracked for each environment independently. Negative
219
+ values are allowed, in which case this argument is ignored.
220
+ Defaults to ``None`` (i.e., no maximum number of steps).
221
+ init_random_frames (int, optional): Number of frames for which the
222
+ policy is ignored before it is called. This feature is mainly
223
+ intended to be used in offline/model-based settings, where a
224
+ batch of random trajectories can be used to initialize training.
225
+ If provided, it will be rounded up to the closest multiple of frames_per_batch.
226
+ Defaults to ``None`` (i.e. no random frames).
227
+ reset_at_each_iter (bool, optional): Whether environments should be reset
228
+ at the beginning of a batch collection.
229
+ Defaults to ``False``.
230
+ postproc (Callable, optional): A post-processing transform, such as
231
+ a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
232
+ instance.
233
+ Defaults to ``None``.
234
+ split_trajs (bool, optional): Boolean indicating whether the resulting
235
+ TensorDict should be split according to the trajectories.
236
+ See :func:`~torchrl.collectors.utils.split_trajectories` for more
237
+ information.
238
+ Defaults to ``False``.
239
+ exploration_type (ExplorationType, optional): interaction mode to be used when
240
+ collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
241
+ ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
242
+ or ``torchrl.envs.utils.ExplorationType.MEAN``.
243
+ collector_class (Type or str, optional): a collector class for the remote node. Can be
244
+ :class:`~torchrl.collectors.Collector`,
245
+ :class:`~torchrl.collectors.MultiSyncCollector`,
246
+ :class:`~torchrl.collectors.MultiAsyncCollector`
247
+ or a derived class of these. The strings "single", "sync" and
248
+ "async" correspond to respective class.
249
+ Defaults to :class:`~torchrl.collectors.Collector`.
250
+ collector_kwargs (dict or list, optional): a dictionary of parameters to be passed to the
251
+ remote data-collector. If a list is provided, each element will
252
+ correspond to an individual set of keyword arguments for the
253
+ dedicated collector.
254
+ num_workers_per_collector (int, optional): the number of copies of the
255
+ env constructor that is to be used on the remote nodes.
256
+ Defaults to 1 (a single env per collector).
257
+ On a single worker node all the sub-workers will be
258
+ executing the same environment. If different environments need to
259
+ be executed, they should be dispatched across worker nodes, not
260
+ subnodes.
261
+ slurm_kwargs (dict): a dictionary of parameters to be passed to the
262
+ submitit executor.
263
+ backend (str, optional): must a string "<distributed_backed>" where
264
+ <distributed_backed> is one of ``"gloo"``, ``"mpi"``, ``"nccl"`` or ``"ucc"``. See
265
+ the torch.distributed documentation for more information.
266
+ Defaults to ``"gloo"``.
267
+ max_weight_update_interval (int, optional): the maximum number of
268
+ batches that can be collected before the policy weights of a worker
269
+ is updated.
270
+ For sync collections, this parameter is overwritten by ``update_after_each_batch``.
271
+ For async collections, it may be that one worker has not seen its
272
+ parameters being updated for a certain time even if ``update_after_each_batch``
273
+ is turned on.
274
+ Defaults to -1 (no forced update).
275
+ update_interval (int, optional): the frequency at which the policy is
276
+ updated. Defaults to 1.
277
+ launcher (str, optional): how jobs should be launched.
278
+ Can be one of "submitit" or "mp" for multiprocessing. The former
279
+ can launch jobs across multiple nodes, whilst the latter will only
280
+ launch jobs on a single machine. "submitit" requires the homonymous
281
+ library to be installed.
282
+ To find more about submitit, visit
283
+ https://github.com/facebookincubator/submitit
284
+ Defaults to "submitit".
285
+ tcp_port (int, optional): the TCP port to be used. Defaults to 10003.
286
+ """
287
+
288
+ def __init__(
289
+ self,
290
+ create_env_fn,
291
+ policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
292
+ *,
293
+ policy_factory: Callable[[], Callable]
294
+ | list[Callable[[], Callable]]
295
+ | None = None,
296
+ frames_per_batch: int,
297
+ total_frames: int = -1,
298
+ device: torch.device | list[torch.device] = None,
299
+ storing_device: torch.device | list[torch.device] = None,
300
+ env_device: torch.device | list[torch.device] = None,
301
+ policy_device: torch.device | list[torch.device] = None,
302
+ max_frames_per_traj: int = -1,
303
+ init_random_frames: int = -1,
304
+ reset_at_each_iter: bool = False,
305
+ postproc: Callable | None = None,
306
+ split_trajs: bool = False,
307
+ exploration_type: ExporationType = DEFAULT_EXPLORATION_TYPE, # noqa
308
+ collector_class: type | Callable[[], BaseCollector] = Collector,
309
+ collector_kwargs: dict[str, Any] | None = None,
310
+ num_workers_per_collector: int = 1,
311
+ slurm_kwargs: dict[str, Any] | None = None,
312
+ backend: Literal["gloo", "nccl"] = "gloo",
313
+ max_weight_update_interval: int = -1,
314
+ update_interval: int = 1,
315
+ launcher: str = "submitit",
316
+ tcp_port: str | None = None,
317
+ ):
318
+
319
+ if collector_class == "async":
320
+ collector_class = MultiAsyncCollector
321
+ elif collector_class == "sync":
322
+ collector_class = MultiSyncCollector
323
+ elif collector_class == "single":
324
+ collector_class = Collector
325
+ self.collector_class = collector_class
326
+ self.env_constructors = create_env_fn
327
+ self.policy = policy
328
+ collector_kwargs = collector_kwargs if collector_kwargs is not None else {}
329
+
330
+ if isinstance(policy, nn.Module):
331
+ policy_weights = TensorDict.from_module(policy)
332
+ policy_weights = policy_weights.data.lock_()
333
+ else:
334
+ if collector_kwargs.get("weight_updater") is None and (
335
+ policy_factory is None
336
+ or (isinstance(policy_factory, Sequence) and not any(policy_factory))
337
+ ):
338
+ warnings.warn(_NON_NN_POLICY_WEIGHTS)
339
+ policy_weights = TensorDict(lock=True)
340
+
341
+ if not isinstance(policy_factory, Sequence):
342
+ policy_factory = [policy_factory] * len(create_env_fn)
343
+ self.policy_factory = policy_factory
344
+ self._policy_to_send = policy if not any(policy_factory) else None
345
+ self.policy_weights = policy_weights
346
+ self.num_workers = len(create_env_fn)
347
+ self.frames_per_batch = frames_per_batch
348
+ self.requested_frames_per_batch = frames_per_batch
349
+
350
+ self.device = device
351
+ self.storing_device = storing_device
352
+ self.env_device = env_device
353
+ self.policy_device = policy_device
354
+
355
+ self.storing_device = storing_device
356
+ # make private to avoid changes from users during collection
357
+ self.update_interval = update_interval
358
+ self.total_frames_per_collector = total_frames // self.num_workers
359
+ if self.total_frames_per_collector * self.num_workers != total_frames:
360
+ raise RuntimeError(
361
+ f"Cannot dispatch {total_frames} frames across {self.num_workers}. "
362
+ f"Consider using a number of frames that is divisible by the number of workers."
363
+ )
364
+ self.max_weight_update_interval = max_weight_update_interval
365
+ self.launcher = launcher
366
+ self._batches_since_weight_update = [0 for _ in range(self.num_workers)]
367
+ if tcp_port is None:
368
+ self.tcp_port = os.environ.get("TCP_PORT", "10003")
369
+ else:
370
+ self.tcp_port = str(tcp_port)
371
+
372
+ if self.frames_per_batch % self.num_workers != 0:
373
+ raise RuntimeError(
374
+ f"Cannot dispatch {self.frames_per_batch} frames across {self.num_workers}. "
375
+ f"Consider using a number of frames per batch that is divisible by the number of workers."
376
+ )
377
+ self._frames_per_batch_corrected = self.frames_per_batch // self.num_workers
378
+
379
+ self.num_workers_per_collector = num_workers_per_collector
380
+ self.total_frames = total_frames
381
+ self.slurm_kwargs = copy(DEFAULT_SLURM_CONF)
382
+ if slurm_kwargs is not None:
383
+ self.slurm_kwargs.update(slurm_kwargs)
384
+ self.collector_kwargs = (
385
+ deepcopy(collector_kwargs)
386
+ if isinstance(collector_kwargs, (list, tuple))
387
+ else [copy(collector_kwargs) for _ in range(self.num_workers)]
388
+ )
389
+
390
+ # update collector kwargs
391
+ for i, collector_kwarg in enumerate(self.collector_kwargs):
392
+ collector_kwarg["max_frames_per_traj"] = max_frames_per_traj
393
+ collector_kwarg["init_random_frames"] = (
394
+ init_random_frames // self.num_workers
395
+ )
396
+ collector_kwarg["reset_at_each_iter"] = reset_at_each_iter
397
+ collector_kwarg["exploration_type"] = exploration_type
398
+ collector_kwarg["device"] = self.device[i]
399
+ collector_kwarg["storing_device"] = self.storing_device[i]
400
+ collector_kwarg["env_device"] = self.env_device[i]
401
+ collector_kwarg["policy_device"] = self.policy_device[i]
402
+
403
+ self.postproc = postproc
404
+ self.split_trajs = split_trajs
405
+
406
+ self.backend = backend
407
+
408
+ # Create weight sync schemes for distributed weight updates
409
+ # The scheme creates its own TCPStore for coordination
410
+ self._weight_sync_schemes = None
411
+ if isinstance(policy, nn.Module):
412
+ from torchrl.weight_update import DistributedWeightSyncScheme
413
+
414
+ self._weight_sync_schemes = {
415
+ "policy": DistributedWeightSyncScheme(backend=backend, sync=False)
416
+ }
417
+ # Initialize schemes on sender BEFORE starting workers so the store
418
+ # exists when workers try to connect
419
+ for model_id, scheme in self._weight_sync_schemes.items():
420
+ torchrl_logger.debug(
421
+ f"DistributedSyncCollector: Initializing scheme for '{model_id}' on sender"
422
+ )
423
+ scheme.init_on_sender(
424
+ model_id=model_id,
425
+ context=self,
426
+ num_workers=self.num_workers,
427
+ model=policy,
428
+ )
429
+
430
+ # os.environ['TP_SOCKET_IFNAME'] = 'lo'
431
+
432
+ self._init_workers()
433
+ self._make_container()
434
+
435
+ @property
436
+ def device(self) -> list[torch.device]:
437
+ return self._device
438
+
439
+ @property
440
+ def storing_device(self) -> list[torch.device]:
441
+ return self._storing_device
442
+
443
+ @property
444
+ def env_device(self) -> list[torch.device]:
445
+ return self._env_device
446
+
447
+ @property
448
+ def policy_device(self) -> list[torch.device]:
449
+ return self._policy_device
450
+
451
+ @device.setter
452
+ def device(self, value):
453
+ if isinstance(value, (tuple, list)):
454
+ if len(value) != self.num_workers:
455
+ raise RuntimeError(
456
+ "The number of devices passed to the collector must match the number of workers."
457
+ )
458
+ self._device = value
459
+ else:
460
+ self._device = [value] * self.num_workers
461
+
462
+ @storing_device.setter
463
+ def storing_device(self, value):
464
+ if isinstance(value, (tuple, list)):
465
+ if len(value) != self.num_workers:
466
+ raise RuntimeError(
467
+ "The number of devices passed to the collector must match the number of workers."
468
+ )
469
+ self._storing_device = value
470
+ else:
471
+ self._storing_device = [value] * self.num_workers
472
+
473
+ @env_device.setter
474
+ def env_device(self, value):
475
+ if isinstance(value, (tuple, list)):
476
+ if len(value) != self.num_workers:
477
+ raise RuntimeError(
478
+ "The number of devices passed to the collector must match the number of workers."
479
+ )
480
+ self._env_device = value
481
+ else:
482
+ self._env_device = [value] * self.num_workers
483
+
484
+ @policy_device.setter
485
+ def policy_device(self, value):
486
+ if isinstance(value, (tuple, list)):
487
+ if len(value) != self.num_workers:
488
+ raise RuntimeError(
489
+ "The number of devices passed to the collector must match the number of workers."
490
+ )
491
+ self._policy_device = value
492
+ else:
493
+ self._policy_device = [value] * self.num_workers
494
+
495
+ def _init_master_dist(
496
+ self,
497
+ world_size,
498
+ backend,
499
+ ):
500
+ TCP_PORT = self.tcp_port
501
+ torchrl_logger.debug("init master...")
502
+ torch.distributed.init_process_group(
503
+ backend,
504
+ rank=0,
505
+ world_size=world_size,
506
+ timeout=timedelta(MAX_TIME_TO_CONNECT),
507
+ init_method=f"tcp://{self.IPAddr}:{TCP_PORT}",
508
+ )
509
+ torchrl_logger.debug("done")
510
+
511
+ def _make_container(self):
512
+ env_constructor = self.env_constructors[0]
513
+ pseudo_collector = Collector(
514
+ env_constructor,
515
+ self.policy,
516
+ frames_per_batch=self._frames_per_batch_corrected,
517
+ total_frames=self.total_frames,
518
+ split_trajs=False,
519
+ )
520
+ for _data in pseudo_collector:
521
+ break
522
+ self._tensordict_out = _data.expand((self.num_workers, *_data.shape))
523
+ self._single_tds = self._tensordict_out.unbind(0)
524
+ self._tensordict_out.lock_()
525
+ pseudo_collector.shutdown()
526
+ del pseudo_collector
527
+
528
+ def _init_worker_dist_submitit(self, executor, i):
529
+ TCP_PORT = self.tcp_port
530
+ env_make = self.env_constructors[i]
531
+ if not isinstance(env_make, (EnvBase, EnvCreator)):
532
+ env_make = CloudpickleWrapper(env_make)
533
+ job = executor.submit(
534
+ _distributed_init_collection_node,
535
+ rank=i + 1,
536
+ rank0_ip=self.IPAddr,
537
+ tcpport=int(TCP_PORT),
538
+ world_size=self.num_workers + 1,
539
+ backend=self.backend,
540
+ collector_class=self.collector_class,
541
+ num_workers=self.num_workers_per_collector,
542
+ env_make=env_make,
543
+ policy=self._policy_to_send,
544
+ policy_factory=self.policy_factory[i],
545
+ frames_per_batch=self._frames_per_batch_corrected,
546
+ collector_kwargs=self.collector_kwargs[i],
547
+ update_interval=self.update_interval,
548
+ total_frames=self.total_frames_per_collector,
549
+ weight_sync_schemes=self._weight_sync_schemes,
550
+ verbose=VERBOSE,
551
+ )
552
+ return job
553
+
554
+ def _init_worker_dist_mp(self, i):
555
+ TCP_PORT = self.tcp_port
556
+ env_make = self.env_constructors[i]
557
+ if not isinstance(env_make, (EnvBase, EnvCreator)):
558
+ env_make = CloudpickleWrapper(env_make)
559
+ job = _ProcessNoWarn(
560
+ target=_distributed_init_collection_node,
561
+ kwargs=dict( # noqa: C408
562
+ rank=i + 1,
563
+ rank0_ip=self.IPAddr,
564
+ tcpport=int(TCP_PORT),
565
+ world_size=self.num_workers + 1,
566
+ backend=self.backend,
567
+ collector_class=self.collector_class,
568
+ num_workers=self.num_workers_per_collector,
569
+ env_make=env_make,
570
+ policy=self._policy_to_send,
571
+ policy_factory=self.policy_factory[i],
572
+ frames_per_batch=self._frames_per_batch_corrected,
573
+ collector_kwargs=self.collector_kwargs[i],
574
+ update_interval=self.update_interval,
575
+ total_frames=self.total_frames_per_collector,
576
+ weight_sync_schemes=self._weight_sync_schemes,
577
+ verbose=VERBOSE,
578
+ ),
579
+ )
580
+ job.start()
581
+ return job
582
+
583
+ def _init_workers(self):
584
+
585
+ hostname = socket.gethostname()
586
+ IPAddr = socket.gethostbyname(hostname)
587
+ torchrl_logger.debug(f"Server IP address: {IPAddr}")
588
+ self.IPAddr = IPAddr
589
+ os.environ["MASTER_ADDR"] = str(self.IPAddr)
590
+ os.environ["MASTER_PORT"] = str(self.tcp_port)
591
+
592
+ self.jobs = []
593
+ if self.launcher == "submitit":
594
+ if not _has_submitit:
595
+ raise ImportError("submitit not found.") from SUBMITIT_ERR
596
+ executor = submitit.AutoExecutor(folder="log_test")
597
+ executor.update_parameters(**self.slurm_kwargs)
598
+ for i in range(self.num_workers):
599
+ torchrl_logger.debug("Submitting job")
600
+ if self.launcher == "submitit":
601
+ job = self._init_worker_dist_submitit(
602
+ executor,
603
+ i,
604
+ )
605
+ torchrl_logger.debug(f"job id {job.job_id}") # ID of your job
606
+ elif self.launcher == "mp":
607
+ job = self._init_worker_dist_mp(
608
+ i,
609
+ )
610
+ torchrl_logger.debug("job launched")
611
+ self.jobs.append(job)
612
+
613
+ # Initialize process group and weight sync
614
+ # If we have schemes, they handle init_process_group in connect()
615
+ # Otherwise, we need to init manually for data.irecv to work
616
+ if self._weight_sync_schemes is not None:
617
+ for model_id, scheme in self._weight_sync_schemes.items():
618
+ torchrl_logger.debug(
619
+ f"DistributedSyncCollector: Connecting scheme '{model_id}' (will init process group)"
620
+ )
621
+ scheme.connect()
622
+ torchrl_logger.debug(
623
+ "DistributedSyncCollector: Initial weight sync completed"
624
+ )
625
+ else:
626
+ # No schemes - init process group manually
627
+ self._init_master_dist(self.num_workers + 1, self.backend)
628
+
629
+ def iterator(self):
630
+ yield from self._iterator_dist()
631
+
632
+ def _iterator_dist(self):
633
+
634
+ total_frames = 0
635
+ j = -1
636
+ while total_frames < self.total_frames:
637
+ j += 1
638
+ if j % self.update_interval == 0 and self._weight_sync_schemes is not None:
639
+ # Send weight updates via the schemes
640
+ # Each scheme handles extracting weights from the policy and sending
641
+ for scheme in self._weight_sync_schemes.values():
642
+ scheme.send()
643
+
644
+ trackers = []
645
+ for i in range(self.num_workers):
646
+ rank = i + 1
647
+ torchrl_logger.debug(f"Receiving from rank {rank} on main")
648
+ trackers.append(
649
+ self._single_tds[i].irecv(src=rank, return_premature=True)
650
+ )
651
+ torchrl_logger.debug(f"Received from rank {rank} on main")
652
+ for tracker in trackers:
653
+ for _tracker in tracker:
654
+ _tracker.wait()
655
+
656
+ data = self._tensordict_out.clone()
657
+ traj_ids = data.get(("collector", "traj_ids"), None)
658
+ if traj_ids is not None:
659
+ for i in range(1, self.num_workers):
660
+ traj_ids[i] += traj_ids[i - 1].max()
661
+ data.set_(("collector", "traj_ids"), traj_ids)
662
+ total_frames += data.numel()
663
+ if self.split_trajs:
664
+ data = split_trajectories(data)
665
+ if self.postproc is not None:
666
+ data = self.postproc(data)
667
+ yield data
668
+
669
+ def update_policy_weights_(
670
+ self,
671
+ policy_or_weights: TensorDictBase | None = None,
672
+ *,
673
+ worker_ids=None,
674
+ wait=True,
675
+ **kwargs,
676
+ ) -> None:
677
+ raise NotImplementedError
678
+
679
+ def set_seed(self, seed: int, static_seed: bool = False) -> int:
680
+ raise NotImplementedError
681
+
682
+ def state_dict(self) -> OrderedDict:
683
+ raise NotImplementedError
684
+
685
+ def load_state_dict(self, state_dict: OrderedDict) -> None:
686
+ raise NotImplementedError
687
+
688
+ def shutdown(self, timeout: float | None = None) -> None:
689
+ # Prevent double shutdown
690
+ if getattr(self, "_shutdown", False):
691
+ return
692
+ self._shutdown = True
693
+
694
+ # Wait for workers to exit
695
+ if hasattr(self, "jobs"):
696
+ for job in self.jobs:
697
+ if self.launcher == "mp":
698
+ if hasattr(job, "is_alive") and job.is_alive():
699
+ job.join(timeout=timeout if timeout is not None else 10)
700
+ elif self.launcher == "submitit":
701
+ try:
702
+ job.result()
703
+ except Exception:
704
+ pass
705
+
706
+ # Clean up weight sync schemes AFTER workers have exited
707
+ if self._weight_sync_schemes is not None:
708
+ torchrl_logger.debug("shutting down weight sync schemes")
709
+ for scheme in self._weight_sync_schemes.values():
710
+ try:
711
+ scheme.shutdown()
712
+ except Exception as e:
713
+ torchrl_logger.warning(
714
+ f"Error shutting down weight sync scheme: {e}"
715
+ )
716
+ self._weight_sync_schemes = None
717
+
718
+ # Destroy torch.distributed process group
719
+ if torch.distributed.is_initialized():
720
+ torchrl_logger.debug("destroying process group")
721
+ torch.distributed.destroy_process_group()
722
+
723
+ torchrl_logger.debug("collector shut down")
724
+
725
+
726
+ class DistributedSyncDataCollector(
727
+ DistributedSyncCollector, metaclass=_LegacyCollectorMeta
728
+ ):
729
+ """Deprecated version of :class:`~torchrl.collectors.distributed.DistributedSyncCollector`."""
730
+
731
+ ...