torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,1805 @@
1
+ from __future__ import annotations
2
+
3
+ import _pickle
4
+ import abc
5
+
6
+ import contextlib
7
+ import sys
8
+ import warnings
9
+ from collections import OrderedDict
10
+ from collections.abc import Callable, Mapping, Sequence
11
+ from typing import Any
12
+
13
+ import numpy as np
14
+ import torch
15
+ from tensordict import TensorDict, TensorDictBase
16
+ from tensordict.nn import CudaGraphModule, TensorDictModule
17
+ from tensordict.utils import _zip_strict
18
+ from torch import multiprocessing as mp, nn
19
+ from torchrl import logger as torchrl_logger
20
+ from torchrl._utils import (
21
+ _check_for_faulty_process,
22
+ _get_mp_ctx,
23
+ _make_process_no_warn_cls,
24
+ _mp_sharing_strategy_for_spawn,
25
+ _set_mp_start_method_if_unset,
26
+ RL_WARNINGS,
27
+ )
28
+ from torchrl.collectors._base import BaseCollector
29
+ from torchrl.collectors._constants import (
30
+ _InterruptorManager,
31
+ _is_osx,
32
+ DEFAULT_EXPLORATION_TYPE,
33
+ ExplorationType,
34
+ INSTANTIATE_TIMEOUT,
35
+ )
36
+ from torchrl.collectors._runner import _main_async_collector
37
+ from torchrl.collectors._single import Collector
38
+ from torchrl.collectors.utils import _make_meta_policy_cm, _TrajectoryPool
39
+ from torchrl.collectors.weight_update import WeightUpdaterBase
40
+ from torchrl.data import ReplayBuffer
41
+ from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
42
+ from torchrl.envs import EnvBase, EnvCreator
43
+ from torchrl.envs.llm.transforms import PolicyVersion
44
+ from torchrl.weight_update import (
45
+ MultiProcessWeightSyncScheme,
46
+ SharedMemWeightSyncScheme,
47
+ WeightSyncScheme,
48
+ )
49
+ from torchrl.weight_update.utils import _resolve_model
50
+
51
+
52
+ class _MultiCollectorMeta(abc.ABCMeta):
53
+ """Metaclass for MultiCollector that dispatches based on sync parameter.
54
+
55
+ When MultiCollector is instantiated with sync=True or sync=False, the metaclass
56
+ intercepts the call and returns the appropriate subclass instance:
57
+ - sync=True: returns MultiSyncCollector (alias: MultiSyncCollector)
58
+ - sync=False: returns MultiAsyncCollector (alias: MultiAsyncCollector)
59
+ """
60
+
61
+ def __call__(cls, *args, sync: bool | None = None, **kwargs):
62
+ # Only dispatch if we're instantiating MultiCollector directly (not a subclass)
63
+ # and sync is explicitly provided
64
+ if cls.__name__ == "MultiCollector" and sync is not None:
65
+ if sync:
66
+ from torchrl.collectors._multi_sync import MultiSyncCollector
67
+
68
+ return MultiSyncCollector(*args, **kwargs)
69
+ else:
70
+ from torchrl.collectors._multi_async import MultiAsyncCollector
71
+
72
+ return MultiAsyncCollector(*args, **kwargs)
73
+ return super().__call__(*args, **kwargs)
74
+
75
+
76
+ class MultiCollector(BaseCollector, metaclass=_MultiCollectorMeta):
77
+ """Runs a given number of DataCollectors on separate processes.
78
+
79
+ Args:
80
+ create_env_fn (List[Callabled]): list of Callables, each returning an
81
+ instance of :class:`~torchrl.envs.EnvBase`.
82
+ policy (Callable): Policy to be executed in the environment.
83
+ Must accept :class:`tensordict.tensordict.TensorDictBase` object as input.
84
+ If ``None`` is provided (default), the policy used will be a
85
+ :class:`~torchrl.collectors.RandomPolicy` instance with the environment
86
+ ``action_spec``.
87
+ Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`.
88
+ This is the recommended usage of the collector.
89
+ Other callables are accepted too:
90
+ If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module`
91
+ instances) it will be wrapped in a `nn.Module` first.
92
+ Then, the collector will try to assess if these
93
+ modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
94
+
95
+ - If the policy forward signature matches any of ``forward(self, tensordict)``,
96
+ ``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
97
+ any typing with a single argument typed as a subclass of ``TensorDictBase``)
98
+ then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
99
+
100
+ - In all other cases an attempt to wrap it will be undergone as such:
101
+ ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
102
+
103
+ .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
104
+ pickled directly), the ``policy_factory`` should be used instead.
105
+
106
+ .. note:: When using ``weight_sync_schemes``, both ``policy`` and ``policy_factory`` can be provided together.
107
+ In this case, the ``policy`` is used ONLY for weight extraction (via ``TensorDict.from_module()``) to
108
+ set up weight synchronization, but it is NOT sent to workers and its weights are NOT depopulated.
109
+ The ``policy_factory`` is what actually gets passed to workers to create their local policy instances.
110
+ This is useful when the policy is hard to serialize but you have a copy on the main node for
111
+ weight synchronization purposes.
112
+
113
+ Keyword Args:
114
+ sync (bool, optional): if ``True``, the collector will run in sync mode (:class:`~torchrl.collectors.MultiSyncCollector`). If
115
+ `False`, the collector will run in async mode (:class:`~torchrl.collectors.MultiAsyncCollector`).
116
+ policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
117
+ (or list of callables) that returns a policy instance.
118
+
119
+ When not using ``weight_sync_schemes``, this is mutually exclusive with the ``policy`` argument.
120
+
121
+ When using ``weight_sync_schemes``, both ``policy`` and ``policy_factory`` can be provided:
122
+ the ``policy`` is used for weight extraction only, while ``policy_factory`` creates policies on workers.
123
+
124
+ .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
125
+
126
+ .. warning:: `policy_factory` is currently not compatible with multiprocessed data
127
+ collectors.
128
+
129
+ num_workers (int, optional): number of workers to use. If `create_env_fn` is a list, this will be ignored.
130
+ Defaults to `None` (workers determined by the `create_env_fn` length).
131
+ frames_per_batch (int, Sequence[int]): A keyword-only argument representing the
132
+ total number of elements in a batch. If a sequence is provided, represents the number of elements in a
133
+ batch per worker. Total number of elements in a batch is then the sum over the sequence.
134
+ total_frames (int, optional): A keyword-only argument representing the
135
+ total number of frames returned by the collector
136
+ during its lifespan. If the ``total_frames`` is not divisible by
137
+ ``frames_per_batch``, an exception is raised.
138
+ Endless collectors can be created by passing ``total_frames=-1``.
139
+ Defaults to ``-1`` (never ending collector).
140
+ device (int, str or torch.device, optional): The generic device of the
141
+ collector. The ``device`` args fills any non-specified device: if
142
+ ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
143
+ ``env_device`` is not specified, its value will be set to ``device``.
144
+ Defaults to ``None`` (No default device).
145
+ Supports a list of devices if one wishes to indicate a different device
146
+ for each worker. The list must be as long as the number of workers.
147
+ storing_device (int, str or torch.device, optional): The device on which
148
+ the output :class:`~tensordict.TensorDict` will be stored.
149
+ If ``device`` is passed and ``storing_device`` is ``None``, it will
150
+ default to the value indicated by ``device``.
151
+ For long trajectories, it may be necessary to store the data on a different
152
+ device than the one where the policy and env are executed.
153
+ Defaults to ``None`` (the output tensordict isn't on a specific device,
154
+ leaf tensors sit on the device where they were created).
155
+ Supports a list of devices if one wishes to indicate a different device
156
+ for each worker. The list must be as long as the number of workers.
157
+ env_device (int, str or torch.device, optional): The device on which
158
+ the environment should be cast (or executed if that functionality is
159
+ supported). If not specified and the env has a non-``None`` device,
160
+ ``env_device`` will default to that value. If ``device`` is passed
161
+ and ``env_device=None``, it will default to ``device``. If the value
162
+ as such specified of ``env_device`` differs from ``policy_device``
163
+ and one of them is not ``None``, the data will be cast to ``env_device``
164
+ before being passed to the env (i.e., passing different devices to
165
+ policy and env is supported). Defaults to ``None``.
166
+ Supports a list of devices if one wishes to indicate a different device
167
+ for each worker. The list must be as long as the number of workers.
168
+ policy_device (int, str or torch.device, optional): The device on which
169
+ the policy should be cast.
170
+ If ``device`` is passed and ``policy_device=None``, it will default
171
+ to ``device``. If the value as such specified of ``policy_device``
172
+ differs from ``env_device`` and one of them is not ``None``,
173
+ the data will be cast to ``policy_device`` before being passed to
174
+ the policy (i.e., passing different devices to policy and env is
175
+ supported). Defaults to ``None``.
176
+ Supports a list of devices if one wishes to indicate a different device
177
+ for each worker. The list must be as long as the number of workers.
178
+ create_env_kwargs (dict, optional): A dictionary with the
179
+ keyword arguments used to create an environment. If a list is
180
+ provided, each of its elements will be assigned to a sub-collector.
181
+ collector_class (Python class or constructor): a collector class to be remotely instantiated. Can be
182
+ :class:`~torchrl.collectors.Collector`,
183
+ :class:`~torchrl.collectors.MultiSyncCollector`,
184
+ :class:`~torchrl.collectors.MultiAsyncCollector`
185
+ or a derived class of these.
186
+ Defaults to :class:`~torchrl.collectors.Collector`.
187
+ max_frames_per_traj (int, optional): Maximum steps per trajectory.
188
+ Note that a trajectory can span across multiple batches (unless
189
+ ``reset_at_each_iter`` is set to ``True``, see below).
190
+ Once a trajectory reaches ``n_steps``, the environment is reset.
191
+ If the environment wraps multiple environments together, the number
192
+ of steps is tracked for each environment independently. Negative
193
+ values are allowed, in which case this argument is ignored.
194
+ Defaults to ``None`` (i.e. no maximum number of steps).
195
+ init_random_frames (int, optional): Number of frames for which the
196
+ policy is ignored before it is called. This feature is mainly
197
+ intended to be used in offline/model-based settings, where a
198
+ batch of random trajectories can be used to initialize training.
199
+ If provided, it will be rounded up to the closest multiple of frames_per_batch.
200
+ Defaults to ``None`` (i.e. no random frames).
201
+ reset_at_each_iter (bool, optional): Whether environments should be reset
202
+ at the beginning of a batch collection.
203
+ Defaults to ``False``.
204
+ postproc (Callable, optional): A post-processing transform, such as
205
+ a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
206
+ instance.
207
+ Defaults to ``None``.
208
+ split_trajs (bool, optional): Boolean indicating whether the resulting
209
+ TensorDict should be split according to the trajectories.
210
+ See :func:`~torchrl.collectors.utils.split_trajectories` for more
211
+ information.
212
+ Defaults to ``False``.
213
+ exploration_type (ExplorationType, optional): interaction mode to be used when
214
+ collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
215
+ ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
216
+ or ``torchrl.envs.utils.ExplorationType.MEAN``.
217
+ reset_when_done (bool, optional): if ``True`` (default), an environment
218
+ that return a ``True`` value in its ``"done"`` or ``"truncated"``
219
+ entry will be reset at the corresponding indices.
220
+ update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()`
221
+ will be called before (sync) or after (async) each data collection.
222
+ Defaults to ``False``.
223
+ preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
224
+ that will be allowed to finished collecting their rollout before the rest are forced to end early.
225
+ num_threads (int, optional): number of threads for this process.
226
+ Defaults to the number of workers.
227
+ num_sub_threads (int, optional): number of threads of the subprocesses.
228
+ Should be equal to one plus the number of processes launched within
229
+ each subprocess (or one if a single process is launched).
230
+ Defaults to 1 for safety: if none is indicated, launching multiple
231
+ workers may charge the cpu load too much and harm performance.
232
+ cat_results (str, int or None): (:class:`~torchrl.collectors.MultiSyncCollector` exclusively).
233
+ If ``"stack"``, the data collected from the workers will be stacked along the
234
+ first dimension. This is the preferred behavior as it is the most compatible
235
+ with the rest of the library.
236
+ If ``0``, results will be concatenated along the first dimension
237
+ of the outputs, which can be the batched dimension if the environments are
238
+ batched or the time dimension if not.
239
+ A ``cat_results`` value of ``-1`` will always concatenate results along the
240
+ time dimension. This should be preferred over the default. Intermediate values
241
+ are also accepted.
242
+ Defaults to ``"stack"``.
243
+
244
+ .. note:: From v0.5, this argument will default to ``"stack"`` for a better
245
+ interoperability with the rest of the library.
246
+
247
+ set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding
248
+ ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of
249
+ a rollout is reached. If no ``"truncated"`` key is found, an exception is raised.
250
+ Truncated keys can be set through ``env.add_truncated_keys``.
251
+ Defaults to ``False``.
252
+ use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data.
253
+ This isn't compatible with environments with dynamic specs. Defaults to ``True``
254
+ for envs without dynamic specs, ``False`` for others.
255
+ replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts
256
+ but populate the buffer instead. Defaults to ``None``.
257
+ extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not
258
+ with single steps. Defaults to `True` for multiprocessed data collectors.
259
+ local_init_rb (bool, optional): if ``False``, the collector will use fake data to initialize
260
+ the replay buffer in the main process (legacy behavior). If ``True``, the storage-level
261
+ coordination will handle initialization with real data from worker processes.
262
+ Defaults to ``None``, which maintains backward compatibility but shows a deprecation warning.
263
+ This parameter is deprecated and will be removed in v0.12.
264
+ trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
265
+ assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules
266
+ and ``False`` otherwise.
267
+ compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled
268
+ using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it
269
+ will be used to compile the policy.
270
+ cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
271
+ in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
272
+ If a dictionary of kwargs is passed, it will be used to wrap the policy.
273
+ no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
274
+ For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_
275
+ or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected
276
+ crashes.
277
+ Defaults to ``False``.
278
+ weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase`
279
+ or its subclass, responsible for updating the policy weights on remote inference workers.
280
+ If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default,
281
+ which handles weight synchronization across multiple processes.
282
+ Consider using a constructor if the updater needs to be serialized.
283
+ weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for
284
+ SENDING weights to worker sub-collectors. Keys are model identifiers (e.g., "policy")
285
+ and values are WeightSyncScheme instances configured to send weights to child processes.
286
+ If not provided, a :class:`~torchrl.collectors.MultiProcessWeightSyncScheme` will be used by default.
287
+ This is for propagating weights DOWN the hierarchy (parent -> children).
288
+ weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for
289
+ RECEIVING weights from parent collectors. Keys are model identifiers (e.g., "policy")
290
+ and values are WeightSyncScheme instances configured to receive weights.
291
+ This enables cascading in hierarchies like: RPCDataCollector -> MultiSyncCollector -> Collector.
292
+ Received weights are automatically propagated to sub-collectors if matching model_ids exist.
293
+ Defaults to ``None``.
294
+ track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
295
+ This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
296
+ Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
297
+ the policy version.
298
+ Defaults to `False`.
299
+ worker_idx (int, optional): the index of the worker.
300
+
301
+ Examples:
302
+ >>> from torchrl.collectors import MultiCollector
303
+ >>> from torchrl.envs import GymEnv
304
+ >>>
305
+ >>> def make_env():
306
+ ... return GymEnv("CartPole-v1")
307
+ >>>
308
+ >>> # Synchronous collection (for on-policy algorithms like PPO)
309
+ >>> sync_collector = MultiCollector(
310
+ ... create_env_fn=[make_env] * 4, # 4 parallel workers
311
+ ... policy=my_policy,
312
+ ... frames_per_batch=1000,
313
+ ... total_frames=100000,
314
+ ... sync=True, # All workers complete before batch is delivered
315
+ ... )
316
+ >>>
317
+ >>> # Asynchronous collection (for off-policy algorithms like SAC)
318
+ >>> async_collector = MultiCollector(
319
+ ... create_env_fn=[make_env] * 4,
320
+ ... policy=my_policy,
321
+ ... frames_per_batch=1000,
322
+ ... total_frames=100000,
323
+ ... sync=False, # First-come-first-serve delivery
324
+ ... )
325
+ >>>
326
+ >>> # Iterate over collected data
327
+ >>> for data in sync_collector:
328
+ ... # data is a TensorDict with collected transitions
329
+ ... pass
330
+ >>> sync_collector.shutdown()
331
+
332
+ """
333
+
334
+ def __init__(
335
+ self,
336
+ create_env_fn: Sequence[Callable[[], EnvBase]],
337
+ policy: None
338
+ | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None,
339
+ *,
340
+ num_workers: int | None = None,
341
+ policy_factory: Callable[[], Callable]
342
+ | list[Callable[[], Callable]]
343
+ | None = None,
344
+ frames_per_batch: int | Sequence[int],
345
+ total_frames: int | None = -1,
346
+ device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
347
+ storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
348
+ env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
349
+ policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
350
+ create_env_kwargs: Sequence[dict] | None = None,
351
+ collector_class: type | Callable[[], BaseCollector] | None = None,
352
+ max_frames_per_traj: int | None = None,
353
+ init_random_frames: int | None = None,
354
+ reset_at_each_iter: bool = False,
355
+ postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
356
+ split_trajs: bool | None = None,
357
+ exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
358
+ reset_when_done: bool = True,
359
+ update_at_each_batch: bool = False,
360
+ preemptive_threshold: float | None = None,
361
+ num_threads: int | None = None,
362
+ num_sub_threads: int = 1,
363
+ cat_results: str | int | None = None,
364
+ set_truncated: bool = False,
365
+ use_buffers: bool | None = None,
366
+ replay_buffer: ReplayBuffer | None = None,
367
+ extend_buffer: bool = True,
368
+ local_init_rb: bool | None = None,
369
+ trust_policy: bool | None = None,
370
+ compile_policy: bool | dict[str, Any] | None = None,
371
+ cudagraph_policy: bool | dict[str, Any] | None = None,
372
+ no_cuda_sync: bool = False,
373
+ weight_updater: WeightUpdaterBase
374
+ | Callable[[], WeightUpdaterBase]
375
+ | None = None,
376
+ weight_sync_schemes: dict[str, WeightSyncScheme] | None = None,
377
+ weight_recv_schemes: dict[str, WeightSyncScheme] | None = None,
378
+ track_policy_version: bool = False,
379
+ worker_idx: int | None = None,
380
+ ):
381
+ self.closed = True
382
+ self.worker_idx = worker_idx
383
+
384
+ # Set up workers and environment functions
385
+ create_env_fn, total_frames_per_batch = self._setup_workers_and_env_fns(
386
+ create_env_fn, num_workers, frames_per_batch
387
+ )
388
+
389
+ # Set up basic configuration
390
+ self.set_truncated = set_truncated
391
+ self.num_sub_threads = num_sub_threads
392
+ self.num_threads = num_threads
393
+ self.create_env_fn = create_env_fn
394
+ self._read_compile_kwargs(compile_policy, cudagraph_policy)
395
+
396
+ # Set up environment kwargs
397
+ self.create_env_kwargs = self._setup_env_kwargs(create_env_kwargs)
398
+
399
+ # Set up devices
400
+ storing_devices, policy_devices, env_devices = self._get_devices(
401
+ storing_device=storing_device,
402
+ env_device=env_device,
403
+ policy_device=policy_device,
404
+ device=device,
405
+ )
406
+ self.storing_device = storing_devices
407
+ self.policy_device = policy_devices
408
+ self.env_device = env_devices
409
+ self.collector_class = collector_class
410
+ del storing_device, env_device, policy_device, device
411
+ self.no_cuda_sync = no_cuda_sync
412
+
413
+ # Set up replay buffer
414
+ self._use_buffers = use_buffers
415
+ self.replay_buffer = replay_buffer
416
+ self._setup_multi_replay_buffer(local_init_rb, replay_buffer, extend_buffer)
417
+
418
+ # Set up policy and weights
419
+ if trust_policy is None:
420
+ trust_policy = policy is not None and isinstance(policy, CudaGraphModule)
421
+ self.trust_policy = trust_policy
422
+
423
+ policy_factory = self._setup_policy_factory(policy_factory)
424
+
425
+ # Set up weight synchronization
426
+ if weight_sync_schemes is None and weight_updater is None:
427
+ weight_sync_schemes = {}
428
+ elif weight_sync_schemes is not None and weight_updater is not None:
429
+ raise TypeError(
430
+ "Cannot specify both weight_sync_schemes and weight_updater."
431
+ )
432
+ if (
433
+ weight_sync_schemes is not None
434
+ and not weight_sync_schemes
435
+ and weight_updater is None
436
+ and (isinstance(policy, nn.Module) or any(policy_factory))
437
+ ):
438
+ # Set up a default local shared-memory sync scheme for the policy.
439
+ # This is used to propagate weights from the orchestrator policy
440
+ # (possibly combined with a policy_factory) down to worker policies.
441
+ weight_sync_schemes["policy"] = SharedMemWeightSyncScheme()
442
+
443
+ self._setup_multi_weight_sync(weight_updater, weight_sync_schemes)
444
+
445
+ # Store policy and policy_factory - temporary set to make them visible to the receiver
446
+ self.policy = policy
447
+ self.policy_factory = policy_factory
448
+
449
+ # Set up weight receivers if provided
450
+ if weight_recv_schemes is not None:
451
+ self.register_scheme_receiver(weight_recv_schemes)
452
+
453
+ self._setup_multi_policy_and_weights(
454
+ self.policy, self.policy_factory, weight_updater, weight_sync_schemes
455
+ )
456
+
457
+ # Set up policy version tracking
458
+ self._setup_multi_policy_version_tracking(track_policy_version)
459
+
460
+ # # Set up fallback policy for weight extraction
461
+ # self._setup_fallback_policy(policy, policy_factory, weight_sync_schemes)
462
+
463
+ # Set up total frames and other parameters
464
+ self._setup_multi_total_frames(
465
+ total_frames, total_frames_per_batch, frames_per_batch
466
+ )
467
+ self.reset_at_each_iter = reset_at_each_iter
468
+ self.postprocs = postproc
469
+ self.max_frames_per_traj = (
470
+ int(max_frames_per_traj) if max_frames_per_traj is not None else 0
471
+ )
472
+
473
+ # Set up split trajectories
474
+ self.requested_frames_per_batch = total_frames_per_batch
475
+ self.reset_when_done = reset_when_done
476
+ self._setup_split_trajs(split_trajs, reset_when_done)
477
+
478
+ # Set up other parameters
479
+ self.init_random_frames = (
480
+ int(init_random_frames) if init_random_frames is not None else 0
481
+ )
482
+ self.update_at_each_batch = update_at_each_batch
483
+ self.exploration_type = exploration_type
484
+ self.frames_per_worker = np.inf
485
+
486
+ # Set up preemptive threshold
487
+ self._setup_preemptive_threshold(preemptive_threshold)
488
+
489
+ # Run worker processes
490
+ try:
491
+ self._run_processes()
492
+ except Exception as e:
493
+ self.shutdown(raise_on_error=False)
494
+ raise e
495
+
496
+ # Set up frame tracking and other options
497
+ self._exclude_private_keys = True
498
+ self._frames = 0
499
+ self._iter = -1
500
+
501
+ # Validate cat_results
502
+ self._validate_cat_results(cat_results)
503
+
504
+ def _setup_workers_and_env_fns(
505
+ self,
506
+ create_env_fn: Sequence[Callable] | Callable,
507
+ num_workers: int | None,
508
+ frames_per_batch: int | Sequence[int],
509
+ ) -> tuple[list[Callable], int]:
510
+ """Set up workers and environment functions."""
511
+ if isinstance(create_env_fn, Sequence):
512
+ self.num_workers = len(create_env_fn)
513
+ else:
514
+ self.num_workers = num_workers
515
+ create_env_fn = [create_env_fn] * self.num_workers
516
+
517
+ if (
518
+ isinstance(frames_per_batch, Sequence)
519
+ and len(frames_per_batch) != self.num_workers
520
+ ):
521
+ raise ValueError(
522
+ "If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker."
523
+ f"Got {len(frames_per_batch)} values for {self.num_workers} workers."
524
+ )
525
+
526
+ self._frames_per_batch = frames_per_batch
527
+ total_frames_per_batch = (
528
+ sum(frames_per_batch)
529
+ if isinstance(frames_per_batch, Sequence)
530
+ else frames_per_batch
531
+ )
532
+
533
+ return create_env_fn, total_frames_per_batch
534
+
535
+ def _setup_env_kwargs(
536
+ self, create_env_kwargs: Sequence[dict] | dict | None
537
+ ) -> list[dict]:
538
+ """Set up environment kwargs for each worker."""
539
+ if isinstance(create_env_kwargs, Mapping):
540
+ create_env_kwargs = [create_env_kwargs] * self.num_workers
541
+ elif create_env_kwargs is None:
542
+ create_env_kwargs = [{}] * self.num_workers
543
+ elif isinstance(create_env_kwargs, (tuple, list)):
544
+ create_env_kwargs = list(create_env_kwargs)
545
+ if len(create_env_kwargs) != self.num_workers:
546
+ raise ValueError(
547
+ f"len(create_env_kwargs) must be equal to num_workers, got {len(create_env_kwargs)=} and {self.num_workers=}"
548
+ )
549
+ return create_env_kwargs
550
+
551
+ def _setup_multi_replay_buffer(
552
+ self,
553
+ local_init_rb: bool | None,
554
+ replay_buffer: ReplayBuffer | None,
555
+ extend_buffer: bool,
556
+ ) -> None:
557
+ """Set up replay buffer for multi-process collector."""
558
+ # Handle local_init_rb deprecation
559
+ if local_init_rb is None:
560
+ local_init_rb = False
561
+ if replay_buffer is not None and not local_init_rb:
562
+ warnings.warn(
563
+ "local_init_rb=False is deprecated and will be removed in v0.12. "
564
+ "The new storage-level initialization provides better performance.",
565
+ FutureWarning,
566
+ )
567
+ self.local_init_rb = local_init_rb
568
+
569
+ self._check_replay_buffer_init()
570
+
571
+ self.extend_buffer = extend_buffer
572
+
573
+ if (
574
+ replay_buffer is not None
575
+ and hasattr(replay_buffer, "shared")
576
+ and not replay_buffer.shared
577
+ ):
578
+ torchrl_logger.warning("Replay buffer is not shared. Sharing it.")
579
+ replay_buffer.share()
580
+
581
+ def _setup_policy_factory(
582
+ self, policy_factory: Callable | list[Callable] | None
583
+ ) -> list[Callable | None]:
584
+ """Set up policy factory for each worker."""
585
+ if not isinstance(policy_factory, Sequence):
586
+ policy_factory = [policy_factory] * self.num_workers
587
+ return policy_factory
588
+
589
+ def _setup_multi_policy_and_weights(
590
+ self,
591
+ policy: TensorDictModule | Callable | None,
592
+ policy_factory: list[Callable | None],
593
+ weight_updater: WeightUpdaterBase | Callable | None,
594
+ weight_sync_schemes: dict[str, WeightSyncScheme] | None,
595
+ ) -> None:
596
+ """Set up policy for multi-process collector.
597
+
598
+ With weight sync schemes: validates and stores policy without weight extraction.
599
+ With weight updater: extracts weights and creates stateful policies.
600
+
601
+ When both policy and policy_factory are provided (with weight_sync_schemes):
602
+ - The policy is used ONLY for weight extraction via get_model()
603
+ - The policy is NOT depopulated of its weights
604
+ - The policy is NOT sent to workers
605
+ - The policy_factory is used to create policies on workers
606
+ """
607
+ if any(policy_factory) and policy is not None:
608
+ if weight_sync_schemes is None:
609
+ raise TypeError(
610
+ "policy_factory and policy are mutually exclusive when not using weight_sync_schemes. "
611
+ "When using weight_sync_schemes, policy can be provided alongside policy_factory "
612
+ "for weight extraction purposes only (the policy will not be sent to workers)."
613
+ )
614
+ # Store policy as fallback for weight extraction only
615
+ # The policy keeps its weights and is NOT sent to workers
616
+ self._fallback_policy = policy
617
+
618
+ if weight_sync_schemes is not None:
619
+ weight_sync_policy = weight_sync_schemes.get("policy")
620
+ if weight_sync_policy is None:
621
+ return
622
+ # # If we only have a policy_factory (no policy instance), the scheme must
623
+ # # be pre-initialized on the sender, since there is no policy on the
624
+ # # collector to extract weights from.
625
+ # if any(p is not None for p in policy_factory) and policy is None:
626
+ # if not weight_sync_policy.initialized_on_sender:
627
+ # raise RuntimeError(
628
+ # "the weight sync scheme must be initialized on sender ahead of time "
629
+ # "when passing a policy_factory without a policy instance on the collector. "
630
+ # f"Got {policy_factory=}"
631
+ # )
632
+ # # When a policy instance is provided alongside a policy_factory, the scheme
633
+ # # can rely on the collector context (and its policy) to extract weights.
634
+ # # Weight sync scheme initialization then happens in _run_processes where
635
+ # # pipes and workers are available.
636
+ else:
637
+ # Using legacy weight updater - extract weights and create stateful policies
638
+ self._setup_multi_policy_and_weights_legacy(
639
+ policy, policy_factory, weight_updater, weight_sync_schemes
640
+ )
641
+
642
+ def _setup_multi_policy_and_weights_legacy(
643
+ self,
644
+ policy: TensorDictModule | Callable | None,
645
+ policy_factory: list[Callable | None],
646
+ weight_updater: WeightUpdaterBase | Callable | None,
647
+ weight_sync_schemes: dict[str, WeightSyncScheme] | None,
648
+ ) -> None:
649
+ """Set up policy and extract weights for each device.
650
+
651
+ Creates stateful policies with weights extracted and placed in shared memory.
652
+ Used with weight updater for in-place weight replacement.
653
+ """
654
+ self._policy_weights_dict = {}
655
+ self._fallback_policy = None # Policy to use for weight extraction fallback
656
+
657
+ if not any(policy_factory):
658
+ for policy_device, env_maker, env_maker_kwargs in _zip_strict(
659
+ self.policy_device, self.create_env_fn, self.create_env_kwargs
660
+ ):
661
+ policy_new_device, get_weights_fn = self._get_policy_and_device(
662
+ policy=policy,
663
+ policy_device=policy_device,
664
+ env_maker=env_maker,
665
+ env_maker_kwargs=env_maker_kwargs,
666
+ )
667
+ if type(policy_new_device) is not type(policy):
668
+ policy = policy_new_device
669
+ weights = (
670
+ TensorDict.from_module(policy_new_device)
671
+ if isinstance(policy_new_device, nn.Module)
672
+ else TensorDict()
673
+ )
674
+ # For multi-process collectors, ensure weights are in shared memory
675
+ if policy_device and policy_device.type == "cpu":
676
+ weights = weights.share_memory_()
677
+ self._policy_weights_dict[policy_device] = weights
678
+ # Store the first policy instance for fallback weight extraction
679
+ if self._fallback_policy is None:
680
+ self._fallback_policy = policy_new_device
681
+ self._get_weights_fn = get_weights_fn
682
+ if weight_updater is None:
683
+ # For multiprocessed collectors, use MultiProcessWeightSyncScheme by default
684
+ if weight_sync_schemes is None:
685
+ weight_sync_schemes = {"policy": MultiProcessWeightSyncScheme()}
686
+ self._weight_sync_schemes = weight_sync_schemes
687
+ elif weight_updater is None:
688
+ warnings.warn(
689
+ "weight_updater is None, but policy_factory is provided. This means that the server will "
690
+ "not know how to send the weights to the workers. If the workers can handle their weight synchronization "
691
+ "on their own (via some specialized worker type / constructor) this may well work, but make sure "
692
+ "your weight synchronization strategy is properly set. To suppress this warning, you can use "
693
+ "RemoteModuleWeightUpdater() which enforces explicit weight passing when calling update_policy_weights_(weights). "
694
+ "This will work whenever your inference and training policies are nn.Module instances with similar structures."
695
+ )
696
+
697
+ def _setup_multi_weight_sync(
698
+ self,
699
+ weight_updater: WeightUpdaterBase | Callable | None,
700
+ weight_sync_schemes: dict[str, WeightSyncScheme] | None,
701
+ ) -> None:
702
+ """Set up weight synchronization for multi-process collector."""
703
+ if weight_sync_schemes is not None:
704
+ # Use weight sync schemes for weight distribution
705
+ self._weight_sync_schemes = weight_sync_schemes
706
+ # Senders will be created in _run_processes
707
+ self.weight_updater = None
708
+ else:
709
+ # Use weight updater for weight distribution
710
+ self.weight_updater = weight_updater
711
+ self._weight_sync_schemes = None
712
+
713
+ def _setup_multi_policy_version_tracking(
714
+ self, track_policy_version: bool | PolicyVersion
715
+ ) -> None:
716
+ """Set up policy version tracking for multi-process collector."""
717
+ self.policy_version_tracker = track_policy_version
718
+ if PolicyVersion is not None:
719
+ if isinstance(track_policy_version, bool) and track_policy_version:
720
+ self.policy_version_tracker = PolicyVersion()
721
+ elif hasattr(track_policy_version, "increment_version"):
722
+ self.policy_version_tracker = track_policy_version
723
+ else:
724
+ self.policy_version_tracker = None
725
+ else:
726
+ if track_policy_version:
727
+ raise ImportError(
728
+ "PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False."
729
+ )
730
+ self.policy_version_tracker = None
731
+
732
+ # TODO: Remove this
733
+ def _setup_fallback_policy(
734
+ self,
735
+ policy: TensorDictModule | Callable | None,
736
+ policy_factory: list[Callable | None],
737
+ weight_sync_schemes: dict[str, WeightSyncScheme] | None,
738
+ ) -> None:
739
+ """Set up fallback policy for weight extraction when using policy_factory."""
740
+ # _fallback_policy is already set in _setup_multi_policy_and_weights if a policy was provided
741
+ # If policy_factory was used, create a policy instance to use as fallback
742
+ if policy is None and any(policy_factory) and weight_sync_schemes is not None:
743
+ if not hasattr(self, "_fallback_policy") or self._fallback_policy is None:
744
+ first_factory = (
745
+ policy_factory[0]
746
+ if isinstance(policy_factory, list)
747
+ else policy_factory
748
+ )
749
+ if first_factory is not None:
750
+ # Create a policy instance for weight extraction
751
+ # This will be a reference to a policy with the same structure
752
+ # For shared memory, modifications to any policy will be visible here
753
+ self._fallback_policy = first_factory()
754
+
755
+ def _setup_multi_total_frames(
756
+ self,
757
+ total_frames: int,
758
+ total_frames_per_batch: int,
759
+ frames_per_batch: int | Sequence[int],
760
+ ) -> None:
761
+ """Validate and set total frames for multi-process collector."""
762
+ if total_frames is None or total_frames < 0:
763
+ total_frames = float("inf")
764
+ else:
765
+ remainder = total_frames % total_frames_per_batch
766
+ if remainder != 0 and RL_WARNINGS:
767
+ warnings.warn(
768
+ f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({total_frames_per_batch}). "
769
+ f"This means {total_frames_per_batch - remainder} additional frames will be collected. "
770
+ "To silence this message, set the environment variable RL_WARNINGS to False."
771
+ )
772
+ self.total_frames = (
773
+ int(total_frames) if total_frames != float("inf") else total_frames
774
+ )
775
+
776
+ def _setup_split_trajs(
777
+ self, split_trajs: bool | None, reset_when_done: bool
778
+ ) -> None:
779
+ """Set up split trajectories option."""
780
+ if split_trajs is None:
781
+ split_trajs = False
782
+ elif not reset_when_done and split_trajs:
783
+ raise RuntimeError(
784
+ "Cannot split trajectories when reset_when_done is False."
785
+ )
786
+ self.split_trajs = split_trajs
787
+
788
+ def _setup_preemptive_threshold(self, preemptive_threshold: float | None) -> None:
789
+ """Set up preemptive threshold for early stopping."""
790
+ if preemptive_threshold is not None:
791
+ if _is_osx:
792
+ raise NotImplementedError(
793
+ "Cannot use preemption on OSX due to Queue.qsize() not being implemented on this platform."
794
+ )
795
+ self.preemptive_threshold = np.clip(preemptive_threshold, 0.0, 1.0)
796
+ manager = _InterruptorManager()
797
+ manager.start()
798
+ self.interruptor = manager._Interruptor()
799
+ else:
800
+ self.preemptive_threshold = 1.0
801
+ self.interruptor = None
802
+
803
+ def _should_use_random_frames(self) -> bool:
804
+ """Determine if random frames should be used instead of the policy.
805
+
806
+ When a replay buffer is provided, uses `replay_buffer.write_count` as the
807
+ global step counter to support `.start()` mode where `_frames` isn't updated
808
+ until after collection. Otherwise, uses the internal `_frames` counter.
809
+
810
+ Returns:
811
+ bool: True if random frames should be used, False otherwise.
812
+ """
813
+ if self.init_random_frames is None or self.init_random_frames <= 0:
814
+ return False
815
+ # Use replay_buffer.write_count when available for accurate counting in .start() mode
816
+ if self.replay_buffer is not None:
817
+ return self.replay_buffer.write_count < self.init_random_frames
818
+ return self._frames < self.init_random_frames
819
+
820
+ def _validate_cat_results(self, cat_results: str | int | None) -> None:
821
+ """Validate cat_results parameter."""
822
+ if cat_results is not None and (
823
+ not isinstance(cat_results, (int, str))
824
+ or (isinstance(cat_results, str) and cat_results != "stack")
825
+ ):
826
+ raise ValueError(
827
+ "cat_results must be a string ('stack') "
828
+ f"or an integer representing the cat dimension. Got {cat_results}."
829
+ )
830
+ # Lazy import to avoid circular dependency
831
+ from torchrl.collectors._multi_sync import MultiSyncCollector
832
+
833
+ if not isinstance(self, MultiSyncCollector) and cat_results not in (
834
+ "stack",
835
+ None,
836
+ ):
837
+ raise ValueError(
838
+ "cat_results can only be used with ``MultiSyncCollector``."
839
+ )
840
+ self.cat_results = cat_results
841
+
842
+ def _check_replay_buffer_init(self):
843
+ if self.replay_buffer is None:
844
+ return
845
+ is_init = hasattr(self.replay_buffer, "_storage") and getattr(
846
+ self.replay_buffer._storage, "initialized", True
847
+ )
848
+ if not is_init:
849
+ if self.local_init_rb:
850
+ # New behavior: storage handles all coordination itself
851
+ # Nothing to do here - the storage will coordinate during first write
852
+ self.replay_buffer.share()
853
+ return
854
+
855
+ # Legacy behavior: fake tensordict initialization
856
+ if isinstance(self.create_env_fn[0], EnvCreator):
857
+ fake_td = self.create_env_fn[0].meta_data.tensordict
858
+ elif isinstance(self.create_env_fn[0], EnvBase):
859
+ fake_td = self.create_env_fn[0].fake_tensordict()
860
+ else:
861
+ fake_td = self.create_env_fn[0](
862
+ **self.create_env_kwargs[0]
863
+ ).fake_tensordict()
864
+ fake_td["collector", "traj_ids"] = torch.zeros(
865
+ fake_td.shape, dtype=torch.long
866
+ )
867
+ # Use extend to avoid time-related transforms to fail
868
+ self.replay_buffer.extend(fake_td.unsqueeze(-1))
869
+ self.replay_buffer.empty()
870
+
871
+ @classmethod
872
+ def _total_workers_from_env(cls, env_creators):
873
+ if isinstance(env_creators, (tuple, list)):
874
+ return sum(
875
+ cls._total_workers_from_env(env_creator) for env_creator in env_creators
876
+ )
877
+ from torchrl.envs import ParallelEnv
878
+
879
+ if isinstance(env_creators, ParallelEnv):
880
+ return env_creators.num_workers
881
+ return 1
882
+
883
+ def _get_devices(
884
+ self,
885
+ *,
886
+ storing_device: torch.device,
887
+ policy_device: torch.device,
888
+ env_device: torch.device,
889
+ device: torch.device,
890
+ ):
891
+ # convert all devices to lists
892
+ if not isinstance(storing_device, (list, tuple)):
893
+ storing_device = [
894
+ storing_device,
895
+ ] * self.num_workers
896
+ if not isinstance(policy_device, (list, tuple)):
897
+ policy_device = [
898
+ policy_device,
899
+ ] * self.num_workers
900
+ if not isinstance(env_device, (list, tuple)):
901
+ env_device = [
902
+ env_device,
903
+ ] * self.num_workers
904
+ if not isinstance(device, (list, tuple)):
905
+ device = [
906
+ device,
907
+ ] * self.num_workers
908
+ if not (
909
+ len(device)
910
+ == len(storing_device)
911
+ == len(policy_device)
912
+ == len(env_device)
913
+ == self.num_workers
914
+ ):
915
+ raise RuntimeError(
916
+ f"THe length of the devices does not match the number of workers: {self.num_workers}."
917
+ )
918
+ storing_device, policy_device, env_device = zip(
919
+ *[
920
+ Collector._get_devices(
921
+ storing_device=storing_device,
922
+ policy_device=policy_device,
923
+ env_device=env_device,
924
+ device=device,
925
+ )
926
+ for (storing_device, policy_device, env_device, device) in zip(
927
+ storing_device, policy_device, env_device, device
928
+ )
929
+ ]
930
+ )
931
+ return storing_device, policy_device, env_device
932
+
933
+ def frames_per_batch_worker(self, *, worker_idx: int | None = None) -> int:
934
+ raise NotImplementedError
935
+
936
+ @property
937
+ def _queue_len(self) -> int:
938
+ raise NotImplementedError
939
+
940
+ def _recv_and_check(
941
+ self,
942
+ pipe,
943
+ *,
944
+ timeout: float | None = None,
945
+ check_interval: float = 1.0,
946
+ worker_idx: int | None = None,
947
+ ):
948
+ """Receive from a pipe while periodically checking worker health.
949
+
950
+ This method prevents the main process from hanging indefinitely if a worker
951
+ dies while we're waiting for a response. It polls the pipe with a timeout
952
+ and checks if all worker processes are still alive between polls.
953
+
954
+ The overhead is minimal: if data is already available, `poll()` returns
955
+ immediately and no health check is performed. Health checks only run
956
+ when actually waiting for a slow response.
957
+
958
+ Args:
959
+ pipe: The pipe to receive from.
960
+ timeout: Maximum total time to wait for a message (seconds).
961
+ If None (default), wait indefinitely but still check worker health
962
+ periodically.
963
+ check_interval: How often to check worker health (seconds). Default 1.0.
964
+ worker_idx: Optional worker index for error messages.
965
+
966
+ Returns:
967
+ The received message.
968
+
969
+ Raises:
970
+ RuntimeError: If a worker process dies while waiting.
971
+ TimeoutError: If no message is received within the timeout (only if
972
+ timeout is not None).
973
+ """
974
+ # Fast path: check if data is already available (no overhead)
975
+ if pipe.poll(0):
976
+ return pipe.recv()
977
+
978
+ # Slow path: wait with periodic health checks
979
+ elapsed = 0.0
980
+ while timeout is None or elapsed < timeout:
981
+ if pipe.poll(check_interval):
982
+ return pipe.recv()
983
+ elapsed += check_interval
984
+ # Check if any worker has died
985
+ _check_for_faulty_process(self.procs)
986
+ torchrl_logger.debug(
987
+ f"MultiCollector._recv_and_check: Still waiting after {elapsed:.1f}s"
988
+ + (f" for worker {worker_idx}" if worker_idx is not None else "")
989
+ )
990
+
991
+ # Final check before timeout
992
+ _check_for_faulty_process(self.procs)
993
+ worker_info = f" from worker {worker_idx}" if worker_idx is not None else ""
994
+ raise TimeoutError(
995
+ f"Timed out after {timeout}s waiting for message{worker_info}. "
996
+ f"All workers are still alive - this may indicate a deadlock or very slow operation."
997
+ )
998
+
999
+ def _run_processes(self) -> None:
1000
+ if self.num_threads is None:
1001
+ total_workers = self._total_workers_from_env(self.create_env_fn)
1002
+ self.num_threads = max(
1003
+ 1, torch.get_num_threads() - total_workers
1004
+ ) # 1 more thread for this proc
1005
+
1006
+ # Set up for worker processes
1007
+ torch.set_num_threads(self.num_threads)
1008
+ ctx = _get_mp_ctx()
1009
+ # Best-effort global init (only if unset) to keep other mp users consistent.
1010
+ _set_mp_start_method_if_unset(ctx.get_start_method())
1011
+ if sys.platform == "linux" and ctx.get_start_method() == "spawn":
1012
+ # On older PyTorch versions (< 2.8), pickling Process objects for "spawn"
1013
+ # can pass file descriptors for shared storages, causing spawn-time failures.
1014
+ # The strategy function returns "file_system" for old PyTorch, None otherwise.
1015
+ strategy = _mp_sharing_strategy_for_spawn()
1016
+ if strategy is not None:
1017
+ mp.set_sharing_strategy(strategy)
1018
+ queue_out = ctx.Queue(self._queue_len) # sends data from proc to main
1019
+ self.procs = []
1020
+ self._traj_pool = _TrajectoryPool(ctx=ctx, lock=True)
1021
+
1022
+ # Create all pipes upfront (needed for weight sync scheme initialization)
1023
+ # Store as list of (parent, child) tuples for use in worker creation
1024
+ pipe_pairs = [ctx.Pipe() for _ in range(self.num_workers)]
1025
+ # Extract parent pipes for external use (e.g., polling, receiving messages)
1026
+ self.pipes = [pipe_parent for pipe_parent, _ in pipe_pairs]
1027
+
1028
+ _ProcessNoWarnCtx = _make_process_no_warn_cls(ctx)
1029
+ # Initialize all weight sync schemes now that pipes are available
1030
+ # Both SharedMemWeightSyncScheme (uses queues) and MultiProcessWeightSyncScheme (uses pipes)
1031
+ # can be initialized here since all required resources exist
1032
+ if self._weight_sync_schemes:
1033
+ for model_id, scheme in self._weight_sync_schemes.items():
1034
+ if not scheme.initialized_on_sender:
1035
+ torchrl_logger.debug(
1036
+ f"Init weight sync scheme {type(scheme).__name__} for {model_id=}."
1037
+ )
1038
+ scheme.init_on_sender(model_id=model_id, context=self, ctx=ctx)
1039
+
1040
+ # Create a policy on the right device
1041
+ policy_factory = self.policy_factory
1042
+ has_policy_factory = any(policy_factory)
1043
+ if has_policy_factory:
1044
+ policy_factory = [
1045
+ CloudpickleWrapper(_policy_factory)
1046
+ for _policy_factory in policy_factory
1047
+ ]
1048
+
1049
+ for i, (env_fun, env_fun_kwargs) in enumerate(
1050
+ zip(self.create_env_fn, self.create_env_kwargs)
1051
+ ):
1052
+ pipe_parent, pipe_child = pipe_pairs[i] # use pre-created pipes
1053
+ if env_fun.__class__.__name__ != "EnvCreator" and not isinstance(
1054
+ env_fun, EnvBase
1055
+ ): # to avoid circular imports
1056
+ env_fun = CloudpickleWrapper(env_fun)
1057
+
1058
+ policy_device = self.policy_device[i]
1059
+ storing_device = self.storing_device[i]
1060
+ env_device = self.env_device[i]
1061
+
1062
+ # Prepare policy for worker based on weight synchronization method.
1063
+ # IMPORTANT: when a policy_factory is provided, the policy instance
1064
+ # is used ONLY on the main process (for weight extraction etc.) and
1065
+ # is NOT sent to workers.
1066
+ policy = self.policy
1067
+
1068
+ if self._weight_sync_schemes:
1069
+ # With weight sync schemes, send stateless policies.
1070
+ # Schemes handle weight distribution on worker side.
1071
+ if has_policy_factory:
1072
+ # Factory will create policy in worker; don't send policy.
1073
+ policy_to_send = None
1074
+ cm = contextlib.nullcontext()
1075
+ elif policy is not None:
1076
+ # Send a stateless policy down to workers: schemes apply weights.
1077
+ policy_to_send = policy
1078
+ cm = _make_meta_policy_cm(
1079
+ policy, mp_start_method=ctx.get_start_method()
1080
+ )
1081
+ else:
1082
+ policy_to_send = None
1083
+ cm = contextlib.nullcontext()
1084
+ elif hasattr(self, "_policy_weights_dict"):
1085
+ # LEGACY:
1086
+ # With weight updater, use in-place weight replacement.
1087
+ # Take the weights and locally dispatch them to the policy before sending.
1088
+ # This ensures a given set of shared weights for a device are shared
1089
+ # for all policies that rely on that device.
1090
+ policy_weights = self._policy_weights_dict.get(policy_device)
1091
+ if has_policy_factory:
1092
+ # Even in legacy mode, when a policy_factory is present, do not
1093
+ # send the stateful policy down to workers.
1094
+ policy_to_send = None
1095
+ cm = contextlib.nullcontext()
1096
+ else:
1097
+ policy_to_send = policy
1098
+ if policy is not None and policy_weights is not None:
1099
+ cm = policy_weights.to_module(policy)
1100
+ else:
1101
+ cm = contextlib.nullcontext()
1102
+ else:
1103
+ # Parameter-less policy.
1104
+ cm = contextlib.nullcontext()
1105
+ # When a policy_factory exists, never send the policy instance.
1106
+ policy_to_send = None if has_policy_factory else policy
1107
+
1108
+ with cm:
1109
+ kwargs = {
1110
+ "policy_factory": policy_factory[i],
1111
+ "pipe_child": pipe_child,
1112
+ "queue_out": queue_out,
1113
+ "create_env_fn": env_fun,
1114
+ "create_env_kwargs": env_fun_kwargs,
1115
+ "policy": policy_to_send,
1116
+ "max_frames_per_traj": self.max_frames_per_traj,
1117
+ "frames_per_batch": self.frames_per_batch_worker(worker_idx=i),
1118
+ "reset_at_each_iter": self.reset_at_each_iter,
1119
+ "policy_device": policy_device,
1120
+ "storing_device": storing_device,
1121
+ "env_device": env_device,
1122
+ "exploration_type": self.exploration_type,
1123
+ "reset_when_done": self.reset_when_done,
1124
+ "idx": i,
1125
+ "interruptor": self.interruptor,
1126
+ "set_truncated": self.set_truncated,
1127
+ "use_buffers": self._use_buffers,
1128
+ "replay_buffer": self.replay_buffer,
1129
+ "extend_buffer": self.extend_buffer,
1130
+ "traj_pool": self._traj_pool,
1131
+ "trust_policy": self.trust_policy,
1132
+ "compile_policy": self.compiled_policy_kwargs
1133
+ if self.compiled_policy
1134
+ else False,
1135
+ "cudagraph_policy": self.cudagraphed_policy_kwargs
1136
+ if self.cudagraphed_policy
1137
+ else False,
1138
+ "no_cuda_sync": self.no_cuda_sync,
1139
+ "collector_class": self.collector_class,
1140
+ "postproc": self.postprocs
1141
+ if self.replay_buffer is not None
1142
+ else None,
1143
+ "weight_sync_schemes": self._weight_sync_schemes,
1144
+ "worker_idx": i, # Worker index for queue-based weight distribution
1145
+ "init_random_frames": self.init_random_frames,
1146
+ "profile_config": self._profile_config,
1147
+ }
1148
+ proc = _ProcessNoWarnCtx(
1149
+ target=_main_async_collector,
1150
+ num_threads=self.num_sub_threads,
1151
+ _start_method=ctx.get_start_method(),
1152
+ kwargs=kwargs,
1153
+ )
1154
+ # proc.daemon can't be set as daemonic processes may be launched by the process itself
1155
+ try:
1156
+ proc.start()
1157
+ except TypeError as err:
1158
+ if "cannot pickle" in str(err):
1159
+ raise RuntimeError(
1160
+ "A non-serializable object was passed to the collector workers."
1161
+ ) from err
1162
+ except RuntimeError as err:
1163
+ if "Cowardly refusing to serialize non-leaf tensor" in str(err):
1164
+ raise RuntimeError(
1165
+ "At least one of the tensors in the policy, replay buffer, environment constructor or postprocessor requires gradients. "
1166
+ "This is not supported in multiprocessed data collectors.\n- For ReplayBuffer transforms, use a `transform_factory` instead with `delayed_init=True`.\n"
1167
+ "- Make sure your environment constructor does not reference tensors already instantiated on the main process.\n"
1168
+ "- Since no gradient can be propagated through the Collector pipes, the backward graph is never needed. Consider using detached tensors instead."
1169
+ ) from err
1170
+ elif "_share_fd_: only available on CPU" in str(
1171
+ err
1172
+ ) or "_share_filename_: only available on CPU" in str(err):
1173
+ # This is a common failure mode on older PyTorch versions when using the
1174
+ # "spawn" multiprocessing start method: the process object contains a
1175
+ # CUDA/MPS tensor (or a module/buffer on a non-CPU device), which must be
1176
+ # pickled when spawning workers.
1177
+ #
1178
+ # See: https://github.com/pytorch/pytorch/issues/87688#issuecomment-1968901877
1179
+ start_method = None
1180
+ try:
1181
+ start_method = mp.get_start_method(allow_none=True)
1182
+ except Exception:
1183
+ # Best effort: some environments may disallow querying here.
1184
+ start_method = None
1185
+ raise RuntimeError(
1186
+ "Failed to start a collector worker process because a non-CPU tensor "
1187
+ "was captured in the worker process arguments and had to be serialized "
1188
+ "(pickled) at process start.\n\n"
1189
+ f"Detected multiprocessing start method: {start_method!r}.\n\n"
1190
+ "Workarounds:\n"
1191
+ "- Keep any tensors/modules referenced by your collector constructor "
1192
+ "(policy, replay buffer, postprocs, env factory captures, etc.) on CPU "
1193
+ "when using a spawning start method (common on macOS/Windows).\n"
1194
+ "- Or set the multiprocessing start method to 'fork' *before* creating "
1195
+ "the collector (Unix only). Example:\n\n"
1196
+ " import torch.multiprocessing as mp\n"
1197
+ " if __name__ == '__main__':\n"
1198
+ " mp.set_start_method('fork', force=True)\n\n"
1199
+ "Upstream context: https://github.com/pytorch/pytorch/issues/87688#issuecomment-1968901877"
1200
+ ) from err
1201
+ else:
1202
+ raise err
1203
+ except ValueError as err:
1204
+ if "bad value(s) in fds_to_keep" in str(err):
1205
+ # This error occurs on old Python versions (e.g., 3.9) with old PyTorch (e.g., 2.3)
1206
+ # when using the spawn multiprocessing start method. The spawn implementation tries to
1207
+ # preserve file descriptors across exec, but some descriptors may be invalid/closed.
1208
+ # This is a compatibility issue with old Python multiprocessing implementations.
1209
+ python_version = (
1210
+ f"{sys.version_info.major}.{sys.version_info.minor}"
1211
+ )
1212
+ raise RuntimeError(
1213
+ f"Failed to start collector worker process due to file descriptor issues "
1214
+ f"with spawn multiprocessing on Python {python_version}.\n\n"
1215
+ f"This is a known compatibility issue with old Python/PyTorch stacks. "
1216
+ f"Consider upgrading to Python >= 3.10 and PyTorch >= 2.5, or use the 'fork' "
1217
+ f"multiprocessing start method on Unix systems.\n\n"
1218
+ f"Workarounds:\n"
1219
+ f"- Upgrade Python to >= 3.10 and PyTorch to >= 2.5\n"
1220
+ f"- On Unix systems, force fork start method:\n"
1221
+ f" import torch.multiprocessing as mp\n"
1222
+ f" if __name__ == '__main__':\n"
1223
+ f" mp.set_start_method('fork', force=True)\n\n"
1224
+ f"Upstream Python issue: https://github.com/python/cpython/issues/87706"
1225
+ ) from err
1226
+ except _pickle.PicklingError as err:
1227
+ if "<lambda>" in str(err):
1228
+ raise RuntimeError(
1229
+ """Can't open a process with doubly cloud-pickled lambda function.
1230
+ This error is likely due to an attempt to use a ParallelEnv in a
1231
+ multiprocessed data collector. To do this, consider wrapping your
1232
+ lambda function in an `torchrl.envs.EnvCreator` wrapper as follows:
1233
+ `env = ParallelEnv(N, EnvCreator(my_lambda_function))`.
1234
+ This will not only ensure that your lambda function is cloud-pickled once, but
1235
+ also that the state dict is synchronised across processes if needed."""
1236
+ ) from err
1237
+ pipe_child.close()
1238
+ self.procs.append(proc)
1239
+
1240
+ # Synchronize initial weights with workers AFTER starting processes but BEFORE waiting for "instantiated"
1241
+ # This must happen after proc.start() but before workers send "instantiated" to avoid deadlock:
1242
+ # Workers will call receiver.collect() during init and may block waiting for data
1243
+ if self._weight_sync_schemes:
1244
+ # start with policy
1245
+ policy_scheme = self._weight_sync_schemes.get("policy")
1246
+ if policy_scheme is not None:
1247
+ policy_scheme.connect()
1248
+ for key, scheme in self._weight_sync_schemes.items():
1249
+ if key == "policy":
1250
+ continue
1251
+ scheme.connect()
1252
+
1253
+ # Wait for workers to be ready
1254
+ for i, pipe_parent in enumerate(self.pipes):
1255
+ pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT)
1256
+ try:
1257
+ msg = pipe_parent.recv()
1258
+ except EOFError as e:
1259
+ raise RuntimeError(
1260
+ f"Worker {i} failed to initialize and closed the connection before sending status. "
1261
+ f"This typically indicates that the worker process crashed during initialization. "
1262
+ f"Check the worker process logs for the actual error."
1263
+ ) from e
1264
+ if msg != "instantiated":
1265
+ # Check if it's an error dict from worker
1266
+ if isinstance(msg, dict) and msg.get("error"):
1267
+ # Reconstruct the exception from the worker
1268
+ exc_type_name = msg["exception_type"]
1269
+ exc_msg = msg["exception_msg"]
1270
+ traceback_str = msg["traceback"]
1271
+
1272
+ # Try to get the actual exception class
1273
+ exc_class = None
1274
+ exc_module = msg["exception_module"]
1275
+
1276
+ if exc_module == "builtins":
1277
+ # Get from builtins
1278
+ import builtins
1279
+
1280
+ exc_class = getattr(builtins, exc_type_name, None)
1281
+ else:
1282
+ # Try to import from the module
1283
+ try:
1284
+ import importlib
1285
+
1286
+ mod = importlib.import_module(exc_module)
1287
+ exc_class = getattr(mod, exc_type_name, None)
1288
+ except Exception:
1289
+ pass
1290
+
1291
+ # Re-raise with original exception type if possible
1292
+ if exc_class is not None:
1293
+ raise exc_class(
1294
+ f"{exc_msg}\n\nWorker traceback:\n{traceback_str}"
1295
+ )
1296
+ else:
1297
+ # Fall back to RuntimeError if we can't get the original type
1298
+ raise RuntimeError(
1299
+ f"Worker {i} raised {exc_type_name}: {exc_msg}\n\nWorker traceback:\n{traceback_str}"
1300
+ )
1301
+ else:
1302
+ # Legacy string error message
1303
+ raise RuntimeError(msg)
1304
+
1305
+ self.queue_out = queue_out
1306
+ self.closed = False
1307
+
1308
+ _running_free = False
1309
+
1310
+ def start(self):
1311
+ """Starts the collector(s) for asynchronous data collection.
1312
+
1313
+ The collected data is stored in the provided replay buffer. This method initiates the background collection of
1314
+ data across multiple processes, allowing for decoupling of data collection and training.
1315
+
1316
+ Raises:
1317
+ RuntimeError: If no replay buffer is defined during the collector's initialization.
1318
+
1319
+ Example:
1320
+ >>> from torchrl.modules import RandomPolicy >>> >>> import time
1321
+ >>> from functools import partial
1322
+ >>>
1323
+ >>> import tqdm
1324
+ >>>
1325
+ >>> from torchrl.collectors import MultiAsyncCollector
1326
+ >>> from torchrl.data import LazyTensorStorage, ReplayBuffer
1327
+ >>> from torchrl.envs import GymEnv, set_gym_backend
1328
+ >>> import ale_py
1329
+ >>>
1330
+ >>> # Set the gym backend to gymnasium
1331
+ >>> set_gym_backend("gymnasium").set()
1332
+ >>>
1333
+ >>> if __name__ == "__main__":
1334
+ ... # Create a random policy for the Pong environment
1335
+ ... env_fn = partial(GymEnv, "ALE/Pong-v5")
1336
+ ... policy = RandomPolicy(env_fn().action_spec)
1337
+ ...
1338
+ ... # Initialize a shared replay buffer
1339
+ ... rb = ReplayBuffer(storage=LazyTensorStorage(10000), shared=True)
1340
+ ...
1341
+ ... # Create a multi-async data collector with 16 environments
1342
+ ... num_envs = 16
1343
+ ... collector = MultiAsyncCollector(
1344
+ ... [env_fn] * num_envs,
1345
+ ... policy=policy,
1346
+ ... replay_buffer=rb,
1347
+ ... frames_per_batch=num_envs * 16,
1348
+ ... total_frames=-1,
1349
+ ... )
1350
+ ...
1351
+ ... # Progress bar to track the number of collected frames
1352
+ ... pbar = tqdm.tqdm(total=100_000)
1353
+ ...
1354
+ ... # Start the collector asynchronously
1355
+ ... collector.start()
1356
+ ...
1357
+ ... # Track the write count of the replay buffer
1358
+ ... prec_wc = 0
1359
+ ... while True:
1360
+ ... wc = rb.write_count
1361
+ ... c = wc - prec_wc
1362
+ ... prec_wc = wc
1363
+ ...
1364
+ ... # Update the progress bar
1365
+ ... pbar.update(c)
1366
+ ... pbar.set_description(f"Write Count: {rb.write_count}")
1367
+ ...
1368
+ ... # Check the write count every 0.5 seconds
1369
+ ... time.sleep(0.5)
1370
+ ...
1371
+ ... # Stop when the desired number of frames is reached
1372
+ ... if rb.write_count . 100_000:
1373
+ ... break
1374
+ ...
1375
+ ... # Shut down the collector
1376
+ ... collector.async_shutdown()
1377
+ """
1378
+ if self.replay_buffer is None:
1379
+ raise RuntimeError("Replay buffer must be defined for execution.")
1380
+ self._running_free = True
1381
+ torchrl_logger.debug(
1382
+ f"MultiCollector.start(): Sending run_free to {len(self.pipes)} workers..."
1383
+ )
1384
+ for i, pipe in enumerate(self.pipes):
1385
+ pipe.send((None, "run_free"))
1386
+ torchrl_logger.debug(f"MultiCollector.start(): Sent run_free to worker {i}")
1387
+
1388
+ @contextlib.contextmanager
1389
+ def pause(self):
1390
+ """Context manager that pauses the collector if it is running free."""
1391
+ if self._running_free:
1392
+ for pipe in self.pipes:
1393
+ pipe.send((None, "pause"))
1394
+ # Make sure all workers are paused
1395
+ for i in range(len(self.pipes)):
1396
+ # Use timeout with health check to avoid hanging if a worker dies
1397
+ timeout = 30.0
1398
+ check_interval = 1.0
1399
+ elapsed = 0.0
1400
+ while elapsed < timeout:
1401
+ try:
1402
+ idx, msg = self.queue_out.get(timeout=check_interval)
1403
+ break
1404
+ except Exception:
1405
+ elapsed += check_interval
1406
+ _check_for_faulty_process(self.procs)
1407
+ else:
1408
+ _check_for_faulty_process(self.procs)
1409
+ raise TimeoutError(
1410
+ f"Timed out waiting for worker {i} to pause after {timeout}s"
1411
+ )
1412
+ if msg != "paused":
1413
+ raise ValueError(f"Expected paused, but got {msg=}.")
1414
+ torchrl_logger.debug(f"Worker {idx} is paused.")
1415
+ self._running_free = False
1416
+ yield None
1417
+ for pipe in self.pipes:
1418
+ pipe.send((None, "restart"))
1419
+ self._running_free = True
1420
+ else:
1421
+ raise RuntimeError("Collector cannot be paused.")
1422
+
1423
+ def enable_profile(self, **kwargs) -> None:
1424
+ """Enable profiling for collector worker rollouts.
1425
+
1426
+ For multi-process collectors, this sends the profile configuration
1427
+ to the specified workers. Must be called before iteration starts.
1428
+
1429
+ See :meth:`BaseCollector.enable_profile` for full documentation.
1430
+ """
1431
+ # First, call parent to validate and set _profile_config
1432
+ super().enable_profile(**kwargs)
1433
+
1434
+ # Send profile config to workers that should be profiled
1435
+ if self._profile_config is not None:
1436
+ for idx in self._profile_config.workers:
1437
+ if idx < self.num_workers:
1438
+ self.pipes[idx].send((self._profile_config, "enable_profile"))
1439
+
1440
+ # Wait for confirmation from workers
1441
+ for idx in self._profile_config.workers:
1442
+ if idx < self.num_workers:
1443
+ if self.pipes[idx].poll(INSTANTIATE_TIMEOUT):
1444
+ _, msg = self.pipes[idx].recv()
1445
+ if msg != "profile_enabled":
1446
+ raise RuntimeError(
1447
+ f"Worker {idx}: Expected 'profile_enabled' message, got {msg}"
1448
+ )
1449
+ else:
1450
+ raise TimeoutError(
1451
+ f"Worker {idx}: Timed out waiting for profile confirmation."
1452
+ )
1453
+
1454
+ def __del__(self):
1455
+ try:
1456
+ self.shutdown()
1457
+ except Exception:
1458
+ # an AttributeError will typically be raised if the collector is deleted when the program ends.
1459
+ # In the future, insignificant changes to the close method may change the error type.
1460
+ # We excplicitely assume that any error raised during closure in
1461
+ # __del__ will not affect the program.
1462
+ pass
1463
+
1464
+ def shutdown(
1465
+ self,
1466
+ timeout: float | None = None,
1467
+ close_env: bool = True,
1468
+ raise_on_error: bool = True,
1469
+ ) -> None:
1470
+ """Shuts down all processes. This operation is irreversible.
1471
+
1472
+ Args:
1473
+ timeout (float, optional): The timeout for closing pipes between workers.
1474
+ close_env (bool, optional): Whether to close the environment. Defaults to `True`.
1475
+ raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`.
1476
+ """
1477
+ if not close_env:
1478
+ raise RuntimeError(
1479
+ f"Cannot shutdown {type(self).__name__} collector without environment being closed."
1480
+ )
1481
+ try:
1482
+ self._shutdown_main(timeout)
1483
+ except Exception as e:
1484
+ if raise_on_error:
1485
+ raise e
1486
+ else:
1487
+ pass
1488
+
1489
+ def _shutdown_main(self, timeout: float | None = None) -> None:
1490
+ if timeout is None:
1491
+ timeout = 10
1492
+ try:
1493
+ if self.closed:
1494
+ return
1495
+ _check_for_faulty_process(self.procs)
1496
+ all_closed = [False] * self.num_workers
1497
+ rep = 0
1498
+ for idx in range(self.num_workers):
1499
+ if all_closed[idx]:
1500
+ continue
1501
+ if not self.procs[idx].is_alive():
1502
+ continue
1503
+ self.pipes[idx].send((None, "close"))
1504
+
1505
+ while not all(all_closed) and rep < 1000:
1506
+ rep += 1
1507
+ for idx in range(self.num_workers):
1508
+ if all_closed[idx]:
1509
+ continue
1510
+ if not self.procs[idx].is_alive():
1511
+ all_closed[idx] = True
1512
+ continue
1513
+ try:
1514
+ if self.pipes[idx].poll(timeout / 1000 / self.num_workers):
1515
+ msg = self.pipes[idx].recv()
1516
+ if msg != "closed":
1517
+ raise RuntimeError(f"got {msg} but expected 'close'")
1518
+ all_closed[idx] = True
1519
+ else:
1520
+ continue
1521
+ except BrokenPipeError:
1522
+ all_closed[idx] = True
1523
+ continue
1524
+ self.closed = True
1525
+
1526
+ self.queue_out.close()
1527
+ for pipe in self.pipes:
1528
+ pipe.close()
1529
+ for proc in self.procs:
1530
+ proc.join(1.0)
1531
+ finally:
1532
+ import torchrl
1533
+
1534
+ num_threads = min(
1535
+ torchrl._THREAD_POOL_INIT,
1536
+ torch.get_num_threads()
1537
+ + self._total_workers_from_env(self.create_env_fn),
1538
+ )
1539
+ torch.set_num_threads(num_threads)
1540
+
1541
+ for proc in self.procs:
1542
+ if proc.is_alive():
1543
+ proc.terminate()
1544
+
1545
+ def async_shutdown(self, timeout: float | None = None):
1546
+ return self.shutdown(timeout=timeout)
1547
+
1548
+ def set_seed(self, seed: int, static_seed: bool = False) -> int:
1549
+ """Sets the seeds of the environments stored in the DataCollector.
1550
+
1551
+ Args:
1552
+ seed: integer representing the seed to be used for the environment.
1553
+ static_seed (bool, optional): if ``True``, the seed is not incremented.
1554
+ Defaults to False
1555
+
1556
+ Returns:
1557
+ Output seed. This is useful when more than one environment is
1558
+ contained in the DataCollector, as the seed will be incremented for
1559
+ each of these. The resulting seed is the seed of the last
1560
+ environment.
1561
+
1562
+ Examples:
1563
+ >>> from torchrl.envs import ParallelEnv
1564
+ >>> from torchrl.envs.libs.gym import GymEnv
1565
+ >>> from tensordict.nn import TensorDictModule
1566
+ >>> from torch import nn
1567
+ >>> env_fn = lambda: GymEnv("Pendulum-v1")
1568
+ >>> env_fn_parallel = lambda: ParallelEnv(6, env_fn)
1569
+ >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
1570
+ >>> collector = Collector(env_fn_parallel, policy, frames_per_batch=100, total_frames=300)
1571
+ >>> out_seed = collector.set_seed(1) # out_seed = 6
1572
+
1573
+ """
1574
+ _check_for_faulty_process(self.procs)
1575
+ for idx in range(self.num_workers):
1576
+ self.pipes[idx].send(((seed, static_seed), "seed"))
1577
+ new_seed, msg = self._recv_and_check(self.pipes[idx], worker_idx=idx)
1578
+ if msg != "seeded":
1579
+ raise RuntimeError(f"Expected msg='seeded', got {msg}")
1580
+ seed = new_seed
1581
+ self.reset()
1582
+ return seed
1583
+
1584
+ def reset(self, reset_idx: Sequence[bool] | None = None) -> None:
1585
+ """Resets the environments to a new initial state.
1586
+
1587
+ Args:
1588
+ reset_idx: Optional. Sequence indicating which environments have
1589
+ to be reset. If None, all environments are reset.
1590
+
1591
+ """
1592
+ _check_for_faulty_process(self.procs)
1593
+
1594
+ if reset_idx is None:
1595
+ reset_idx = [True for _ in range(self.num_workers)]
1596
+ for idx in range(self.num_workers):
1597
+ if reset_idx[idx]:
1598
+ self.pipes[idx].send((None, "reset"))
1599
+ for idx in range(self.num_workers):
1600
+ if reset_idx[idx]:
1601
+ j, msg = self._recv_and_check(self.pipes[idx], worker_idx=idx)
1602
+ if msg != "reset":
1603
+ raise RuntimeError(f"Expected msg='reset', got {msg}")
1604
+
1605
+ def state_dict(self) -> OrderedDict:
1606
+ """Returns the state_dict of the data collector.
1607
+
1608
+ Each field represents a worker containing its own state_dict.
1609
+
1610
+ """
1611
+ for idx in range(self.num_workers):
1612
+ self.pipes[idx].send((None, "state_dict"))
1613
+ state_dict = OrderedDict()
1614
+ for idx in range(self.num_workers):
1615
+ _state_dict, msg = self._recv_and_check(self.pipes[idx], worker_idx=idx)
1616
+ if msg != "state_dict":
1617
+ raise RuntimeError(f"Expected msg='state_dict', got {msg}")
1618
+ state_dict[f"worker{idx}"] = _state_dict
1619
+ state_dict.update({"frames": self._frames, "iter": self._iter})
1620
+
1621
+ return state_dict
1622
+
1623
+ def load_state_dict(self, state_dict: OrderedDict) -> None:
1624
+ """Loads the state_dict on the workers.
1625
+
1626
+ Args:
1627
+ state_dict (OrderedDict): state_dict of the form
1628
+ ``{"worker0": state_dict0, "worker1": state_dict1}``.
1629
+
1630
+ """
1631
+ for idx in range(self.num_workers):
1632
+ self.pipes[idx].send((state_dict[f"worker{idx}"], "load_state_dict"))
1633
+ for idx in range(self.num_workers):
1634
+ _, msg = self._recv_and_check(self.pipes[idx], worker_idx=idx)
1635
+ if msg != "loaded":
1636
+ raise RuntimeError(f"Expected msg='loaded', got {msg}")
1637
+ self._frames = state_dict["frames"]
1638
+ self._iter = state_dict["iter"]
1639
+
1640
+ def increment_version(self):
1641
+ """Increment the policy version."""
1642
+ if self.policy_version_tracker is not None:
1643
+ if not hasattr(self.policy_version_tracker, "increment_version"):
1644
+ raise RuntimeError(
1645
+ "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector."
1646
+ )
1647
+ self.policy_version_tracker.increment_version()
1648
+
1649
+ @property
1650
+ def policy_version(self) -> str | int | None:
1651
+ """The current policy version."""
1652
+ if not hasattr(self.policy_version_tracker, "version"):
1653
+ return None
1654
+ return self.policy_version_tracker.version
1655
+
1656
+ def get_policy_version(self) -> str | int | None:
1657
+ """Get the current policy version.
1658
+
1659
+ This method exists to support remote calls in Ray actors, since properties
1660
+ cannot be accessed directly through Ray's RPC mechanism.
1661
+
1662
+ Returns:
1663
+ The current version number (int) or UUID (str), or None if version tracking is disabled.
1664
+ """
1665
+ return self.policy_version
1666
+
1667
+ def getattr_policy(self, attr):
1668
+ """Get an attribute from the policy of the first worker.
1669
+
1670
+ Args:
1671
+ attr (str): The attribute name to retrieve from the policy.
1672
+
1673
+ Returns:
1674
+ The attribute value from the policy of the first worker.
1675
+
1676
+ Raises:
1677
+ AttributeError: If the attribute doesn't exist on the policy.
1678
+ """
1679
+ _check_for_faulty_process(self.procs)
1680
+
1681
+ # Send command to first worker (index 0)
1682
+ self.pipes[0].send((attr, "getattr_policy"))
1683
+ result, msg = self._recv_and_check(self.pipes[0], worker_idx=0)
1684
+ if msg != "getattr_policy":
1685
+ raise RuntimeError(f"Expected msg='getattr_policy', got {msg}")
1686
+
1687
+ # If the worker returned an AttributeError, re-raise it
1688
+ if isinstance(result, AttributeError):
1689
+ raise result
1690
+
1691
+ return result
1692
+
1693
+ def getattr_env(self, attr):
1694
+ """Get an attribute from the environment of the first worker.
1695
+
1696
+ Args:
1697
+ attr (str): The attribute name to retrieve from the environment.
1698
+
1699
+ Returns:
1700
+ The attribute value from the environment of the first worker.
1701
+
1702
+ Raises:
1703
+ AttributeError: If the attribute doesn't exist on the environment.
1704
+ """
1705
+ _check_for_faulty_process(self.procs)
1706
+
1707
+ # Send command to first worker (index 0)
1708
+ self.pipes[0].send((attr, "getattr_env"))
1709
+ result, msg = self._recv_and_check(self.pipes[0], worker_idx=0)
1710
+ if msg != "getattr_env":
1711
+ raise RuntimeError(f"Expected msg='getattr_env', got {msg}")
1712
+
1713
+ # If the worker returned an AttributeError, re-raise it
1714
+ if isinstance(result, AttributeError):
1715
+ raise result
1716
+
1717
+ return result
1718
+
1719
+ def getattr_rb(self, attr):
1720
+ """Get an attribute from the replay buffer."""
1721
+ return getattr(self.replay_buffer, attr)
1722
+
1723
+ def get_model(self, model_id: str):
1724
+ """Get model instance by ID (for weight sync schemes).
1725
+
1726
+ Args:
1727
+ model_id: Model identifier (e.g., "policy", "value_net")
1728
+
1729
+ Returns:
1730
+ The model instance
1731
+
1732
+ Raises:
1733
+ ValueError: If model_id is not recognized
1734
+ """
1735
+ if model_id == "policy":
1736
+ # Return the fallback policy instance
1737
+ if (fallback_policy := getattr(self, "_fallback_policy", None)) is not None:
1738
+ return fallback_policy
1739
+ elif hasattr(self, "policy") and self.policy is not None:
1740
+ return self.policy
1741
+ else:
1742
+ raise ValueError(f"No policy found for model_id '{model_id}'")
1743
+ else:
1744
+ # Try to resolve via attribute access
1745
+ return _resolve_model(self, model_id)
1746
+
1747
+ def get_cached_weights(self, model_id: str):
1748
+ """Get cached shared memory weights if available (for weight sync schemes).
1749
+
1750
+ Args:
1751
+ model_id: Model identifier
1752
+
1753
+ Returns:
1754
+ Cached TensorDict weights or None if not available
1755
+ """
1756
+ if model_id == "policy" and hasattr(self, "_policy_weights_dict"):
1757
+ # Get the policy device (first device if list)
1758
+ policy_device = self.policy_device
1759
+ if isinstance(policy_device, (list, tuple)):
1760
+ policy_device = policy_device[0] if len(policy_device) > 0 else None
1761
+
1762
+ # Return cached weights for this device
1763
+ return self._policy_weights_dict.get(policy_device)
1764
+ return None
1765
+
1766
+ def _weight_update_impl(
1767
+ self,
1768
+ policy_or_weights: TensorDictBase | nn.Module | dict | None = None,
1769
+ *,
1770
+ worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
1771
+ model_id: str | None = None,
1772
+ weights_dict: dict[str, Any] | None = None,
1773
+ **kwargs,
1774
+ ) -> None:
1775
+ """Update weights on workers.
1776
+
1777
+ Weight sync schemes now use background threads on the receiver side.
1778
+ The scheme's send() method:
1779
+ 1. Puts weights in the queue (or updates shared memory)
1780
+ 2. Sends a "receive" instruction to the worker's background thread
1781
+ 3. Waits for acknowledgment (if sync=True)
1782
+
1783
+ No pipe signaling is needed - the scheme handles everything internally.
1784
+ """
1785
+ # Call parent implementation which calls scheme.send()
1786
+ # The scheme handles instruction delivery and acknowledgments
1787
+ super()._weight_update_impl(
1788
+ policy_or_weights=policy_or_weights,
1789
+ worker_ids=worker_ids,
1790
+ model_id=model_id,
1791
+ weights_dict=weights_dict,
1792
+ **kwargs,
1793
+ )
1794
+
1795
+ # for RPC
1796
+ def receive_weights(self, policy_or_weights: TensorDictBase | None = None):
1797
+ return super().receive_weights(policy_or_weights)
1798
+
1799
+ # for RPC
1800
+ def _receive_weights_scheme(self):
1801
+ return super()._receive_weights_scheme()
1802
+
1803
+
1804
+ # Backward-compatible alias (deprecated, use MultiCollector instead)
1805
+ MultiCollector = MultiCollector