torchrl 0.11.0__cp314-cp314t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1006 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ r"""Generic distributed data-collector using torch.distributed.rpc backend."""
7
+ from __future__ import annotations
8
+
9
+ import collections
10
+ import os
11
+ import socket
12
+ import time
13
+ import warnings
14
+ from collections import OrderedDict
15
+ from collections.abc import Callable, Sequence
16
+ from copy import copy, deepcopy
17
+ from typing import Any
18
+
19
+ import torch.cuda
20
+
21
+ from tensordict import TensorDict, TensorDictBase
22
+ from torch import nn
23
+
24
+ from torch.distributed import rpc
25
+ from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE
26
+ from torchrl.collectors._base import _LegacyCollectorMeta, BaseCollector
27
+
28
+ from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE
29
+ from torchrl.collectors._multi_async import MultiAsyncCollector
30
+ from torchrl.collectors._multi_sync import MultiSyncCollector
31
+ from torchrl.collectors._single import Collector
32
+ from torchrl.collectors.distributed.default_configs import (
33
+ DEFAULT_SLURM_CONF,
34
+ DEFAULT_TENSORPIPE_OPTIONS,
35
+ IDLE_TIMEOUT,
36
+ TCP_PORT,
37
+ )
38
+ from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories
39
+ from torchrl.collectors.weight_update import WeightUpdaterBase
40
+ from torchrl.data.utils import CloudpickleWrapper
41
+ from torchrl.envs.common import EnvBase
42
+ from torchrl.envs.env_creator import EnvCreator
43
+ from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme
44
+
45
+ SUBMITIT_ERR = None
46
+ try:
47
+ import submitit
48
+
49
+ _has_submitit = True
50
+ except ModuleNotFoundError as err:
51
+ _has_submitit = False
52
+ SUBMITIT_ERR = err
53
+
54
+
55
+ def _rpc_init_collection_node(
56
+ rank,
57
+ rank0_ip,
58
+ tcp_port,
59
+ world_size,
60
+ visible_device,
61
+ tensorpipe_options,
62
+ backend="gloo",
63
+ verbose=VERBOSE,
64
+ ):
65
+ os.environ["MASTER_ADDR"] = str(rank0_ip)
66
+ os.environ["MASTER_PORT"] = str(tcp_port)
67
+
68
+ # Initialize torch.distributed process group for efficient weight transfer
69
+ if verbose:
70
+ torchrl_logger.debug(
71
+ f"init distributed with rank={rank}, world_size={world_size}, backend={backend}"
72
+ )
73
+ torch.distributed.init_process_group(
74
+ backend=backend,
75
+ rank=rank,
76
+ world_size=world_size,
77
+ )
78
+
79
+ if isinstance(visible_device, list):
80
+ pass
81
+ elif isinstance(visible_device, (str, int, torch.device)):
82
+ visible_device = [visible_device]
83
+ elif visible_device is None:
84
+ pass
85
+ else:
86
+ raise RuntimeError(f"unrecognised dtype {type(visible_device)}")
87
+
88
+ options = rpc.TensorPipeRpcBackendOptions(
89
+ devices=visible_device,
90
+ **tensorpipe_options,
91
+ )
92
+ if verbose:
93
+ torchrl_logger.debug(
94
+ f"init rpc with master addr: {os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
95
+ )
96
+ rpc.init_rpc(
97
+ f"COLLECTOR_NODE_{rank}",
98
+ rank=rank,
99
+ backend=rpc.BackendType.TENSORPIPE,
100
+ rpc_backend_options=options,
101
+ world_size=world_size,
102
+ )
103
+ rpc.shutdown()
104
+ torch.distributed.destroy_process_group()
105
+
106
+
107
+ class RPCCollector(BaseCollector):
108
+ """An RPC-based distributed data collector.
109
+
110
+ Supports sync and async data collection.
111
+
112
+ Args:
113
+ create_env_fn (Callable or List[Callabled]): list of Callables, each returning an
114
+ instance of :class:`~torchrl.envs.EnvBase`.
115
+ policy (Callable): Policy to be executed in the environment.
116
+ Must accept :class:`tensordict.tensordict.TensorDictBase` object as input.
117
+ If ``None`` is provided, the policy used will be a
118
+ :class:`~torchrl.collectors.RandomPolicy` instance with the environment
119
+ ``action_spec``.
120
+ Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`.
121
+ This is the recommended usage of the collector.
122
+ Other callables are accepted too:
123
+ If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module`
124
+ instances) it will be wrapped in a `nn.Module` first.
125
+ Then, the collector will try to assess if these
126
+ modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
127
+
128
+ - If the policy forward signature matches any of ``forward(self, tensordict)``,
129
+ ``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
130
+ any typing with a single argument typed as a subclass of ``TensorDictBase``)
131
+ then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
132
+
133
+ - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
134
+
135
+ .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
136
+ pickled directly), the ``policy_factory`` should be used instead.
137
+
138
+ Keyword Args:
139
+ policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
140
+ (or list of callables) that returns a policy instance. This is exclusive with the `policy` argument.
141
+
142
+ .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
143
+
144
+ frames_per_batch (int): A keyword-only argument representing the total
145
+ number of elements in a batch.
146
+ total_frames (int): A keyword-only argument representing the total
147
+ number of frames returned by the collector
148
+ during its lifespan. If the ``total_frames`` is not divisible by
149
+ ``frames_per_batch``, an exception is raised.
150
+ Endless collectors can be created by passing ``total_frames=-1``.
151
+ Defaults to ``-1`` (endless collector).
152
+ device (int, str or torch.device, optional): The generic device of the
153
+ collector. The ``device`` args fills any non-specified device: if
154
+ ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
155
+ ``env_device`` is not specified, its value will be set to ``device``.
156
+ Defaults to ``None`` (No default device).
157
+ Lists of devices are supported.
158
+ storing_device (int, str or torch.device, optional): The *remote* device on which
159
+ the output :class:`~tensordict.TensorDict` will be stored.
160
+ If ``device`` is passed and ``storing_device`` is ``None``, it will
161
+ default to the value indicated by ``device``.
162
+ For long trajectories, it may be necessary to store the data on a different
163
+ device than the one where the policy and env are executed.
164
+ Defaults to ``None`` (the output tensordict isn't on a specific device,
165
+ leaf tensors sit on the device where they were created).
166
+ Lists of devices are supported.
167
+ env_device (int, str or torch.device, optional): The *remote* device on which
168
+ the environment should be cast (or executed if that functionality is
169
+ supported). If not specified and the env has a non-``None`` device,
170
+ ``env_device`` will default to that value. If ``device`` is passed
171
+ and ``env_device=None``, it will default to ``device``. If the value
172
+ as such specified of ``env_device`` differs from ``policy_device``
173
+ and one of them is not ``None``, the data will be cast to ``env_device``
174
+ before being passed to the env (i.e., passing different devices to
175
+ policy and env is supported). Defaults to ``None``.
176
+ Lists of devices are supported.
177
+ policy_device (int, str or torch.device, optional): The *remote* device on which
178
+ the policy should be cast.
179
+ If ``device`` is passed and ``policy_device=None``, it will default
180
+ to ``device``. If the value as such specified of ``policy_device``
181
+ differs from ``env_device`` and one of them is not ``None``,
182
+ the data will be cast to ``policy_device`` before being passed to
183
+ the policy (i.e., passing different devices to policy and env is
184
+ supported). Defaults to ``None``.
185
+ Lists of devices are supported.
186
+ max_frames_per_traj (int, optional): Maximum steps per trajectory.
187
+ Note that a trajectory can span across multiple batches (unless
188
+ ``reset_at_each_iter`` is set to ``True``, see below).
189
+ Once a trajectory reaches ``n_steps``, the environment is reset.
190
+ If the environment wraps multiple environments together, the number
191
+ of steps is tracked for each environment independently. Negative
192
+ values are allowed, in which case this argument is ignored.
193
+ Defaults to ``None`` (i.e., no maximum number of steps).
194
+ init_random_frames (int, optional): Number of frames for which the
195
+ policy is ignored before it is called. This feature is mainly
196
+ intended to be used in offline/model-based settings, where a
197
+ batch of random trajectories can be used to initialize training.
198
+ If provided, it will be rounded up to the closest multiple of frames_per_batch.
199
+ Defaults to ``None`` (i.e. no random frames).
200
+ reset_at_each_iter (bool, optional): Whether environments should be reset
201
+ at the beginning of a batch collection.
202
+ Defaults to ``False``.
203
+ postproc (Callable, optional): A post-processing transform, such as
204
+ a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
205
+ instance.
206
+ Defaults to ``None``.
207
+ split_trajs (bool, optional): Boolean indicating whether the resulting
208
+ TensorDict should be split according to the trajectories.
209
+ See :func:`~torchrl.collectors.utils.split_trajectories` for more
210
+ information.
211
+ Defaults to ``False``.
212
+ exploration_type (ExplorationType, optional): interaction mode to be used when
213
+ collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
214
+ ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
215
+ or ``torchrl.envs.utils.ExplorationType.MEAN``.
216
+ Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``.
217
+ collector_class (Type or str, optional): a collector class for the remote node. Can be
218
+ :class:`~torchrl.collectors.Collector`,
219
+ :class:`~torchrl.collectors.MultiSyncCollector`,
220
+ :class:`~torchrl.collectors.MultiAsyncCollector`
221
+ or a derived class of these. The strings "single", "sync" and
222
+ "async" correspond to respective class.
223
+ Defaults to :class:`~torchrl.collectors.Collector`.
224
+
225
+ .. note::
226
+
227
+ Support for :class:`MultiSyncCollector` and :class:`MultiAsyncCollector`
228
+ is experimental, and :class:`~torchrl.collectors.Collector`
229
+ should always be preferred. If multiple simultaneous environment
230
+ need to be executed on a single node, consider using a
231
+ :class:`~torchrl.envs.ParallelEnv` instance.
232
+ collector_kwargs (dict or list, optional): a dictionary of parameters to be passed to the
233
+ remote data-collector. If a list is provided, each element will
234
+ correspond to an individual set of keyword arguments for the
235
+ dedicated collector.
236
+ num_workers_per_collector (int, optional): the number of copies of the
237
+ env constructor that is to be used on the remote nodes.
238
+ Defaults to 1 (a single env per collector).
239
+ On a single worker node all the sub-workers will be
240
+ executing the same environment. If different environments need to
241
+ be executed, they should be dispatched across worker nodes, not
242
+ subnodes.
243
+ sync (bool, optional): if ``True``, the resulting tensordict is a stack of all the
244
+ tensordicts collected on each node. If ``False`` (default), each
245
+ tensordict results from a separate node in a "first-ready,
246
+ first-served" fashion.
247
+ slurm_kwargs (dict): a dictionary of parameters to be passed to the
248
+ submitit executor.
249
+ update_after_each_batch (bool, optional): if ``True``, the weights will
250
+ be updated after each collection. For ``sync=True``, this means that
251
+ all workers will see their weights updated. For ``sync=False``,
252
+ only the worker from which the data has been gathered will be
253
+ updated.
254
+ Defaults to ``False``, ie. updates have to be executed manually
255
+ through
256
+ :meth:`~torchrl.collectors.distributed.DistributedDataCollector.update_policy_weights_`.
257
+ max_weight_update_interval (int, optional): the maximum number of
258
+ batches that can be collected before the policy weights of a worker
259
+ is updated.
260
+ For sync collections, this parameter is overwritten by ``update_after_each_batch``.
261
+ For async collections, it may be that one worker has not seen its
262
+ parameters being updated for a certain time even if ``update_after_each_batch``
263
+ is turned on.
264
+ Defaults to -1 (no forced update).
265
+ launcher (str, optional): how jobs should be launched.
266
+ Can be one of "submitit" or "mp" for multiprocessing. The former
267
+ can launch jobs across multiple nodes, whilst the latter will only
268
+ launch jobs on a single machine. "submitit" requires the homonymous
269
+ library to be installed.
270
+ To find more about submitit, visit
271
+ https://github.com/facebookincubator/submitit
272
+ Defaults to "submitit".
273
+ tcp_port (int, optional): the TCP port to be used. Defaults to 10003.
274
+ backend (str, optional): the torch.distributed backend to use for weight synchronization.
275
+ Must be one of ``"gloo"``, ``"mpi"``, ``"nccl"`` or ``"ucc"``. See the torch.distributed
276
+ documentation for more information. Defaults to ``"gloo"``.
277
+ visible_devices (list of Union[int, torch.device, str], optional): a
278
+ list of the same length as the number of nodes containing the
279
+ device used to pass data to main.
280
+ tensorpipe_options (dict, optional): a dictionary of keyword argument
281
+ to pass to :class:`torch.distributed.rpc.TensorPipeRpcBackendOption`.
282
+ weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase`
283
+ or its subclass, responsible for updating the policy weights on remote inference workers using RPC.
284
+ If not provided, an :class:`~torchrl.collectors.distributed.RPCWeightUpdater` will be used by default, which
285
+ handles weight synchronization via RPC.
286
+ Consider using a constructor if the updater needs to be serialized.
287
+ weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for
288
+ SENDING weights to remote collector workers. Keys are model identifiers (e.g., "policy")
289
+ and values are WeightSyncScheme instances configured to send weights via RPC.
290
+ If not provided, an :class:`~torchrl.weight_update.RPCWeightSyncScheme` will be used by default.
291
+ This is for propagating weights from the main process to remote collectors.
292
+ weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for
293
+ RECEIVING weights from a parent process or training loop. Keys are model identifiers (e.g., "policy")
294
+ and values are WeightSyncScheme instances configured to receive weights.
295
+ This is typically used when RPCDataCollector is itself a worker in a larger distributed setup.
296
+ Defaults to ``None``.
297
+
298
+ """
299
+
300
+ _VERBOSE = VERBOSE # for debugging
301
+
302
+ def __init__(
303
+ self,
304
+ create_env_fn,
305
+ policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
306
+ *,
307
+ policy_factory: Callable[[], Callable]
308
+ | list[Callable[[]], Callable]
309
+ | None = None,
310
+ frames_per_batch: int,
311
+ total_frames: int = -1,
312
+ device: torch.device | list[torch.device] = None,
313
+ storing_device: torch.device | list[torch.device] = None,
314
+ env_device: torch.device | list[torch.device] = None,
315
+ policy_device: torch.device | list[torch.device] = None,
316
+ max_frames_per_traj: int = -1,
317
+ init_random_frames: int = -1,
318
+ reset_at_each_iter: bool = False,
319
+ postproc: Callable | None = None,
320
+ split_trajs: bool = False,
321
+ exploration_type: ExporationType = DEFAULT_EXPLORATION_TYPE, # noqa
322
+ collector_class: type = Collector,
323
+ collector_kwargs: dict[str, Any] | None = None,
324
+ num_workers_per_collector: int = 1,
325
+ sync: bool = False,
326
+ slurm_kwargs: dict[str, Any] | None = None,
327
+ update_after_each_batch: bool = False,
328
+ max_weight_update_interval: int = -1,
329
+ launcher: str = "submitit",
330
+ tcp_port: str | None = None,
331
+ backend: str = "gloo",
332
+ visible_devices: list[torch.device] | None = None,
333
+ tensorpipe_options: dict[str, Any] | None = None,
334
+ weight_updater: WeightUpdaterBase
335
+ | Callable[[], WeightUpdaterBase]
336
+ | None = None,
337
+ weight_sync_schemes: dict[str, WeightSyncScheme] | None = None,
338
+ weight_recv_schemes: dict[str, WeightSyncScheme] | None = None,
339
+ ):
340
+
341
+ if self._VERBOSE:
342
+ torchrl_logger.setLevel("DEBUG")
343
+
344
+ if collector_class == "async":
345
+ collector_class = MultiAsyncCollector
346
+ elif collector_class == "sync":
347
+ collector_class = MultiSyncCollector
348
+ elif collector_class == "single":
349
+ collector_class = Collector
350
+ self.collector_class = collector_class
351
+ self.env_constructors = create_env_fn
352
+ self.policy = policy
353
+ if isinstance(policy, nn.Module):
354
+ policy_weights = TensorDict.from_module(policy)
355
+ policy_weights = policy_weights.data.lock_()
356
+ else:
357
+ if weight_updater is None and (
358
+ policy_factory is None
359
+ or (isinstance(policy_factory, Sequence) and not any(policy_factory))
360
+ ):
361
+ warnings.warn(_NON_NN_POLICY_WEIGHTS)
362
+ policy_weights = TensorDict(lock=True)
363
+
364
+ if not isinstance(policy_factory, Sequence):
365
+ policy_factory = [policy_factory] * len(create_env_fn)
366
+ self.policy_factory = policy_factory
367
+ self.policy_weights = policy_weights
368
+ self.num_workers = len(create_env_fn)
369
+ self.frames_per_batch = frames_per_batch
370
+ self.requested_frames_per_batch = frames_per_batch
371
+
372
+ self.device = device
373
+ self.storing_device = storing_device
374
+ self.env_device = env_device
375
+ self.policy_device = policy_device
376
+
377
+ self.storing_device = storing_device
378
+ # make private to avoid changes from users during collection
379
+ self._sync = sync
380
+ self.update_after_each_batch = update_after_each_batch
381
+ self.max_weight_update_interval = max_weight_update_interval
382
+ if self.update_after_each_batch and self.max_weight_update_interval > -1:
383
+ raise RuntimeError(
384
+ "Got conflicting update instructions: `update_after_each_batch` "
385
+ "`max_weight_update_interval` are incompatible."
386
+ )
387
+ self.launcher = launcher
388
+ self._batches_since_weight_update = [0 for _ in range(self.num_workers)]
389
+ if tcp_port is None:
390
+ self.tcp_port = os.environ.get("TCP_PORT", TCP_PORT)
391
+ else:
392
+ self.tcp_port = str(tcp_port)
393
+ self.visible_devices = visible_devices
394
+ if self._sync:
395
+ if self.frames_per_batch % self.num_workers != 0:
396
+ raise RuntimeError(
397
+ f"Cannot dispatch {self.frames_per_batch} frames across {self.num_workers}. "
398
+ f"Consider using a number of frames per batch that is divisible by the number of workers."
399
+ )
400
+ self._frames_per_batch_corrected = self.frames_per_batch // self.num_workers
401
+ else:
402
+ self._frames_per_batch_corrected = self.frames_per_batch
403
+
404
+ self.num_workers_per_collector = num_workers_per_collector
405
+ self.total_frames = total_frames
406
+ self.slurm_kwargs = copy(DEFAULT_SLURM_CONF)
407
+ if slurm_kwargs is not None:
408
+ self.slurm_kwargs.update(slurm_kwargs)
409
+
410
+ collector_kwargs = collector_kwargs if collector_kwargs is not None else {}
411
+ self.collector_kwargs = (
412
+ deepcopy(collector_kwargs)
413
+ if isinstance(collector_kwargs, (list, tuple))
414
+ else [copy(collector_kwargs) for _ in range(self.num_workers)]
415
+ )
416
+
417
+ # update collector kwargs
418
+ for i, collector_kwarg in enumerate(self.collector_kwargs):
419
+ collector_kwarg["max_frames_per_traj"] = max_frames_per_traj
420
+ collector_kwarg["init_random_frames"] = (
421
+ init_random_frames // self.num_workers
422
+ )
423
+ if not self._sync and init_random_frames > 0:
424
+ warnings.warn(
425
+ "async distributed data collection with init_random_frames > 0 "
426
+ "may have unforeseen consequences as we do not control that once "
427
+ "non-random data is being collected all nodes are returning non-random data. "
428
+ "If this is a feature that you feel should be fixed, please raise an issue on "
429
+ "torchrl's repo."
430
+ )
431
+ collector_kwarg["reset_at_each_iter"] = reset_at_each_iter
432
+ collector_kwarg["exploration_type"] = exploration_type
433
+ collector_kwarg["device"] = self.device[i]
434
+ collector_kwarg["storing_device"] = self.storing_device[i]
435
+ collector_kwarg["env_device"] = self.env_device[i]
436
+ collector_kwarg["policy_device"] = self.policy_device[i]
437
+
438
+ self.postproc = postproc
439
+ self.split_trajs = split_trajs
440
+ self.backend = backend
441
+
442
+ if tensorpipe_options is None:
443
+ self.tensorpipe_options = copy(DEFAULT_TENSORPIPE_OPTIONS)
444
+ else:
445
+ self.tensorpipe_options = copy(DEFAULT_TENSORPIPE_OPTIONS).update(
446
+ tensorpipe_options
447
+ )
448
+
449
+ # Set up weight synchronization - prefer new schemes over legacy updater
450
+ if weight_updater is None and weight_sync_schemes is None:
451
+ # Default to RPC weight sync scheme for RPC collectors
452
+ from torchrl.weight_update import RPCWeightSyncScheme
453
+
454
+ weight_sync_schemes = {"policy": RPCWeightSyncScheme()}
455
+
456
+ if weight_sync_schemes is not None:
457
+ # Use new weight synchronization system
458
+ self._weight_sync_schemes = weight_sync_schemes
459
+ self.weight_updater = None
460
+ else:
461
+ # Fall back to legacy weight updater system
462
+ if weight_updater is None:
463
+ weight_updater = RPCWeightUpdater(
464
+ collector_infos=self.collector_infos,
465
+ collector_class=self.collector_class,
466
+ collector_rrefs=self.collector_rrefs,
467
+ policy_weights=self.policy_weights,
468
+ num_workers=self.num_workers,
469
+ )
470
+ self.weight_updater = weight_updater
471
+ self._weight_sync_schemes = None
472
+
473
+ self._init()
474
+
475
+ if weight_sync_schemes is not None:
476
+ # Set up weight senders now that remote collectors exist
477
+ for model_id, scheme in self._weight_sync_schemes.items():
478
+ scheme.init_on_sender(
479
+ model_id=model_id,
480
+ num_workers=self.num_workers,
481
+ context=self,
482
+ )
483
+ scheme.connect()
484
+
485
+ # Set up weight receivers if provided
486
+ if weight_recv_schemes is not None:
487
+ self.register_scheme_receiver(weight_recv_schemes)
488
+
489
+ @property
490
+ def device(self) -> list[torch.device]:
491
+ return self._device
492
+
493
+ @property
494
+ def storing_device(self) -> list[torch.device]:
495
+ return self._storing_device
496
+
497
+ @property
498
+ def env_device(self) -> list[torch.device]:
499
+ return self._env_device
500
+
501
+ @property
502
+ def policy_device(self) -> list[torch.device]:
503
+ return self._policy_device
504
+
505
+ @device.setter
506
+ def device(self, value):
507
+ if isinstance(value, (tuple, list)):
508
+ if len(value) != self.num_workers:
509
+ raise RuntimeError(
510
+ "The number of devices passed to the collector must match the number of workers."
511
+ )
512
+ self._device = value
513
+ else:
514
+ self._device = [value] * self.num_workers
515
+
516
+ @storing_device.setter
517
+ def storing_device(self, value):
518
+ if isinstance(value, (tuple, list)):
519
+ if len(value) != self.num_workers:
520
+ raise RuntimeError(
521
+ "The number of devices passed to the collector must match the number of workers."
522
+ )
523
+ self._storing_device = value
524
+ else:
525
+ self._storing_device = [value] * self.num_workers
526
+
527
+ @env_device.setter
528
+ def env_device(self, value):
529
+ if isinstance(value, (tuple, list)):
530
+ if len(value) != self.num_workers:
531
+ raise RuntimeError(
532
+ "The number of devices passed to the collector must match the number of workers."
533
+ )
534
+ self._env_device = value
535
+ else:
536
+ self._env_device = [value] * self.num_workers
537
+
538
+ @policy_device.setter
539
+ def policy_device(self, value):
540
+ if isinstance(value, (tuple, list)):
541
+ if len(value) != self.num_workers:
542
+ raise RuntimeError(
543
+ "The number of devices passed to the collector must match the number of workers."
544
+ )
545
+ self._policy_device = value
546
+ else:
547
+ self._policy_device = [value] * self.num_workers
548
+
549
+ def _init_master_rpc(
550
+ self,
551
+ world_size,
552
+ ):
553
+ """Init torch.distributed and RPC on main node."""
554
+ # Initialize torch.distributed process group for efficient weight transfer
555
+ torchrl_logger.debug(
556
+ f"init distributed with rank=0, world_size={world_size}, backend={self.backend}"
557
+ )
558
+ torch.distributed.init_process_group(
559
+ backend=self.backend,
560
+ rank=0,
561
+ world_size=world_size,
562
+ )
563
+
564
+ # Initialize RPC for control/signaling
565
+ options = rpc.TensorPipeRpcBackendOptions(**self.tensorpipe_options)
566
+ if torch.cuda.is_available():
567
+ if self.visible_devices:
568
+ for i in range(self.num_workers):
569
+ rank = i + 1
570
+ options.set_device_map(
571
+ f"COLLECTOR_NODE_{rank}", {0: self.visible_devices[i]}
572
+ )
573
+ torchrl_logger.debug("init rpc")
574
+ rpc.init_rpc(
575
+ "TRAINER_NODE",
576
+ rank=0,
577
+ backend=rpc.BackendType.TENSORPIPE,
578
+ rpc_backend_options=options,
579
+ world_size=world_size,
580
+ )
581
+
582
+ def _start_workers(
583
+ self,
584
+ world_size,
585
+ env_constructors,
586
+ collector_class,
587
+ num_workers_per_collector,
588
+ policy,
589
+ policy_factory,
590
+ frames_per_batch,
591
+ total_frames,
592
+ collector_kwargs,
593
+ ):
594
+ """Instantiate remote collectors."""
595
+ num_workers = world_size - 1
596
+ time_interval = 1.0
597
+ collector_infos = []
598
+ for i in range(num_workers):
599
+ counter = 0
600
+ while True:
601
+ counter += 1
602
+ time.sleep(time_interval)
603
+ try:
604
+ torchrl_logger.debug(f"trying to connect to collector node {i + 1}")
605
+ collector_info = rpc.get_worker_info(f"COLLECTOR_NODE_{i + 1}")
606
+ break
607
+ except RuntimeError as err:
608
+ if counter * time_interval > self.tensorpipe_options["rpc_timeout"]:
609
+ raise RuntimeError("Could not connect to remote node") from err
610
+ continue
611
+ collector_infos.append(collector_info)
612
+
613
+ collector_rrefs = []
614
+ for i in range(num_workers):
615
+ env_make = env_constructors[i]
616
+ if not isinstance(env_make, (EnvBase, EnvCreator)):
617
+ env_make = CloudpickleWrapper(env_make)
618
+ torchrl_logger.debug("Making collector in remote node")
619
+ # When using weight sync schemes together with a policy_factory, the
620
+ # main-node `policy` should be used only as a weight source on the
621
+ # trainer, and NOT sent to remote collectors (which will build their
622
+ # own policies from the factory). This mirrors the behaviour of
623
+ # `DistributedDataCollector` with multi-process collectors.
624
+ policy_to_send = (
625
+ None
626
+ if (
627
+ policy is not None
628
+ and policy_factory[i] is not None
629
+ and getattr(self, "_weight_sync_schemes", None) is not None
630
+ )
631
+ else policy
632
+ )
633
+
634
+ collector_rref = rpc.remote(
635
+ collector_infos[i],
636
+ collector_class,
637
+ args=(
638
+ [env_make] * num_workers_per_collector
639
+ if collector_class is not Collector
640
+ else env_make,
641
+ policy_to_send,
642
+ ),
643
+ kwargs={
644
+ "policy_factory": policy_factory[i],
645
+ "frames_per_batch": frames_per_batch,
646
+ "total_frames": -1,
647
+ "split_trajs": False,
648
+ "weight_recv_schemes": self._weight_sync_schemes,
649
+ "worker_idx": i,
650
+ **collector_kwargs[i],
651
+ },
652
+ )
653
+ collector_rrefs.append(collector_rref)
654
+
655
+ # Set up receiver schemes on remote collectors (if using new weight sync system)
656
+ # This enables cascading: RPC -> MultiSync -> Sync
657
+ if getattr(self, "_weight_sync_schemes", None) is not None:
658
+ for i in range(num_workers):
659
+ torchrl_logger.debug(
660
+ f"Setting up receiver schemes on remote collector {i}"
661
+ )
662
+ # Call register_scheme_receiver on the remote collector using rref.rpc_sync()
663
+ # This properly dereferences the rref and calls the instance method
664
+ collector_rrefs[i].rpc_sync().register_scheme_receiver(
665
+ self._weight_sync_schemes
666
+ )
667
+
668
+ futures = collections.deque(maxlen=self.num_workers)
669
+
670
+ if not self._sync:
671
+ for i in range(num_workers):
672
+ torchrl_logger.debug("Asking for the first batch")
673
+ # Use rref.rpc_async() to properly call instance method
674
+ future = collector_rrefs[i].rpc_async().next()
675
+ futures.append((future, i))
676
+ self.futures = futures
677
+ self.collector_rrefs = collector_rrefs
678
+ self.collector_infos = collector_infos
679
+
680
+ def _init_worker_rpc(self, executor, i):
681
+ """Init RPC node if necessary."""
682
+ visible_device = (
683
+ self.visible_devices[i] if self.visible_devices is not None else None
684
+ )
685
+ if self.launcher == "submitit":
686
+ if not _has_submitit:
687
+ raise ImportError("submitit not found.") from SUBMITIT_ERR
688
+ job = executor.submit(
689
+ _rpc_init_collection_node,
690
+ i + 1,
691
+ self.IPAddr,
692
+ self.tcp_port,
693
+ self.num_workers + 1,
694
+ visible_device,
695
+ self.tensorpipe_options,
696
+ self.backend,
697
+ self._VERBOSE,
698
+ )
699
+ torchrl_logger.debug(f"job id {job.job_id}") # ID of your job
700
+ return job
701
+ elif self.launcher == "mp":
702
+ job = _ProcessNoWarn(
703
+ target=_rpc_init_collection_node,
704
+ args=(
705
+ i + 1,
706
+ self.IPAddr,
707
+ self.tcp_port,
708
+ self.num_workers + 1,
709
+ visible_device,
710
+ self.tensorpipe_options,
711
+ self.backend,
712
+ self._VERBOSE,
713
+ ),
714
+ )
715
+ job.start()
716
+ return job
717
+ elif self.launcher == "submitit_delayed":
718
+ # job is already launched
719
+ return None
720
+ else:
721
+ raise NotImplementedError(f"Unknown launcher {self.launcher}")
722
+
723
+ def _init(self):
724
+ self._shutdown = False
725
+ if self.launcher == "submitit":
726
+ executor = submitit.AutoExecutor(folder="log_test")
727
+ executor.update_parameters(**self.slurm_kwargs)
728
+ else:
729
+ executor = None
730
+
731
+ hostname = socket.gethostname()
732
+ if self.launcher != "mp":
733
+ IPAddr = socket.gethostbyname(hostname)
734
+ else:
735
+ IPAddr = "localhost"
736
+ self.IPAddr = IPAddr
737
+
738
+ os.environ["MASTER_ADDR"] = str(self.IPAddr)
739
+ os.environ["MASTER_PORT"] = str(self.tcp_port)
740
+
741
+ self.jobs = []
742
+ for i in range(self.num_workers):
743
+ torchrl_logger.debug(f"Submitting job {i}")
744
+ job = self._init_worker_rpc(
745
+ executor,
746
+ i,
747
+ )
748
+ self.jobs.append(job)
749
+
750
+ self._init_master_rpc(
751
+ self.num_workers + 1,
752
+ )
753
+ self._start_workers(
754
+ world_size=self.num_workers + 1,
755
+ env_constructors=self.env_constructors,
756
+ collector_class=self.collector_class,
757
+ num_workers_per_collector=self.num_workers_per_collector,
758
+ policy=self.policy,
759
+ policy_factory=self.policy_factory,
760
+ frames_per_batch=self._frames_per_batch_corrected,
761
+ total_frames=self.total_frames,
762
+ collector_kwargs=self.collector_kwargs,
763
+ )
764
+
765
+ def iterator(self):
766
+ self._collected_frames = 0
767
+ while self._collected_frames < self.total_frames:
768
+ if self._sync:
769
+ data = self._next_sync_rpc()
770
+ else:
771
+ data = self._next_async_rpc()
772
+
773
+ if self.split_trajs:
774
+ data = split_trajectories(data)
775
+ if self.postproc is not None:
776
+ data = self.postproc(data)
777
+ yield data
778
+
779
+ if self.max_weight_update_interval > -1 and not self._sync:
780
+ for j in range(self.num_workers):
781
+ if (
782
+ self._batches_since_weight_update[j]
783
+ > self.max_weight_update_interval
784
+ ):
785
+ torchrl_logger.debug(
786
+ f"Updating policy of worker {j} with wait=False"
787
+ )
788
+ self.update_policy_weights_(worker_ids=[j], wait=False)
789
+ elif self.max_weight_update_interval > -1:
790
+ ranks = [
791
+ 1
792
+ for j in range(self.num_workers)
793
+ if self._batches_since_weight_update[j]
794
+ > self.max_weight_update_interval
795
+ ]
796
+ torchrl_logger.debug(
797
+ f"Updating policy of workers {ranks} with wait=True"
798
+ )
799
+ self.update_policy_weights_(worker_ids=ranks, wait=True)
800
+
801
+ def _next_async_rpc(self):
802
+ torchrl_logger.debug("next async")
803
+ if not len(self.futures):
804
+ raise StopIteration(
805
+ f"The queue is empty, the collector has ran out of data after {self._collected_frames} collected frames."
806
+ )
807
+ while True:
808
+ future, i = self.futures.popleft()
809
+ if future.done():
810
+ if self.update_after_each_batch:
811
+ self.update_policy_weights_(worker_ids=(i,), wait=False)
812
+ torchrl_logger.debug(f"future {i} is done")
813
+ data = future.value()
814
+ self._collected_frames += data.numel()
815
+ if self._collected_frames < self.total_frames:
816
+ # Use rref.rpc_async() to properly call instance method
817
+ future = self.collector_rrefs[i].rpc_async().next()
818
+ self.futures.append((future, i))
819
+ return data
820
+ self.futures.append((future, i))
821
+
822
+ def _next_sync_rpc(self):
823
+ torchrl_logger.debug("next sync: futures")
824
+ if self.update_after_each_batch:
825
+ self.update_policy_weights_()
826
+ for i in range(self.num_workers):
827
+ # Use rref.rpc_async() to properly call instance method
828
+ future = self.collector_rrefs[i].rpc_async().next()
829
+ self.futures.append((future, i))
830
+ data = []
831
+ while len(self.futures):
832
+ future, i = self.futures.popleft()
833
+ # the order is NOT guaranteed: should we change that?
834
+ if future.done():
835
+ data += [future.value()]
836
+ torchrl_logger.debug(
837
+ f"got data from {i} // data has len {len(data)} / {self.num_workers}"
838
+ )
839
+ else:
840
+ self.futures.append((future, i))
841
+ data = torch.cat(data)
842
+ traj_ids = data.get(("collector", "traj_ids"), None)
843
+ if traj_ids is not None:
844
+ for i in range(1, self.num_workers):
845
+ traj_ids[i] += traj_ids[i - 1].max()
846
+ data.set_(("collector", "traj_ids"), traj_ids)
847
+ self._collected_frames += data.numel()
848
+ return data
849
+
850
+ def set_seed(self, seed: int, static_seed: bool = False) -> int:
851
+ for worker in self.collector_infos:
852
+ seed = rpc.rpc_sync(worker, self.collector_class.set_seed, args=(seed,))
853
+
854
+ def state_dict(self) -> OrderedDict:
855
+ raise NotImplementedError
856
+
857
+ def load_state_dict(self, state_dict: OrderedDict) -> None:
858
+ raise NotImplementedError
859
+
860
+ def shutdown(self, timeout: float | None = None) -> None:
861
+ if not hasattr(self, "_shutdown"):
862
+ warnings.warn("shutdown has no effect has `_init` has not been called yet.")
863
+ return
864
+ if self._shutdown:
865
+ return
866
+
867
+ torchrl_logger.debug("shutting down")
868
+ for future, i in self.futures:
869
+ # clear the futures
870
+ while future is not None and not future.done():
871
+ torchrl_logger.debug(f"waiting for proc {i} to clear")
872
+ future.wait()
873
+ for i in range(self.num_workers):
874
+ torchrl_logger.debug(f"shutting down {i}")
875
+ # Use rref.rpc_sync() to properly call instance method
876
+ self.collector_rrefs[i].rpc_sync(timeout=int(IDLE_TIMEOUT)).shutdown()
877
+ torchrl_logger.debug("rpc shutdown")
878
+ rpc.shutdown(timeout=int(IDLE_TIMEOUT))
879
+
880
+ if self.launcher == "mp":
881
+ for job in self.jobs:
882
+ job.join(int(IDLE_TIMEOUT))
883
+ elif self.launcher == "submitit":
884
+ for job in self.jobs:
885
+ _ = job.result()
886
+ elif self.launcher == "submitit_delayed":
887
+ pass
888
+ else:
889
+ raise NotImplementedError(f"Unknown launcher {self.launcher}")
890
+
891
+ # Clean up weight sync schemes AFTER workers have exited
892
+ if getattr(self, "_weight_sync_schemes", None) is not None:
893
+ torchrl_logger.debug("shutting down weight sync schemes")
894
+ for scheme in self._weight_sync_schemes.values():
895
+ try:
896
+ scheme.shutdown()
897
+ except Exception as e:
898
+ torchrl_logger.warning(
899
+ f"Error shutting down weight sync scheme: {e}"
900
+ )
901
+ self._weight_sync_schemes = None
902
+
903
+ # Destroy torch.distributed process group
904
+ if torch.distributed.is_initialized():
905
+ torch.distributed.destroy_process_group()
906
+
907
+ self._shutdown = True
908
+
909
+
910
+ class RPCWeightUpdater(WeightUpdaterBase):
911
+ """A remote weight updater for synchronizing policy weights across remote workers using RPC.
912
+
913
+ The `RPCWeightUpdater` class provides a mechanism for updating the weights of a policy
914
+ across remote inference workers using RPC. It is designed to work with the :class:`~torchrl.collectors.distributed.RPCDataCollector`
915
+ to ensure that each worker receives the latest policy weights.
916
+ This class is typically used in distributed data collection scenarios where remote workers
917
+ are managed via RPC and need to be kept in sync with the central policy weights.
918
+
919
+ Args:
920
+ collector_infos: Information about the collectors, used for RPC communication.
921
+ collector_class: The class of the collectors being used.
922
+ collector_rrefs: Remote references to the collectors.
923
+ policy_weights (TensorDictBase): The current weights of the policy that need to be distributed
924
+ to the workers.
925
+ num_workers (int): The number of remote workers that will receive the updated policy weights.
926
+
927
+ Methods:
928
+ update_weights: Updates the weights on specified or all remote workers using RPC.
929
+ all_worker_ids: Returns a list of all worker identifiers (not implemented in this class).
930
+ _sync_weights_with_worker: Synchronizes the server weights with a specific worker (not implemented).
931
+ _get_server_weights: Retrieves the latest weights from the server (not implemented).
932
+ _maybe_map_weights: Optionally maps server weights before distribution (not implemented).
933
+
934
+ .. note::
935
+ This class assumes that the server weights can be directly applied to the remote workers
936
+ without any additional processing. If your use case requires more complex weight mapping or
937
+ synchronization logic, consider extending `WeightUpdaterBase` with a custom implementation.
938
+
939
+ .. seealso:: :class:`~torchrl.collectors.WeightUpdaterBase` and
940
+ :class:`~torchrl.collectors.distributed.RPCDataCollector`.
941
+
942
+ """
943
+
944
+ _VERBOSE = VERBOSE # for debugging
945
+
946
+ def __init__(
947
+ self,
948
+ collector_infos,
949
+ collector_class,
950
+ collector_rrefs,
951
+ policy_weights: TensorDictBase,
952
+ num_workers: int,
953
+ ):
954
+ super().__init__()
955
+ self.collector_infos = collector_infos
956
+ self.collector_class = collector_class
957
+ self.collector_rrefs = collector_rrefs
958
+ self.policy_weights = policy_weights
959
+ self.num_workers = num_workers
960
+
961
+ def _sync_weights_with_worker(
962
+ self, worker_id: int | torch.device, server_weights: TensorDictBase
963
+ ) -> TensorDictBase:
964
+ raise NotImplementedError
965
+
966
+ def _get_server_weights(self) -> TensorDictBase:
967
+ raise NotImplementedError
968
+
969
+ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
970
+ raise NotImplementedError
971
+
972
+ def all_worker_ids(self) -> list[int] | list[torch.device]:
973
+ raise NotImplementedError
974
+
975
+ def push_weights(
976
+ self,
977
+ weights: TensorDictBase | None = None,
978
+ worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
979
+ **kwargs,
980
+ ):
981
+ workers = worker_ids
982
+ if isinstance(workers, int):
983
+ workers = [workers]
984
+ if workers is None:
985
+ workers = list(range(self.num_workers))
986
+ else:
987
+ workers = list(workers)
988
+ futures = []
989
+ weights = self.policy_weights if weights is None else weights
990
+ for i in workers:
991
+ torchrl_logger.debug(f"calling update on worker {i}")
992
+ # Use rref.rpc_async() to properly call instance method
993
+ futures.append(
994
+ self.collector_rrefs[i].rpc_async().update_policy_weights_(weights)
995
+ )
996
+ if kwargs.get("wait", True):
997
+ for i in workers:
998
+ torchrl_logger.debug(f"waiting for worker {i}")
999
+ futures[i].wait()
1000
+ torchrl_logger.debug("got it!")
1001
+
1002
+
1003
+ class RPCDataCollector(RPCCollector, metaclass=_LegacyCollectorMeta):
1004
+ """Deprecated version of :class:`~torchrl.collectors.distributed.RPCCollector`."""
1005
+
1006
+ ...