torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,3093 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import functools
9
+ import gc
10
+ import os
11
+ import time
12
+ import warnings
13
+ import weakref
14
+ from collections import OrderedDict
15
+ from collections.abc import Callable, Mapping, Sequence
16
+ from copy import deepcopy
17
+ from functools import wraps
18
+ from multiprocessing import connection
19
+ from multiprocessing.connection import wait as connection_wait
20
+ from multiprocessing.synchronize import Lock as MpLock
21
+ from typing import Any
22
+ from warnings import warn
23
+
24
+ import torch
25
+ from tensordict import (
26
+ is_tensor_collection,
27
+ LazyStackedTensorDict,
28
+ TensorDict,
29
+ TensorDictBase,
30
+ unravel_key,
31
+ )
32
+ from tensordict.base import _is_leaf_nontensor
33
+ from tensordict.utils import _zip_strict
34
+ from torch import multiprocessing as mp
35
+
36
+ from torchrl._utils import (
37
+ _check_for_faulty_process,
38
+ _get_default_mp_start_method,
39
+ _make_ordinal_device,
40
+ logger as torchrl_logger,
41
+ rl_warnings,
42
+ timeit,
43
+ VERBOSE,
44
+ )
45
+ from torchrl.data.tensor_specs import Composite, NonTensor
46
+ from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING
47
+ from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, EnvMetaData
48
+
49
+ from torchrl.envs.env_creator import EnvCreator, get_env_metadata
50
+
51
+ from torchrl.envs.utils import (
52
+ _aggregate_end_of_traj,
53
+ _sort_keys,
54
+ _update_during_reset,
55
+ clear_mpi_env_vars,
56
+ )
57
+
58
+ _CONSOLIDATE_ERR_CAPTURE = (
59
+ "TensorDict.consolidate failed. You can deactivate the tensordict consolidation via the "
60
+ "`consolidate` keyword argument of the ParallelEnv constructor."
61
+ )
62
+
63
+
64
+ def _check_start(fun):
65
+ def decorated_fun(self: BatchedEnvBase, *args, **kwargs):
66
+ if self.is_closed:
67
+ self._create_td()
68
+ self._start_workers()
69
+ else:
70
+ if isinstance(self, ParallelEnv):
71
+ _check_for_faulty_process(self._workers)
72
+ return fun(self, *args, **kwargs)
73
+
74
+ return decorated_fun
75
+
76
+
77
+ class _dispatch_caller_parallel:
78
+ def __init__(self, attr, parallel_env):
79
+ self.attr = attr
80
+ self.parallel_env = parallel_env
81
+
82
+ def __call__(self, *args, **kwargs):
83
+ # remove self from args
84
+ args = [_arg if _arg is not self.parallel_env else "_self" for _arg in args]
85
+ for channel in self.parallel_env.parent_channels:
86
+ channel.send((self.attr, (args, kwargs)))
87
+
88
+ results = []
89
+ for channel in self.parallel_env.parent_channels:
90
+ msg, result = channel.recv()
91
+ results.append(result)
92
+
93
+ return results
94
+
95
+ def __iter__(self):
96
+ # if the object returned is not a callable
97
+ return iter(self.__call__())
98
+
99
+
100
+ class _dispatch_caller_serial:
101
+ def __init__(self, list_callable: list[Callable, Any]):
102
+ self.list_callable = list_callable
103
+
104
+ def __call__(self, *args, **kwargs):
105
+ return [_callable(*args, **kwargs) for _callable in self.list_callable]
106
+
107
+
108
+ def lazy_property(prop: property):
109
+ """Converts a property in a lazy property, that will call _set_properties when queried the first time."""
110
+ return property(fget=lazy(prop.fget), fset=prop.fset)
111
+
112
+
113
+ def lazy(fun):
114
+ """Converts a fun in a lazy fun, that will call _set_properties when queried the first time."""
115
+
116
+ @wraps(fun)
117
+ def new_fun(self, *args, **kwargs):
118
+ if not self._properties_set:
119
+ self._set_properties()
120
+ return fun(self, *args, **kwargs)
121
+
122
+ return new_fun
123
+
124
+
125
+ def _is_unpicklable_lambda(fn: Callable) -> bool:
126
+ """Check if a callable is a lambda function that needs cloudpickle wrapping.
127
+
128
+ Lambda functions cannot be pickled with standard pickle, so they need to be
129
+ wrapped with EnvCreator (which uses CloudpickleWrapper) for multiprocessing.
130
+ functools.partial objects are picklable, so they don't need wrapping.
131
+ """
132
+ if isinstance(fn, functools.partial):
133
+ return False
134
+ return callable(fn) and getattr(fn, "__name__", None) == "<lambda>"
135
+
136
+
137
+ class _PEnvMeta(_EnvPostInit):
138
+ def __call__(cls, *args, **kwargs):
139
+ serial_for_single = kwargs.pop("serial_for_single", False)
140
+ if serial_for_single:
141
+ num_workers = kwargs.get("num_workers")
142
+ # Remove start method from kwargs
143
+ kwargs.pop("mp_start_method", None)
144
+ if num_workers is None:
145
+ num_workers = args[0]
146
+ if num_workers == 1:
147
+ # We still use a serial to keep the shape unchanged
148
+ return SerialEnv(*args, **kwargs)
149
+
150
+ # Wrap lambda functions with EnvCreator so they can be pickled for
151
+ # multiprocessing with the spawn start method. Lambda functions cannot
152
+ # be serialized with standard pickle, but EnvCreator uses cloudpickle.
153
+ auto_wrap_envs = kwargs.pop("auto_wrap_envs", True)
154
+
155
+ def _warn_lambda():
156
+ if rl_warnings():
157
+ warnings.warn(
158
+ "A lambda function was passed to ParallelEnv and will be wrapped "
159
+ "in an EnvCreator. This causes the environment to be instantiated "
160
+ "in the main process to extract metadata. Consider using "
161
+ "functools.partial instead, which is natively serializable and "
162
+ "avoids this overhead. To suppress this warning, set the "
163
+ "RL_WARNINGS=0 environment variable.",
164
+ category=UserWarning,
165
+ stacklevel=4,
166
+ )
167
+
168
+ def _wrap_lambdas(create_env_fn):
169
+ if callable(create_env_fn) and _is_unpicklable_lambda(create_env_fn):
170
+ _warn_lambda()
171
+ return EnvCreator(create_env_fn)
172
+ if isinstance(create_env_fn, Sequence):
173
+ # Reuse EnvCreator for identical function objects to preserve
174
+ # _single_task detection (e.g., when [lambda_fn] * 3 is passed)
175
+ wrapped = {}
176
+ result = []
177
+ warned = False
178
+ for fn in create_env_fn:
179
+ if _is_unpicklable_lambda(fn):
180
+ fn_id = id(fn)
181
+ if fn_id not in wrapped:
182
+ if not warned:
183
+ _warn_lambda()
184
+ warned = True
185
+ wrapped[fn_id] = EnvCreator(fn)
186
+ result.append(wrapped[fn_id])
187
+ else:
188
+ result.append(fn)
189
+ return result
190
+ return create_env_fn
191
+
192
+ if auto_wrap_envs:
193
+ if "create_env_fn" in kwargs:
194
+ kwargs["create_env_fn"] = _wrap_lambdas(kwargs["create_env_fn"])
195
+ elif len(args) >= 2:
196
+ args = (args[0], _wrap_lambdas(args[1])) + args[2:]
197
+
198
+ return super().__call__(*args, **kwargs)
199
+
200
+
201
+ class BatchedEnvBase(EnvBase):
202
+ """Batched environments allow the user to query an arbitrary method / attribute of the environment running remotely.
203
+
204
+ Those queries will return a list of length equal to the number of workers containing the
205
+ values resulting from those queries.
206
+
207
+ Example:
208
+ >>> env = ParallelEnv(3, my_env_fun)
209
+ >>> custom_attribute_list = env.custom_attribute
210
+ >>> custom_method_list = env.custom_method(*args)
211
+
212
+ Args:
213
+ num_workers: number of workers (i.e. env instances) to be deployed simultaneously;
214
+ create_env_fn (callable or list of callables): function (or list of functions) to be used for the environment
215
+ creation.
216
+ If a single task is used, a callable should be used and not a list of identical callables:
217
+ if a list of callable is provided, the environment will be executed as if multiple, diverse tasks were
218
+ needed, which comes with a slight compute overhead;
219
+
220
+ Keyword Args:
221
+ create_env_kwargs (dict or list of dicts, optional): kwargs to be used with the environments being created;
222
+ share_individual_td (bool, optional): if ``True``, a different tensordict is created for every process/worker and a lazy
223
+ stack is returned.
224
+ default = None (False if single task);
225
+ shared_memory (bool): whether the returned tensordict will be placed in shared memory;
226
+ memmap (bool): whether the returned tensordict will be placed in memory map.
227
+ policy_proof (callable, optional): if provided, it'll be used to get the list of
228
+ tensors to return through the :obj:`step()` and :obj:`reset()` methods, such as :obj:`"hidden"` etc.
229
+ device (str, int, torch.device): The device of the batched environment can be passed.
230
+ If not, it is inferred from the env. In this case, it is assumed that
231
+ the device of all environments match. If it is provided, it can differ
232
+ from the sub-environment device(s). In that case, the data will be
233
+ automatically cast to the appropriate device during collection.
234
+ This can be used to speed up collection in case casting to device
235
+ introduces an overhead (eg, numpy-based environents etc.): by using
236
+ a ``"cuda"`` device for the batched environment but a ``"cpu"``
237
+ device for the nested environments, one can keep the overhead to a
238
+ minimum.
239
+ num_threads (int, optional): number of threads for this process.
240
+ Should be equal to one plus the number of processes launched within
241
+ each subprocess (or one if a single process is launched).
242
+ Defaults to the number of workers + 1.
243
+ This parameter has no effect for the :class:`~SerialEnv` class.
244
+ num_sub_threads (int, optional): number of threads of the subprocesses.
245
+ Defaults to 1 for safety: if none is indicated, launching multiple
246
+ workers may charge the cpu load too much and harm performance.
247
+ This parameter has no effect for the :class:`~SerialEnv` class.
248
+ serial_for_single (bool, optional): if ``True``, creating a parallel environment
249
+ with a single worker will return a :class:`~SerialEnv` instead.
250
+ This option has no effect with :class:`~SerialEnv`. Defaults to ``False``.
251
+ non_blocking (bool, optional): if ``True``, device moves will be done using the
252
+ ``non_blocking=True`` option. Defaults to ``True``.
253
+ mp_start_method (str, optional): the multiprocessing start method.
254
+ Uses the default start method if not indicated ('spawn' by default in
255
+ TorchRL if not initiated differently before first import).
256
+ To be used only with :class:`~torchrl.envs.ParallelEnv` subclasses.
257
+ use_buffers (bool, optional): whether communication between workers should
258
+ occur via circular preallocated memory buffers. Defaults to ``True`` unless
259
+ one of the environment has dynamic specs.
260
+
261
+ .. note:: Learn more about dynamic specs and environments :ref:`here <dynamic_envs>`.
262
+ daemon (bool, optional): whether the processes should be daemonized.
263
+ This is only applicable to parallel environments such as :class:`~torchrl.envs.ParallelEnv`.
264
+ Defaults to ``False``.
265
+ auto_wrap_envs (bool, optional): if ``True`` (default), lambda functions passed as
266
+ ``create_env_fn`` will be automatically wrapped in an :class:`~torchrl.envs.EnvCreator`
267
+ to enable pickling for multiprocessing with the ``spawn`` start method.
268
+ This wrapping causes the environment to be instantiated once in the main process
269
+ (to extract metadata) before workers are started.
270
+ If this is undesirable, set ``auto_wrap_envs=False``. Otherwise, ensure your callable is
271
+ serializable (e.g., use :func:`functools.partial` instead of lambdas).
272
+ This parameter only affects :class:`~torchrl.envs.ParallelEnv`.
273
+ Defaults to ``True``.
274
+
275
+ .. note::
276
+ For :class:`~torchrl.envs.ParallelEnv`, it is recommended to use :func:`functools.partial`
277
+ instead of lambda functions when possible, as ``partial`` objects are natively serializable
278
+ and avoid the overhead of :class:`~torchrl.envs.EnvCreator` wrapping.
279
+
280
+ .. note::
281
+ One can pass keyword arguments to each sub-environments using the following
282
+ technique: every keyword argument in :meth:`reset` will be passed to each
283
+ environment except for the ``list_of_kwargs`` argument which, if present,
284
+ should contain a list of the same length as the number of workers with the
285
+ worker-specific keyword arguments stored in a dictionary.
286
+ If a partial reset is queried, the element of ``list_of_kwargs`` corresponding
287
+ to sub-environments that are not reset will be ignored.
288
+
289
+ Examples:
290
+ >>> from torchrl.envs import GymEnv, ParallelEnv, SerialEnv, EnvCreator
291
+ >>> make_env = EnvCreator(lambda: GymEnv("Pendulum-v1")) # EnvCreator ensures that the env is sharable. Optional in most cases.
292
+ >>> env = SerialEnv(2, make_env) # Makes 2 identical copies of the Pendulum env, runs them on the same process serially
293
+ >>> env = ParallelEnv(2, make_env) # Makes 2 identical copies of the Pendulum env, runs them on dedicated processes
294
+ >>> from torchrl.envs import DMControlEnv
295
+ >>> env = ParallelEnv(2, [
296
+ ... lambda: DMControlEnv("humanoid", "stand"),
297
+ ... lambda: DMControlEnv("humanoid", "walk")]) # Creates two independent copies of Humanoid, one that walks one that stands
298
+ >>> rollout = env.rollout(10) # executes 10 random steps in the environment
299
+ >>> rollout[0] # data for Humanoid stand
300
+ TensorDict(
301
+ fields={
302
+ action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
303
+ com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
304
+ done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
305
+ extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False),
306
+ head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
307
+ joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
308
+ next: TensorDict(
309
+ fields={
310
+ com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
311
+ done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
312
+ extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False),
313
+ head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
314
+ joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
315
+ reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
316
+ terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
317
+ torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
318
+ truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
319
+ velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)},
320
+ batch_size=torch.Size([10]),
321
+ device=cpu,
322
+ is_shared=False),
323
+ terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
324
+ torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
325
+ truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
326
+ velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)},
327
+ batch_size=torch.Size([10]),
328
+ device=cpu,
329
+ is_shared=False)
330
+ >>> rollout[1] # data for Humanoid walk
331
+ TensorDict(
332
+ fields={
333
+ action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
334
+ com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
335
+ done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
336
+ extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False),
337
+ head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
338
+ joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
339
+ next: TensorDict(
340
+ fields={
341
+ com_velocity: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
342
+ done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
343
+ extremities: Tensor(shape=torch.Size([10, 12]), device=cpu, dtype=torch.float64, is_shared=False),
344
+ head_height: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
345
+ joint_angles: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
346
+ reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float64, is_shared=False),
347
+ terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
348
+ torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
349
+ truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
350
+ velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)},
351
+ batch_size=torch.Size([10]),
352
+ device=cpu,
353
+ is_shared=False),
354
+ terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
355
+ torso_vertical: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float64, is_shared=False),
356
+ truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
357
+ velocity: Tensor(shape=torch.Size([10, 27]), device=cpu, dtype=torch.float64, is_shared=False)},
358
+ batch_size=torch.Size([10]),
359
+ device=cpu,
360
+ is_shared=False)
361
+ >>> # serial_for_single to avoid creating parallel envs if not necessary
362
+ >>> env = ParallelEnv(1, make_env, serial_for_single=True)
363
+ >>> assert isinstance(env, SerialEnv) # serial_for_single allows you to avoid creating parallel envs when not necessary
364
+ """
365
+
366
+ _verbose: bool = VERBOSE
367
+ _excluded_wrapped_keys = [
368
+ "is_closed",
369
+ "parent_channels",
370
+ "batch_size",
371
+ "_dummy_env_str",
372
+ ]
373
+
374
+ def __init__(
375
+ self,
376
+ num_workers: int,
377
+ create_env_fn: Callable[[], EnvBase] | Sequence[Callable[[], EnvBase]],
378
+ *,
379
+ create_env_kwargs: dict | Sequence[dict] = None,
380
+ pin_memory: bool = False,
381
+ share_individual_td: bool | None = None,
382
+ shared_memory: bool = True,
383
+ memmap: bool = False,
384
+ policy_proof: Callable | None = None,
385
+ device: DEVICE_TYPING | None = None,
386
+ allow_step_when_done: bool = False,
387
+ num_threads: int | None = None,
388
+ num_sub_threads: int = 1,
389
+ serial_for_single: bool = False,
390
+ non_blocking: bool = False,
391
+ mp_start_method: str | None = None,
392
+ use_buffers: bool | None = None,
393
+ consolidate: bool = True,
394
+ daemon: bool = False,
395
+ ):
396
+ super().__init__(device=device)
397
+ self.serial_for_single = serial_for_single
398
+ self.is_closed = True
399
+ self.num_sub_threads = num_sub_threads
400
+ self.num_threads = num_threads
401
+ self._cache_in_keys = None
402
+ self._use_buffers = use_buffers
403
+ self.consolidate = consolidate
404
+ self.daemon = daemon
405
+
406
+ self._single_task = callable(create_env_fn) or (len(set(create_env_fn)) == 1)
407
+ if callable(create_env_fn):
408
+ create_env_fn = [create_env_fn for _ in range(num_workers)]
409
+ elif len(create_env_fn) != num_workers:
410
+ raise RuntimeError(
411
+ f"len(create_env_fn) and num_workers mismatch, "
412
+ f"got {len(create_env_fn)} and {num_workers}."
413
+ )
414
+
415
+ create_env_kwargs = {} if create_env_kwargs is None else create_env_kwargs
416
+ if isinstance(create_env_kwargs, Mapping):
417
+ create_env_kwargs = [
418
+ deepcopy(create_env_kwargs) for _ in range(num_workers)
419
+ ]
420
+ elif len(create_env_kwargs) != num_workers:
421
+ raise RuntimeError(
422
+ f"len(create_env_kwargs) and num_workers mismatch, "
423
+ f"got {len(create_env_kwargs)} and {num_workers}."
424
+ )
425
+
426
+ self.policy_proof = policy_proof
427
+ self.num_workers = num_workers
428
+ self.create_env_fn = create_env_fn
429
+ self.create_env_kwargs = create_env_kwargs
430
+ self.pin_memory = pin_memory
431
+ if pin_memory:
432
+ raise ValueError("pin_memory for batched envs is deprecated")
433
+
434
+ # if share_individual_td is None, we will assess later if the output can be stacked
435
+ self.share_individual_td = share_individual_td
436
+ # self._batch_locked = batch_locked
437
+ self._share_memory = shared_memory
438
+ self._memmap = memmap
439
+ self.allow_step_when_done = allow_step_when_done
440
+ if allow_step_when_done:
441
+ raise ValueError("allow_step_when_done is deprecated")
442
+ if self._share_memory and self._memmap:
443
+ raise RuntimeError(
444
+ "memmap and shared memory are mutually exclusive features."
445
+ )
446
+ self._batch_size = None
447
+ self._device = (
448
+ _make_ordinal_device(torch.device(device)) if device is not None else device
449
+ )
450
+ self._dummy_env_str = None
451
+ self._seeds = None
452
+ self.__dict__["_input_spec"] = None
453
+ self.__dict__["_output_spec"] = None
454
+ # self._prepare_dummy_env(create_env_fn, create_env_kwargs)
455
+ self._properties_set = False
456
+ self._get_metadata(create_env_fn, create_env_kwargs)
457
+ self._non_blocking = non_blocking
458
+ if mp_start_method is not None and not isinstance(self, ParallelEnv):
459
+ raise TypeError(
460
+ f"Cannot use mp_start_method={mp_start_method} with envs of type {type(self)}."
461
+ )
462
+ self._mp_start_method = mp_start_method
463
+
464
+ is_spec_locked = EnvBase.is_spec_locked
465
+
466
+ def configure_parallel(
467
+ self,
468
+ *,
469
+ use_buffers: bool | None = None,
470
+ shared_memory: bool | None = None,
471
+ memmap: bool | None = None,
472
+ mp_start_method: str | None = None,
473
+ num_threads: int | None = None,
474
+ num_sub_threads: int | None = None,
475
+ non_blocking: bool | None = None,
476
+ daemon: bool | None = None,
477
+ ) -> BatchedEnvBase:
478
+ """Configure parallel execution parameters before the environment starts.
479
+
480
+ This method allows configuring parameters for parallel environment
481
+ execution. It must be called before the environment is started
482
+ (i.e., before accessing specs or calling reset/step).
483
+
484
+ Args:
485
+ use_buffers (bool, optional): whether communication between workers should
486
+ occur via circular preallocated memory buffers.
487
+ shared_memory (bool, optional): whether the returned tensordict will be
488
+ placed in shared memory.
489
+ memmap (bool, optional): whether the returned tensordict will be placed
490
+ in memory map.
491
+ mp_start_method (str, optional): the multiprocessing start method.
492
+ num_threads (int, optional): number of threads for this process.
493
+ num_sub_threads (int, optional): number of threads of the subprocesses.
494
+ non_blocking (bool, optional): if ``True``, device moves will be done using
495
+ the ``non_blocking=True`` option.
496
+ daemon (bool, optional): whether the processes should be daemonized.
497
+
498
+ Returns:
499
+ self: Returns self for method chaining.
500
+
501
+ Raises:
502
+ RuntimeError: If called after the environment has already started.
503
+
504
+ Example:
505
+ >>> env = ParallelEnv(4, lambda: GymEnv("Pendulum-v1"))
506
+ >>> env.configure_parallel(use_buffers=True, num_threads=2)
507
+ >>> env.reset() # Environment starts here
508
+
509
+ """
510
+ if not self.is_closed:
511
+ raise RuntimeError(
512
+ "configure_parallel() cannot be called after the environment has started. "
513
+ "Call configure_parallel() before accessing specs or calling reset/step."
514
+ )
515
+ if use_buffers is not None:
516
+ self._use_buffers = use_buffers
517
+ if shared_memory is not None:
518
+ self._share_memory = shared_memory
519
+ if memmap is not None:
520
+ self._memmap = memmap
521
+ if mp_start_method is not None:
522
+ self._mp_start_method = mp_start_method
523
+ if num_threads is not None:
524
+ self.num_threads = num_threads
525
+ if num_sub_threads is not None:
526
+ self.num_sub_threads = num_sub_threads
527
+ if non_blocking is not None:
528
+ self._non_blocking = non_blocking
529
+ if daemon is not None:
530
+ self.daemon = daemon
531
+ return self
532
+
533
+ def select_and_clone(self, name, tensor, selected_keys=None):
534
+ if selected_keys is None:
535
+ selected_keys = self._selected_step_keys
536
+ if name in selected_keys:
537
+ if self.device is not None and tensor.device != self.device:
538
+ return tensor.to(self.device, non_blocking=self.non_blocking)
539
+ return tensor.clone()
540
+
541
+ @property
542
+ def non_blocking(self):
543
+ nb = self._non_blocking
544
+ if nb is None:
545
+ nb = True
546
+ self._non_blocking = nb
547
+ return nb
548
+
549
+ @property
550
+ def _sync_m2w(self) -> Callable:
551
+ sync_func = self.__dict__.get("_sync_m2w_value")
552
+ if sync_func is None:
553
+ sync_m2w, sync_w2m = self._find_sync_values()
554
+ self.__dict__["_sync_m2w_value"] = sync_m2w
555
+ self.__dict__["_sync_w2m_value"] = sync_w2m
556
+ return sync_m2w
557
+ return sync_func
558
+
559
+ @property
560
+ def _sync_w2m(self) -> Callable:
561
+ sync_func = self.__dict__.get("_sync_w2m_value")
562
+ if sync_func is None:
563
+ sync_m2w, sync_w2m = self._find_sync_values()
564
+ self.__dict__["_sync_m2w_value"] = sync_m2w
565
+ self.__dict__["_sync_w2m_value"] = sync_w2m
566
+ return sync_w2m
567
+ return sync_func
568
+
569
+ def _find_sync_values(self):
570
+ """Returns the m2w and w2m sync values, in that order."""
571
+ if not self._use_buffers:
572
+ return _do_nothing, _do_nothing
573
+ # Simplest case: everything is on the same device
574
+ worker_device = self.shared_tensordict_parent.device
575
+ self_device = self.device
576
+ if not self.non_blocking or (
577
+ worker_device == self_device or self_device is None
578
+ ):
579
+ # even if they're both None, there is no device-to-device movement
580
+ return _do_nothing, _do_nothing
581
+
582
+ if worker_device is None:
583
+ worker_not_main = False
584
+
585
+ def find_all_worker_devices(item):
586
+ nonlocal worker_not_main
587
+ if hasattr(item, "device"):
588
+ worker_not_main = worker_not_main or (item.device != self_device)
589
+
590
+ for td in self.shared_tensordicts:
591
+ td.apply(find_all_worker_devices, filter_empty=True)
592
+ if worker_not_main:
593
+ if torch.cuda.is_available():
594
+ worker_device = (
595
+ torch.device("cuda")
596
+ if self_device.type != "cuda"
597
+ else torch.device("cpu")
598
+ )
599
+ elif torch.backends.mps.is_available():
600
+ worker_device = (
601
+ torch.device("mps")
602
+ if self_device.type != "mps"
603
+ else torch.device("cpu")
604
+ )
605
+ else:
606
+ raise RuntimeError("Did not find a valid worker device")
607
+ else:
608
+ worker_device = self_device
609
+
610
+ if (
611
+ worker_device is not None
612
+ and worker_device.type == "cuda"
613
+ and self_device is not None
614
+ and self_device.type == "cpu"
615
+ ):
616
+ return _do_nothing, _cuda_sync(worker_device)
617
+ if (
618
+ worker_device is not None
619
+ and worker_device.type == "mps"
620
+ and self_device is not None
621
+ and self_device.type == "cpu"
622
+ ):
623
+ return _mps_sync(worker_device), _mps_sync(worker_device)
624
+ if (
625
+ worker_device is not None
626
+ and worker_device.type == "cpu"
627
+ and self_device is not None
628
+ and self_device.type == "cuda"
629
+ ):
630
+ return _cuda_sync(self_device), _do_nothing
631
+ if (
632
+ worker_device is not None
633
+ and worker_device.type == "cpu"
634
+ and self_device is not None
635
+ and self_device.type == "mps"
636
+ ):
637
+ return _mps_sync(self_device), _mps_sync(self_device)
638
+ return _do_nothing, _do_nothing
639
+
640
+ def __getstate__(self):
641
+ out = self.__dict__.copy()
642
+ out["_sync_m2w_value"] = None
643
+ out["_sync_w2m_value"] = None
644
+ return out
645
+
646
+ @property
647
+ def _has_dynamic_specs(self):
648
+ return not self._use_buffers
649
+
650
+ def _get_metadata(
651
+ self, create_env_fn: list[Callable], create_env_kwargs: list[dict]
652
+ ):
653
+ if self._single_task:
654
+ # if EnvCreator, the metadata are already there
655
+ meta_data: EnvMetaData = get_env_metadata(
656
+ create_env_fn[0], create_env_kwargs[0]
657
+ )
658
+ self.meta_data = meta_data.expand(
659
+ *(self.num_workers, *meta_data.batch_size)
660
+ )
661
+ if self._use_buffers is not False:
662
+ _use_buffers = not self.meta_data.has_dynamic_specs
663
+ if self._use_buffers and not _use_buffers:
664
+ warn(
665
+ "A value of use_buffers=True was passed but this is incompatible "
666
+ "with the list of environments provided. Turning use_buffers to False."
667
+ )
668
+ self._use_buffers = _use_buffers
669
+ if self.share_individual_td is None:
670
+ self.share_individual_td = False
671
+ else:
672
+ n_tasks = len(create_env_fn)
673
+ self.meta_data: list[EnvMetaData] = []
674
+ for i in range(n_tasks):
675
+ self.meta_data.append(
676
+ get_env_metadata(create_env_fn[i], create_env_kwargs[i]).clone()
677
+ )
678
+ if self.share_individual_td is not True:
679
+ share_individual_td = not _stackable(
680
+ *[meta_data.tensordict for meta_data in self.meta_data]
681
+ )
682
+ if share_individual_td and self.share_individual_td is False:
683
+ raise ValueError(
684
+ "share_individual_td=False was provided but share_individual_td must "
685
+ "be True to accommodate non-stackable tensors."
686
+ )
687
+ self.share_individual_td = share_individual_td
688
+ _use_buffers = all(
689
+ not metadata.has_dynamic_specs for metadata in self.meta_data
690
+ )
691
+ if self._use_buffers and not _use_buffers:
692
+ warn(
693
+ "A value of use_buffers=True was passed but this is incompatible "
694
+ "with the list of environments provided. Turning use_buffers to False."
695
+ )
696
+ self._use_buffers = _use_buffers
697
+
698
+ self._set_properties()
699
+
700
+ def update_kwargs(self, kwargs: dict | list[dict]) -> None:
701
+ """Updates the kwargs of each environment given a dictionary or a list of dictionaries.
702
+
703
+ Args:
704
+ kwargs (dict or list of dict): new kwargs to use with the environments
705
+
706
+ """
707
+ if isinstance(kwargs, dict):
708
+ for _kwargs in self.create_env_kwargs:
709
+ _kwargs.update(kwargs)
710
+ else:
711
+ if len(kwargs) != self.num_workers:
712
+ raise RuntimeError(
713
+ f"len(kwargs) and num_workers mismatch, got {len(kwargs)} and {self.num_workers}."
714
+ )
715
+ for _kwargs, _new_kwargs in _zip_strict(self.create_env_kwargs, kwargs):
716
+ _kwargs.update(_new_kwargs)
717
+
718
+ def _get_in_keys_to_exclude(self, tensordict):
719
+ if self._cache_in_keys is None:
720
+ self._cache_in_keys = list(
721
+ set(self.input_spec.keys(True)).intersection(
722
+ tensordict.keys(True, True)
723
+ )
724
+ )
725
+ return self._cache_in_keys
726
+
727
+ def _set_properties(self):
728
+
729
+ cls = type(self)
730
+
731
+ def _check_for_empty_spec(specs: Composite):
732
+ for subspec in (
733
+ "full_state_spec",
734
+ "full_action_spec",
735
+ "full_done_spec",
736
+ "full_reward_spec",
737
+ "full_observation_spec",
738
+ ):
739
+ for key, spec in reversed(
740
+ list(specs.get(subspec, default=Composite()).items(True))
741
+ ):
742
+ if isinstance(spec, Composite) and spec.is_empty():
743
+ raise RuntimeError(
744
+ f"The environment passed to {cls.__name__} has empty specs in {key}. Consider using "
745
+ f"torchrl.envs.transforms.RemoveEmptySpecs to remove the empty specs."
746
+ )
747
+ return specs
748
+
749
+ meta_data = self.meta_data
750
+ self._properties_set = True
751
+ if self._single_task:
752
+ self._batch_size = meta_data.batch_size
753
+ device = meta_data.device
754
+ if self._device is None:
755
+ self._device = device
756
+
757
+ input_spec = _check_for_empty_spec(meta_data.specs["input_spec"].to(device))
758
+ output_spec = _check_for_empty_spec(
759
+ meta_data.specs["output_spec"].to(device)
760
+ )
761
+
762
+ self.action_spec = input_spec["full_action_spec"]
763
+ self.state_spec = input_spec["full_state_spec"]
764
+ self.observation_spec = output_spec["full_observation_spec"]
765
+ self.reward_spec = output_spec["full_reward_spec"]
766
+ self.done_spec = output_spec["full_done_spec"]
767
+
768
+ self._dummy_env_str = meta_data.env_str
769
+ self._env_tensordict = meta_data.tensordict
770
+ if device is None: # In other cases, the device will be mapped later
771
+ self._env_tensordict.clear_device_()
772
+ device_map = meta_data.device_map
773
+
774
+ def map_device(key, value, device_map=device_map):
775
+ return value.to(device_map[key])
776
+
777
+ self._env_tensordict.named_apply(
778
+ map_device, nested_keys=True, filter_empty=True
779
+ )
780
+ # if self._batch_locked is None:
781
+ # self._batch_locked = meta_data.batch_locked
782
+ else:
783
+ self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size])
784
+ devices = set()
785
+ for _meta_data in meta_data:
786
+ device = _meta_data.device
787
+ devices.add(device)
788
+ if self._device is None:
789
+ if len(devices) > 1:
790
+ raise ValueError(
791
+ f"The device wasn't passed to {type(self)}, but more than one device was found in the sub-environments. "
792
+ f"Please indicate a device to be used for collection."
793
+ )
794
+ device = list(devices)[0]
795
+ self._device = device
796
+
797
+ input_spec = []
798
+ for md in meta_data:
799
+ input_spec.append(_check_for_empty_spec(md.specs["input_spec"]))
800
+ input_spec = torch.stack(input_spec, 0)
801
+ output_spec = []
802
+ for md in meta_data:
803
+ output_spec.append(_check_for_empty_spec(md.specs["output_spec"]))
804
+ output_spec = torch.stack(output_spec, 0)
805
+
806
+ self.action_spec = input_spec["full_action_spec"]
807
+ self.state_spec = input_spec["full_state_spec"]
808
+
809
+ self.observation_spec = output_spec["full_observation_spec"]
810
+ self.reward_spec = output_spec["full_reward_spec"]
811
+ self.done_spec = output_spec["full_done_spec"]
812
+
813
+ self._dummy_env_str = str(meta_data[0])
814
+ if self.share_individual_td:
815
+ self._env_tensordict = LazyStackedTensorDict.lazy_stack(
816
+ [meta_data.tensordict for meta_data in meta_data], 0
817
+ )
818
+ else:
819
+ self._env_tensordict = torch.stack(
820
+ [meta_data.tensordict for meta_data in meta_data], 0
821
+ )
822
+ # if self._batch_locked is None:
823
+ # self._batch_locked = meta_data[0].batch_locked
824
+ self.has_lazy_inputs = contains_lazy_spec(self.input_spec)
825
+
826
+ def state_dict(self) -> OrderedDict:
827
+ raise NotImplementedError
828
+
829
+ def load_state_dict(self, state_dict: OrderedDict) -> None:
830
+ raise NotImplementedError
831
+
832
+ batch_size = lazy_property(EnvBase.batch_size)
833
+ device = lazy_property(EnvBase.device)
834
+ input_spec = lazy_property(EnvBase.input_spec)
835
+ output_spec = lazy_property(EnvBase.output_spec)
836
+
837
+ def _create_td(self) -> None:
838
+ """Creates self.shared_tensordict_parent, a TensorDict used to store the most recent observations."""
839
+ if not self._use_buffers:
840
+ return
841
+ shared_tensordict_parent = self._env_tensordict.clone()
842
+ if self._env_tensordict.shape[0] != self.num_workers:
843
+ raise RuntimeError(
844
+ "batched environment base tensordict has the wrong shape"
845
+ )
846
+
847
+ # Non-tensor keys
848
+ non_tensor_keys = []
849
+ for spec in (
850
+ self.full_action_spec,
851
+ self.full_state_spec,
852
+ self.full_observation_spec,
853
+ self.full_reward_spec,
854
+ self.full_done_spec,
855
+ ):
856
+ for key, _spec in spec.items(True, True):
857
+ if isinstance(_spec, NonTensor):
858
+ non_tensor_keys.append(key)
859
+ self._non_tensor_keys = non_tensor_keys
860
+
861
+ if self._single_task:
862
+ self._env_input_keys = sorted(
863
+ list(self.input_spec["full_action_spec"].keys(True, True))
864
+ + list(self.state_spec.keys(True, True)),
865
+ key=_sort_keys,
866
+ )
867
+ self._env_output_keys = []
868
+ self._env_obs_keys = []
869
+ for key in self.output_spec["full_observation_spec"].keys(True, True):
870
+ self._env_output_keys.append(key)
871
+ self._env_obs_keys.append(key)
872
+ self._env_output_keys += self.reward_keys + self.done_keys
873
+ else:
874
+ # this is only possible if _single_task=False
875
+ env_input_keys = set()
876
+ for meta_data in self.meta_data:
877
+ if meta_data.specs["input_spec", "full_state_spec"] is not None:
878
+ env_input_keys = env_input_keys.union(
879
+ meta_data.specs["input_spec", "full_state_spec"].keys(
880
+ True, True
881
+ )
882
+ )
883
+ env_input_keys = env_input_keys.union(
884
+ meta_data.specs["input_spec", "full_action_spec"].keys(True, True)
885
+ )
886
+ env_output_keys = set()
887
+ env_obs_keys = set()
888
+ for meta_data in self.meta_data:
889
+ keys = meta_data.specs["output_spec"]["full_observation_spec"].keys(
890
+ True, True
891
+ )
892
+ keys = list(keys)
893
+ env_obs_keys = env_obs_keys.union(keys)
894
+
895
+ env_output_keys = env_output_keys.union(keys)
896
+ env_output_keys = env_output_keys.union(self.reward_keys + self.done_keys)
897
+ self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys)
898
+ self._env_input_keys = sorted(env_input_keys, key=_sort_keys)
899
+ self._env_output_keys = sorted(env_output_keys, key=_sort_keys)
900
+
901
+ self._env_obs_keys = [
902
+ key for key in self._env_obs_keys if key not in self._non_tensor_keys
903
+ ]
904
+ self._env_input_keys = [
905
+ key for key in self._env_input_keys if key not in self._non_tensor_keys
906
+ ]
907
+ self._env_output_keys = [
908
+ key for key in self._env_output_keys if key not in self._non_tensor_keys
909
+ ]
910
+
911
+ reset_keys = self.reset_keys
912
+ self._selected_keys = (
913
+ set(self._env_output_keys)
914
+ .union(self._env_input_keys)
915
+ .union(self._env_obs_keys)
916
+ .union(set(self.done_keys))
917
+ )
918
+ self._selected_keys = self._selected_keys.union(reset_keys)
919
+
920
+ # input keys
921
+ self._selected_input_keys = {unravel_key(key) for key in self._env_input_keys}
922
+ # output keys after reset
923
+ self._selected_reset_keys = {
924
+ unravel_key(key) for key in self._env_obs_keys + self.done_keys + reset_keys
925
+ }
926
+ # output keys after reset, filtered
927
+ self._selected_reset_keys_filt = {
928
+ unravel_key(key) for key in self._env_obs_keys + self.done_keys
929
+ }
930
+ # output keys after step
931
+ self._selected_step_keys = {unravel_key(key) for key in self._env_output_keys}
932
+
933
+ if not self.share_individual_td:
934
+ shared_tensordict_parent = shared_tensordict_parent.filter_non_tensor_data()
935
+ shared_tensordict_parent = shared_tensordict_parent.select(
936
+ *self._selected_keys,
937
+ *(unravel_key(("next", key)) for key in self._env_output_keys),
938
+ strict=False,
939
+ )
940
+ self.shared_tensordict_parent = shared_tensordict_parent
941
+ else:
942
+ # Multi-task: we share tensordict that *may* have different keys
943
+ shared_tensordict_parent = [
944
+ tensordict.select(
945
+ *self._selected_keys,
946
+ *(unravel_key(("next", key)) for key in self._env_output_keys),
947
+ strict=False,
948
+ ).filter_non_tensor_data()
949
+ for tensordict in shared_tensordict_parent
950
+ ]
951
+ shared_tensordict_parent = LazyStackedTensorDict.lazy_stack(
952
+ shared_tensordict_parent,
953
+ 0,
954
+ )
955
+ self.shared_tensordict_parent = shared_tensordict_parent
956
+
957
+ if self.share_individual_td:
958
+ if not isinstance(self.shared_tensordict_parent, LazyStackedTensorDict):
959
+ self.shared_tensordicts = [
960
+ td.clone() for td in self.shared_tensordict_parent.unbind(0)
961
+ ]
962
+ self.shared_tensordict_parent = LazyStackedTensorDict.lazy_stack(
963
+ self.shared_tensordicts, 0
964
+ )
965
+ else:
966
+ # Multi-task: we share tensordict that *may* have different keys
967
+ # LazyStacked already stores this so we don't need to do anything
968
+ self.shared_tensordicts = self.shared_tensordict_parent
969
+ if self._share_memory:
970
+ self.shared_tensordict_parent.share_memory_()
971
+ elif self._memmap:
972
+ self.shared_tensordict_parent.memmap_()
973
+ else:
974
+ if self._share_memory:
975
+ self.shared_tensordict_parent.share_memory_()
976
+ if not self.shared_tensordict_parent.is_shared():
977
+ raise RuntimeError("share_memory_() failed")
978
+ elif self._memmap:
979
+ self.shared_tensordict_parent.memmap_()
980
+ if not self.shared_tensordict_parent.is_memmap():
981
+ raise RuntimeError("memmap_() failed")
982
+ self.shared_tensordicts = self.shared_tensordict_parent.unbind(0)
983
+ for td in self.shared_tensordicts:
984
+ td.lock_()
985
+
986
+ # we cache all the keys of the shared parent td for future use. This is
987
+ # safe since the td is locked.
988
+ self._cache_shared_keys = set(self.shared_tensordict_parent.keys(True, True))
989
+
990
+ self._shared_tensordict_parent_next = self.shared_tensordict_parent.get("next")
991
+ self._shared_tensordict_parent_root = self.shared_tensordict_parent.exclude(
992
+ "next", *self.reset_keys
993
+ )
994
+
995
+ def _start_workers(self) -> None:
996
+ """Starts the various envs."""
997
+ raise NotImplementedError
998
+
999
+ def __repr__(self) -> str:
1000
+ if self._dummy_env_str is None:
1001
+ self._dummy_env_str = self._set_properties()
1002
+ return (
1003
+ f"{self.__class__.__name__}("
1004
+ f"\n\tenv={self._dummy_env_str}, "
1005
+ f"\n\tbatch_size={self.batch_size})"
1006
+ )
1007
+
1008
+ def close(self, *, raise_if_closed: bool = True) -> None:
1009
+ if self.is_closed:
1010
+ if raise_if_closed:
1011
+ raise RuntimeError("trying to close a closed environment")
1012
+ else:
1013
+ return
1014
+ if self._verbose:
1015
+ torchrl_logger.info(f"closing {self.__class__.__name__}")
1016
+
1017
+ self.__dict__["_input_spec"] = None
1018
+ self.__dict__["_output_spec"] = None
1019
+ self._properties_set = False
1020
+
1021
+ self._shutdown_workers()
1022
+ self.is_closed = True
1023
+ import torchrl
1024
+
1025
+ num_threads = min(
1026
+ torchrl._THREAD_POOL_INIT, torch.get_num_threads() + self.num_workers
1027
+ )
1028
+ torch.set_num_threads(num_threads)
1029
+
1030
+ def _shutdown_workers(self) -> None:
1031
+ raise NotImplementedError
1032
+
1033
+ def _set_seed(self, seed: int | None) -> None:
1034
+ """This method is not used in batched envs."""
1035
+
1036
+ @lazy
1037
+ def start(self) -> None:
1038
+ if not self.is_closed:
1039
+ raise RuntimeError("trying to start a environment that is not closed.")
1040
+ self._create_td()
1041
+ self._start_workers()
1042
+
1043
+ def to(self, device: DEVICE_TYPING):
1044
+ self._non_blocking = None
1045
+ device = _make_ordinal_device(torch.device(device))
1046
+ if device == self.device:
1047
+ return self
1048
+ self._device = device
1049
+ self.__dict__["_sync_m2w_value"] = None
1050
+ self.__dict__["_sync_w2m_value"] = None
1051
+ if self.__dict__["_input_spec"] is not None:
1052
+ self.__dict__["_input_spec"] = self.__dict__["_input_spec"].to(device)
1053
+ if self.__dict__["_output_spec"] is not None:
1054
+ self.__dict__["_output_spec"] = self.__dict__["_output_spec"].to(device)
1055
+ return self
1056
+
1057
+ def _reset_proc_data(self, tensordict, tensordict_reset):
1058
+ # since we call `reset` directly, all the postproc has been completed
1059
+ if tensordict is not None:
1060
+ if isinstance(tensordict_reset, LazyStackedTensorDict) and not isinstance(
1061
+ tensordict, LazyStackedTensorDict
1062
+ ):
1063
+ tensordict = LazyStackedTensorDict(*tensordict.unbind(0))
1064
+ return _update_during_reset(tensordict_reset, tensordict, self.reset_keys)
1065
+ return tensordict_reset
1066
+
1067
+ def add_truncated_keys(self):
1068
+ raise RuntimeError(
1069
+ "Cannot add truncated keys to a batched environment. Please add these entries to "
1070
+ "the nested environments by calling sub_env.add_truncated_keys()"
1071
+ )
1072
+
1073
+
1074
+ class SerialEnv(BatchedEnvBase):
1075
+ """Creates a series of environments in the same process."""
1076
+
1077
+ __doc__ += BatchedEnvBase.__doc__
1078
+
1079
+ _share_memory = False
1080
+
1081
+ def _start_workers(self) -> None:
1082
+ _num_workers = self.num_workers
1083
+
1084
+ self._envs = []
1085
+ weakref_set = set()
1086
+ for idx in range(_num_workers):
1087
+ env = self.create_env_fn[idx](**self.create_env_kwargs[idx])
1088
+ # We want to avoid having the same env multiple times
1089
+ # so we try to deepcopy it if needed. If we can't, we make
1090
+ # the user aware that this isn't a very good idea
1091
+ wr = weakref.ref(env)
1092
+ if wr in weakref_set:
1093
+ try:
1094
+ env = deepcopy(env)
1095
+ except Exception:
1096
+ warn(
1097
+ "Deepcopying the env failed within SerialEnv "
1098
+ "but more than one copy of the same env was found. "
1099
+ "This is a dangerous situation if your env keeps track "
1100
+ "of some variables (e.g., state) in-place. "
1101
+ "We'll use the same copy of the environment be beaware that "
1102
+ "this may have important, unwanted issues for stateful "
1103
+ "environments!"
1104
+ )
1105
+ weakref_set.add(wr)
1106
+ self._envs.append(env.set_spec_lock_())
1107
+ self.is_closed = False
1108
+ self.set_spec_lock_()
1109
+
1110
+ @_check_start
1111
+ def state_dict(self) -> OrderedDict:
1112
+ state_dict = OrderedDict()
1113
+ for idx, env in enumerate(self._envs):
1114
+ state_dict[f"worker{idx}"] = env.state_dict()
1115
+
1116
+ return state_dict
1117
+
1118
+ @_check_start
1119
+ def load_state_dict(self, state_dict: OrderedDict) -> None:
1120
+ if "worker0" not in state_dict:
1121
+ state_dict = OrderedDict(
1122
+ **{f"worker{idx}": state_dict for idx in range(self.num_workers)}
1123
+ )
1124
+ for idx, env in enumerate(self._envs):
1125
+ env.load_state_dict(state_dict[f"worker{idx}"])
1126
+
1127
+ def _shutdown_workers(self) -> None:
1128
+ if not self.is_closed:
1129
+ for env in self._envs:
1130
+ env.close()
1131
+ del self._envs
1132
+
1133
+ @_check_start
1134
+ def set_seed(
1135
+ self, seed: int | None = None, static_seed: bool = False
1136
+ ) -> int | None:
1137
+ for env in self._envs:
1138
+ new_seed = env.set_seed(seed, static_seed=static_seed)
1139
+ seed = new_seed
1140
+ return seed
1141
+
1142
+ @_check_start
1143
+ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
1144
+ list_of_kwargs = kwargs.pop("list_of_kwargs", [kwargs] * self.num_workers)
1145
+ if kwargs is not list_of_kwargs[0] and kwargs:
1146
+ # this means that kwargs had more than one element and that a list was provided
1147
+ for elt in list_of_kwargs:
1148
+ elt.update(kwargs)
1149
+ if tensordict is not None:
1150
+ if "_reset" in tensordict.keys():
1151
+ needs_resetting = tensordict["_reset"]
1152
+ else:
1153
+ needs_resetting = _aggregate_end_of_traj(
1154
+ tensordict, reset_keys=self.reset_keys
1155
+ )
1156
+ if needs_resetting.ndim > 2:
1157
+ needs_resetting = needs_resetting.flatten(1, needs_resetting.ndim - 1)
1158
+ if needs_resetting.ndim > 1:
1159
+ needs_resetting = needs_resetting.any(-1)
1160
+ elif not needs_resetting.ndim:
1161
+ needs_resetting = needs_resetting.expand((self.num_workers,))
1162
+ tensordict = tensordict.unbind(0)
1163
+ else:
1164
+ needs_resetting = torch.ones(
1165
+ (self.num_workers,), device=self.device, dtype=torch.bool
1166
+ )
1167
+
1168
+ out_tds = None
1169
+ if not self._use_buffers or self._non_tensor_keys:
1170
+ out_tds = [None] * self.num_workers
1171
+
1172
+ tds = []
1173
+ for i, _env in enumerate(self._envs):
1174
+ if not needs_resetting[i]:
1175
+ if out_tds is not None and tensordict is not None:
1176
+ ftd = _env.observation_spec.zero()
1177
+ if self.device is None:
1178
+ ftd.clear_device_()
1179
+ else:
1180
+ ftd = ftd.to(self.device)
1181
+ out_tds[i] = ftd
1182
+ continue
1183
+ if tensordict is not None:
1184
+ tensordict_ = tensordict[i]
1185
+ if tensordict_.is_empty():
1186
+ tensordict_ = None
1187
+ else:
1188
+ env_device = _env.device
1189
+ if env_device != self.device:
1190
+ if env_device is not None:
1191
+ tensordict_ = tensordict_.to(
1192
+ env_device, non_blocking=self.non_blocking
1193
+ )
1194
+ else:
1195
+ tensordict_ = tensordict_.clear_device_()
1196
+ else:
1197
+ tensordict_ = tensordict_.clone(False)
1198
+ else:
1199
+ tensordict_ = None
1200
+ tds.append((i, tensordict_))
1201
+
1202
+ self._sync_m2w()
1203
+ for i, tensordict_ in tds:
1204
+ _env = self._envs[i]
1205
+ _td = _env.reset(tensordict=tensordict_, **list_of_kwargs[i])
1206
+ if self._use_buffers:
1207
+ try:
1208
+ self.shared_tensordicts[i].update_(
1209
+ _td,
1210
+ keys_to_update=list(self._selected_reset_keys_filt),
1211
+ non_blocking=self.non_blocking,
1212
+ )
1213
+ except RuntimeError as err:
1214
+ if "no_grad mode" in str(err):
1215
+ raise RuntimeError(
1216
+ "Cannot update a view of a tensordict when gradients are required. "
1217
+ "To collect gradient across sub-environments, please set the "
1218
+ "share_individual_td argument to True."
1219
+ )
1220
+ raise
1221
+ if out_tds is not None:
1222
+ out_tds[i] = _td
1223
+
1224
+ device = self.device
1225
+ if not self._use_buffers:
1226
+ result = LazyStackedTensorDict.maybe_dense_stack(out_tds)
1227
+ if result.device != device:
1228
+ if device is None:
1229
+ result = result.clear_device_()
1230
+ else:
1231
+ result = result.to(device, non_blocking=self.non_blocking)
1232
+ self._sync_w2m()
1233
+ return result
1234
+
1235
+ selected_output_keys = self._selected_reset_keys_filt
1236
+
1237
+ # select + clone creates 2 tds, but we can create one only
1238
+ out = self.shared_tensordict_parent.named_apply(
1239
+ lambda *args: self.select_and_clone(
1240
+ *args, selected_keys=selected_output_keys
1241
+ ),
1242
+ nested_keys=True,
1243
+ filter_empty=True,
1244
+ )
1245
+ if out_tds is not None:
1246
+ out.update(
1247
+ LazyStackedTensorDict(*out_tds), keys_to_update=self._non_tensor_keys
1248
+ )
1249
+
1250
+ if out.device != device:
1251
+ if device is None:
1252
+ out = out.clear_device_()
1253
+ else:
1254
+ out = out.to(device, non_blocking=self.non_blocking)
1255
+ self._sync_w2m()
1256
+ return out
1257
+
1258
+ @_check_start
1259
+ def _step(
1260
+ self,
1261
+ tensordict: TensorDict,
1262
+ ) -> TensorDict:
1263
+ partial_steps = tensordict.get("_step")
1264
+ tensordict_save = tensordict
1265
+ if partial_steps is not None and partial_steps.all():
1266
+ partial_steps = None
1267
+ if partial_steps is not None:
1268
+ partial_steps = partial_steps.view(tensordict.shape)
1269
+ tensordict = tensordict[partial_steps]
1270
+ workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
1271
+ tensordict_in = tensordict
1272
+ else:
1273
+ workers_range = range(self.num_workers)
1274
+ tensordict_in = tensordict.copy()
1275
+ # if self._use_buffers:
1276
+ # shared_tensordict_parent = self.shared_tensordict_parent
1277
+
1278
+ data_in = []
1279
+ for i, td_ in zip(workers_range, tensordict_in):
1280
+ # shared_tensordicts are locked, and we need to select the keys since we update in-place.
1281
+ # There may be unexpected keys, such as "_reset", that we should comfortably ignore here.
1282
+ env_device = self._envs[i].device
1283
+ if env_device != self.device:
1284
+ if env_device is not None:
1285
+ data_in.append(td_.to(env_device, non_blocking=self.non_blocking))
1286
+ else:
1287
+ data_in.append(td_.clear_device_())
1288
+ else:
1289
+ data_in.append(td_)
1290
+
1291
+ self._sync_m2w()
1292
+ out_tds = None
1293
+ if not self._use_buffers or self._non_tensor_keys:
1294
+ out_tds = []
1295
+
1296
+ if self._use_buffers:
1297
+ next_td = self.shared_tensordict_parent.get("next")
1298
+ for i, _data_in in zip(workers_range, data_in):
1299
+ out_td = self._envs[i]._step(_data_in)
1300
+ next_td[i].update_(
1301
+ out_td,
1302
+ # _env_output_keys exclude non-tensor data
1303
+ keys_to_update=list(self._env_output_keys),
1304
+ non_blocking=self.non_blocking,
1305
+ )
1306
+ if out_tds is not None:
1307
+ # we store the non-tensor data here
1308
+ out_tds.append(out_td)
1309
+
1310
+ # We must pass a clone of the tensordict, as the values of this tensordict
1311
+ # will be modified in-place at further steps
1312
+ device = self.device
1313
+
1314
+ selected_keys = self._selected_step_keys
1315
+
1316
+ if partial_steps is not None:
1317
+ next_td = TensorDict.lazy_stack([next_td[i] for i in workers_range])
1318
+ out = next_td.named_apply(
1319
+ lambda *args: self.select_and_clone(*args, selected_keys),
1320
+ nested_keys=True,
1321
+ filter_empty=True,
1322
+ )
1323
+ if out_tds is not None:
1324
+ out.update(
1325
+ LazyStackedTensorDict(*out_tds),
1326
+ keys_to_update=self._non_tensor_keys,
1327
+ )
1328
+
1329
+ if out.device != device:
1330
+ if device is None:
1331
+ out = out.clear_device_()
1332
+ elif out.device != device:
1333
+ out = out.to(device, non_blocking=self.non_blocking)
1334
+ self._sync_w2m()
1335
+ else:
1336
+ for i, _data_in in zip(workers_range, data_in):
1337
+ out_td = self._envs[i]._step(_data_in)
1338
+ out_tds.append(out_td)
1339
+ out = LazyStackedTensorDict.maybe_dense_stack(out_tds)
1340
+
1341
+ if partial_steps is not None and not partial_steps.all():
1342
+ result = out.new_zeros(tensordict_save.shape)
1343
+ # Copy the observation data from the previous step as placeholder
1344
+
1345
+ def select_and_clone(x, y):
1346
+ if y is not None:
1347
+ if x.device != y.device:
1348
+ x = x.to(y.device)
1349
+ else:
1350
+ x = x.clone()
1351
+ return x
1352
+
1353
+ prev = tensordict_save._fast_apply(
1354
+ select_and_clone,
1355
+ result,
1356
+ filter_empty=True,
1357
+ device=result.device,
1358
+ batch_size=result.batch_size,
1359
+ is_leaf=_is_leaf_nontensor,
1360
+ default=None,
1361
+ )
1362
+
1363
+ result.update(prev)
1364
+ if partial_steps.any():
1365
+ result[partial_steps] = out
1366
+ return result
1367
+
1368
+ return out
1369
+
1370
+ def __getattr__(self, attr: str) -> Any:
1371
+ if attr in self.__dir__():
1372
+ return super().__getattr__(
1373
+ attr
1374
+ ) # make sure that appropriate exceptions are raised
1375
+ elif attr.startswith("__"):
1376
+ raise AttributeError(
1377
+ "dispatching built-in private methods is "
1378
+ f"not permitted with type {type(self)}. "
1379
+ f"Got attribute {attr}."
1380
+ )
1381
+ else:
1382
+ if attr in self._excluded_wrapped_keys:
1383
+ raise AttributeError(f"Getting {attr} resulted in an exception")
1384
+ try:
1385
+ # determine if attr is a callable
1386
+ list_attr = [getattr(env, attr) for env in self._envs]
1387
+ callable_attr = callable(list_attr[0])
1388
+ if callable_attr:
1389
+ if self.is_closed:
1390
+ raise RuntimeError(
1391
+ "Trying to access attributes of closed/non started "
1392
+ "environments. Check that the batched environment "
1393
+ "has been started (e.g. by calling env.reset)"
1394
+ )
1395
+ return _dispatch_caller_serial(list_attr)
1396
+ else:
1397
+ return list_attr
1398
+ except AttributeError:
1399
+ raise AttributeError(
1400
+ f"attribute {attr} not found in " f"{self._dummy_env_str}"
1401
+ )
1402
+
1403
+ def to(self, device: DEVICE_TYPING):
1404
+ device = _make_ordinal_device(torch.device(device))
1405
+ if device == self.device:
1406
+ return self
1407
+ super().to(device)
1408
+ if not self.is_closed:
1409
+ self._envs = [env.to(device) for env in self._envs]
1410
+ return self
1411
+
1412
+
1413
+ class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta):
1414
+ """Creates one environment per process.
1415
+
1416
+ TensorDicts are passed via shared memory or memory map.
1417
+
1418
+ """
1419
+
1420
+ __doc__ += BatchedEnvBase.__doc__
1421
+ __doc__ += """
1422
+
1423
+ .. note:: ParallelEnv will timeout after one of the worker is idle for a determinate amount of time.
1424
+ This can be controlled via the BATCHED_PIPE_TIMEOUT environment variable, which in turn modifies
1425
+ the torchrl._utils.BATCHED_PIPE_TIMEOUT integer. The default timeout value is 10000 seconds.
1426
+
1427
+ .. warning::
1428
+ TorchRL's ParallelEnv is quite stringent when it comes to env specs, since
1429
+ these are used to build shared memory buffers for inter-process communication.
1430
+ As such, we encourage users to first run a check of the env specs with
1431
+ :func:`~torchrl.envs.utils.check_env_specs`:
1432
+
1433
+ >>> from torchrl.envs import check_env_specs
1434
+ >>> env = make_env()
1435
+ >>> check_env_specs(env) # if this passes without error you're good to go!
1436
+ >>> penv = ParallelEnv(2, make_env)
1437
+
1438
+ In particular, gym-like envs with info-dict readers may be difficult to
1439
+ share across processes if the spec is not properly set, which is hard to
1440
+ do automatically. Check :meth:`~torchrl.envs.GymLikeEnv.set_info_dict_reader`
1441
+ for more information. Here is a short example:
1442
+
1443
+ >>> from torchrl.envs import GymEnv, set_gym_backend, check_env_specs, TransformedEnv, TensorDictPrimer
1444
+ >>> import torch
1445
+ >>> env = GymEnv("HalfCheetah-v4")
1446
+ >>> env.rollout(3) # no info registered, this env passes check_env_specs
1447
+ TensorDict(
1448
+ fields={
1449
+ action: Tensor(shape=torch.Size([10, 6]), device=cpu, dtype=torch.float32, is_shared=False),
1450
+ done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1451
+ next: TensorDict(
1452
+ fields={
1453
+ done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1454
+ observation: Tensor(shape=torch.Size([10, 17]), device=cpu, dtype=torch.float64, is_shared=False),
1455
+ reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
1456
+ terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1457
+ truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
1458
+ batch_size=torch.Size([10]),
1459
+ device=cpu,
1460
+ is_shared=False),
1461
+ observation: Tensor(shape=torch.Size([10, 17]), device=cpu, dtype=torch.float64, is_shared=False),
1462
+ terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1463
+ truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
1464
+ batch_size=torch.Size([10]),
1465
+ device=cpu,
1466
+ is_shared=False)
1467
+ >>> check_env_specs(env) # succeeds!
1468
+ >>> env.set_info_dict_reader() # sets the default info_dict reader
1469
+ >>> env.rollout(10) # because the info_dict is empty at reset time, we're missing the root infos!
1470
+ TensorDict(
1471
+ fields={
1472
+ action: Tensor(shape=torch.Size([10, 6]), device=cpu, dtype=torch.float32, is_shared=False),
1473
+ done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1474
+ next: TensorDict(
1475
+ fields={
1476
+ done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1477
+ observation: Tensor(shape=torch.Size([10, 17]), device=cpu, dtype=torch.float64, is_shared=False),
1478
+ reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
1479
+ reward_ctrl: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
1480
+ reward_run: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
1481
+ terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1482
+ truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1483
+ x_position: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
1484
+ x_velocity: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False)},
1485
+ batch_size=torch.Size([10]),
1486
+ device=cpu,
1487
+ is_shared=False),
1488
+ observation: Tensor(shape=torch.Size([10, 17]), device=cpu, dtype=torch.float64, is_shared=False),
1489
+ terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1490
+ truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
1491
+ batch_size=torch.Size([10]),
1492
+ device=cpu,
1493
+ is_shared=False)
1494
+ >>> check_env_specs(env) # This check now fails! We should not use an env constructed like this in a parallel env
1495
+ >>> # This ad-hoc fix registers the info-spec for reset. It is wrapped inside `env.auto_register_info_dict()`
1496
+ >>> env_fixed = TransformedEnv(env, TensorDictPrimer(env.info_dict_reader[0].info_spec))
1497
+ >>> env_fixed.rollout(10)
1498
+ TensorDict(
1499
+ fields={
1500
+ action: Tensor(shape=torch.Size([10, 6]), device=cpu, dtype=torch.float32, is_shared=False),
1501
+ done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1502
+ next: TensorDict(
1503
+ fields={
1504
+ done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1505
+ observation: Tensor(shape=torch.Size([10, 17]), device=cpu, dtype=torch.float64, is_shared=False),
1506
+ reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
1507
+ reward_ctrl: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
1508
+ reward_run: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
1509
+ terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1510
+ truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1511
+ x_position: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
1512
+ x_velocity: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False)},
1513
+ batch_size=torch.Size([10]),
1514
+ device=cpu,
1515
+ is_shared=False),
1516
+ observation: Tensor(shape=torch.Size([10, 17]), device=cpu, dtype=torch.float64, is_shared=False),
1517
+ reward_ctrl: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
1518
+ reward_run: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
1519
+ terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1520
+ truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
1521
+ x_position: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
1522
+ x_velocity: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False)},
1523
+ batch_size=torch.Size([10]),
1524
+ device=cpu,
1525
+ is_shared=False)
1526
+ >>> check_env_specs(env_fixed) # Succeeds! This env can be used within a parallel env!
1527
+
1528
+ Related classes and methods: :meth:`~torchrl.envs.GymLikeEnv.auto_register_info_dict`
1529
+ and :class:`~torchrl.envs.gym_like.default_info_dict_reader`.
1530
+
1531
+ .. warning::
1532
+ The choice of the devices where ParallelEnv needs to be executed can
1533
+ drastically influence its performance. The rule of thumbs is:
1534
+
1535
+ - If the base environment (backend, e.g., Gym) is executed on CPU, the
1536
+ sub-environments should be executed on CPU and the data should be
1537
+ passed via shared physical memory.
1538
+ - If the base environment is (or can be) executed on CUDA, the sub-environments
1539
+ should be placed on CUDA too.
1540
+ - If a CUDA device is available and the policy is to be executed on CUDA,
1541
+ the ParallelEnv device should be set to CUDA.
1542
+
1543
+ Therefore, supposing a CUDA device is available, we have the following scenarios:
1544
+
1545
+ >>> # The sub-envs are executed on CPU, but the policy is on GPU
1546
+ >>> env = ParallelEnv(N, MyEnv(..., device="cpu"), device="cuda")
1547
+ >>> # The sub-envs are executed on CUDA
1548
+ >>> env = ParallelEnv(N, MyEnv(..., device="cuda"), device="cuda")
1549
+ >>> # this will create the exact same environment
1550
+ >>> env = ParallelEnv(N, MyEnv(..., device="cuda"))
1551
+ >>> # If no cuda device is available
1552
+ >>> env = ParallelEnv(N, MyEnv(..., device="cpu"))
1553
+
1554
+ .. warning::
1555
+ ParallelEnv disable gradients in all operations (:meth:`step`,
1556
+ :meth:`reset` and :meth:`step_and_maybe_reset`) because gradients
1557
+ cannot be passed through :class:`multiprocessing.Pipe` objects.
1558
+ Only :class:`~torchrl.envs.SerialEnv` will support backpropagation.
1559
+
1560
+ """
1561
+
1562
+ def _start_workers(self) -> None:
1563
+ import torchrl
1564
+
1565
+ self._timeout = 10.0
1566
+ self.BATCHED_PIPE_TIMEOUT = torchrl._utils.BATCHED_PIPE_TIMEOUT
1567
+
1568
+ num_threads = max(
1569
+ 1, torch.get_num_threads() - self.num_workers
1570
+ ) # 1 more thread for this proc
1571
+
1572
+ if self.num_threads is None:
1573
+ self.num_threads = num_threads
1574
+
1575
+ if self.num_threads != torch.get_num_threads():
1576
+ torch.set_num_threads(self.num_threads)
1577
+
1578
+ if self._mp_start_method is not None:
1579
+ ctx = mp.get_context(self._mp_start_method)
1580
+ else:
1581
+ ctx = mp.get_context(_get_default_mp_start_method())
1582
+ # Use ctx.Process directly to ensure all multiprocessing primitives
1583
+ # (Queue, Pipe, Process, Event) come from the same context.
1584
+ # Warning filtering and num_threads are handled in the worker functions.
1585
+ proc_fun = ctx.Process
1586
+ num_sub_threads = self.num_sub_threads
1587
+
1588
+ _num_workers = self.num_workers
1589
+
1590
+ self.parent_channels = []
1591
+ self._workers = []
1592
+ if self._use_buffers:
1593
+ func = _run_worker_pipe_shared_mem
1594
+ else:
1595
+ func = _run_worker_pipe_direct
1596
+ # We look for cuda tensors through the leaves
1597
+ # because the shared tensordict could be partially on cuda
1598
+ # and some leaves may be inaccessible through get (e.g., LazyStacked)
1599
+ has_cuda = [False]
1600
+
1601
+ def look_for_cuda(tensor, has_cuda=has_cuda):
1602
+ has_cuda[0] = has_cuda[0] or tensor.is_cuda
1603
+
1604
+ if self._use_buffers:
1605
+ self.shared_tensordict_parent.apply(look_for_cuda, filter_empty=True)
1606
+ has_cuda = has_cuda[0]
1607
+ if has_cuda:
1608
+ self.event = torch.cuda.Event()
1609
+ else:
1610
+ self.event = None
1611
+ self._events = [ctx.Event() for _ in range(_num_workers)]
1612
+ kwargs = [{"mp_event": self._events[i]} for i in range(_num_workers)]
1613
+ with clear_mpi_env_vars():
1614
+ for idx in range(_num_workers):
1615
+ if self._verbose:
1616
+ torchrl_logger.info(f"initiating worker {idx}")
1617
+ # No certainty which module multiprocessing_context is
1618
+ parent_pipe, child_pipe = ctx.Pipe()
1619
+ env_fun = self.create_env_fn[idx]
1620
+ if not isinstance(env_fun, (EnvCreator, CloudpickleWrapper)):
1621
+ env_fun = CloudpickleWrapper(env_fun)
1622
+
1623
+ kwargs[idx].update(
1624
+ {
1625
+ "parent_pipe": parent_pipe,
1626
+ "child_pipe": child_pipe,
1627
+ "env_fun": env_fun,
1628
+ "env_fun_kwargs": self.create_env_kwargs[idx],
1629
+ "has_lazy_inputs": self.has_lazy_inputs,
1630
+ "num_threads": num_sub_threads,
1631
+ "non_blocking": self.non_blocking,
1632
+ "filter_warnings": self._filter_warnings_subprocess(),
1633
+ }
1634
+ )
1635
+ if self._use_buffers:
1636
+ kwargs[idx].update(
1637
+ {
1638
+ "shared_tensordict": self.shared_tensordicts[idx],
1639
+ "_selected_input_keys": self._selected_input_keys,
1640
+ "_selected_reset_keys": self._selected_reset_keys,
1641
+ "_selected_step_keys": self._selected_step_keys,
1642
+ "_non_tensor_keys": self._non_tensor_keys,
1643
+ }
1644
+ )
1645
+ else:
1646
+ kwargs[idx].update(
1647
+ {
1648
+ "consolidate": self.consolidate,
1649
+ }
1650
+ )
1651
+ process = proc_fun(target=func, kwargs=kwargs[idx])
1652
+ process.daemon = self.daemon
1653
+ process.start()
1654
+ child_pipe.close()
1655
+ self.parent_channels.append(parent_pipe)
1656
+ self._workers.append(process)
1657
+
1658
+ for parent_pipe in self.parent_channels:
1659
+ # use msg as sync point
1660
+ parent_pipe.recv()
1661
+
1662
+ # send shared tensordict to workers
1663
+ for channel in self.parent_channels:
1664
+ channel.send(("init", None))
1665
+ self.is_closed = False
1666
+ self.set_spec_lock_()
1667
+
1668
+ def _filter_warnings_subprocess(self) -> bool:
1669
+ from torchrl import filter_warnings_subprocess
1670
+
1671
+ return filter_warnings_subprocess
1672
+
1673
+ @_check_start
1674
+ def state_dict(self) -> OrderedDict:
1675
+ state_dict = OrderedDict()
1676
+ for channel in self.parent_channels:
1677
+ channel.send(("state_dict", None))
1678
+ for idx, channel in enumerate(self.parent_channels):
1679
+ msg, _state_dict = channel.recv()
1680
+ if msg != "state_dict":
1681
+ raise RuntimeError(f"Expected 'state_dict' but received {msg}")
1682
+ state_dict[f"worker{idx}"] = _state_dict
1683
+
1684
+ return state_dict
1685
+
1686
+ @_check_start
1687
+ def load_state_dict(self, state_dict: OrderedDict) -> None:
1688
+ if "worker0" not in state_dict:
1689
+ state_dict = OrderedDict(
1690
+ **{f"worker{idx}": state_dict for idx in range(self.num_workers)}
1691
+ )
1692
+ for i, channel in enumerate(self.parent_channels):
1693
+ channel.send(("load_state_dict", state_dict[f"worker{i}"]))
1694
+ for event in self._events:
1695
+ event.wait(self._timeout)
1696
+ event.clear()
1697
+
1698
+ def _step_and_maybe_reset_no_buffers(
1699
+ self, tensordict: TensorDictBase
1700
+ ) -> tuple[TensorDictBase, TensorDictBase]:
1701
+ partial_steps = tensordict.get("_step", None)
1702
+ tensordict_save = tensordict
1703
+ if partial_steps is not None and partial_steps.all():
1704
+ partial_steps = None
1705
+ if partial_steps is not None:
1706
+ partial_steps = partial_steps.view(tensordict.shape)
1707
+ tensordict = tensordict[partial_steps]
1708
+ workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
1709
+ else:
1710
+ workers_range = range(self.num_workers)
1711
+
1712
+ if self.consolidate:
1713
+ try:
1714
+ td = tensordict.consolidate(
1715
+ # share_memory=False: avoid resource_sharer which causes
1716
+ # progressive slowdown with fork on Linux
1717
+ share_memory=False,
1718
+ inplace=True,
1719
+ num_threads=1,
1720
+ )
1721
+ except Exception as err:
1722
+ raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
1723
+ else:
1724
+ td = tensordict
1725
+
1726
+ for i in workers_range:
1727
+ # We send the same td multiple times as it is in shared mem and we just need to index it
1728
+ # in each process.
1729
+ # If we don't do this, we need to unbind it but then the custom pickler will require
1730
+ # some extra metadata to be collected.
1731
+ self.parent_channels[i].send(("step_and_maybe_reset", (td, i)))
1732
+
1733
+ results = [None] * len(workers_range)
1734
+
1735
+ self._wait_for_workers(workers_range)
1736
+
1737
+ for i, w in enumerate(workers_range):
1738
+ results[i] = self.parent_channels[w].recv()
1739
+
1740
+ out_next, out_root = zip(*(future for future in results))
1741
+ out = TensorDict.maybe_dense_stack(out_next), TensorDict.maybe_dense_stack(
1742
+ out_root
1743
+ )
1744
+ if partial_steps is not None:
1745
+ result = out.new_zeros(tensordict_save.shape)
1746
+
1747
+ def select_and_clone(x, y):
1748
+ if y is not None:
1749
+ if x.device != y.device:
1750
+ x = x.to(y.device)
1751
+ else:
1752
+ x = x.clone()
1753
+ return x
1754
+
1755
+ prev = tensordict_save._fast_apply(
1756
+ select_and_clone,
1757
+ result,
1758
+ filter_empty=True,
1759
+ device=result.device,
1760
+ batch_size=result.batch_size,
1761
+ is_leaf=_is_leaf_nontensor,
1762
+ default=None,
1763
+ )
1764
+
1765
+ result.update(prev)
1766
+
1767
+ if partial_steps.any():
1768
+ result[partial_steps] = out
1769
+ return result
1770
+ return out
1771
+
1772
+ @torch.no_grad()
1773
+ @_check_start
1774
+ def step_and_maybe_reset(
1775
+ self, tensordict: TensorDictBase
1776
+ ) -> tuple[TensorDictBase, TensorDictBase]:
1777
+ if not self._use_buffers:
1778
+ # Simply dispatch the input to the workers
1779
+ # return self._step_and_maybe_reset_no_buffers(tensordict)
1780
+ return super().step_and_maybe_reset(tensordict)
1781
+
1782
+ partial_steps = tensordict.get("_step")
1783
+ tensordict_save = tensordict
1784
+ if partial_steps is not None and partial_steps.all():
1785
+ partial_steps = None
1786
+ if partial_steps is not None:
1787
+ partial_steps = partial_steps.view(tensordict.shape)
1788
+ workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
1789
+ shared_tensordict_parent = TensorDict.lazy_stack(
1790
+ [self.shared_tensordict_parent[i] for i in workers_range]
1791
+ )
1792
+ next_td = TensorDict.lazy_stack(
1793
+ [self._shared_tensordict_parent_next[i] for i in workers_range]
1794
+ )
1795
+ tensordict_ = TensorDict.lazy_stack(
1796
+ [self._shared_tensordict_parent_root[i] for i in workers_range]
1797
+ )
1798
+ if self.shared_tensordict_parent.device is None:
1799
+ tensordict = tensordict._fast_apply(
1800
+ lambda x, y: x[partial_steps].to(y.device)
1801
+ if y is not None
1802
+ else x[partial_steps],
1803
+ self.shared_tensordict_parent,
1804
+ default=None,
1805
+ device=None,
1806
+ batch_size=shared_tensordict_parent.shape,
1807
+ )
1808
+ else:
1809
+ tensordict = tensordict[partial_steps].to(
1810
+ self.shared_tensordict_parent.device
1811
+ )
1812
+ else:
1813
+ workers_range = range(self.num_workers)
1814
+ shared_tensordict_parent = self.shared_tensordict_parent
1815
+ next_td = self._shared_tensordict_parent_next
1816
+ tensordict_ = self._shared_tensordict_parent_root
1817
+
1818
+ # We must use the in_keys and nothing else for the following reasons:
1819
+ # - efficiency: copying all the keys will in practice mean doing a lot
1820
+ # of writing operations since the input tensordict may (and often will)
1821
+ # contain all the previous output data.
1822
+ # - value mismatch: if the batched env is placed within a transform
1823
+ # and this transform overrides an observation key (eg, CatFrames)
1824
+ # the shape, dtype or device may not necessarily match and writing
1825
+ # the value in-place will fail.
1826
+ shared_tensordict_parent.update_(
1827
+ tensordict,
1828
+ keys_to_update=self._env_input_keys,
1829
+ non_blocking=self.non_blocking,
1830
+ )
1831
+ next_td_passthrough = tensordict.get("next", default=None)
1832
+ if next_td_passthrough is not None:
1833
+ # if we have input "next" data (eg, RNNs which pass the next state)
1834
+ # the sub-envs will need to process them through step_and_maybe_reset.
1835
+ # We keep track of which keys are present to let the worker know what
1836
+ # should be passed to the env (we don't want to pass done states for instance)
1837
+ next_td_keys = list(next_td_passthrough.keys(True, True))
1838
+ data = [{"next_td_passthrough_keys": next_td_keys} for _ in workers_range]
1839
+ shared_tensordict_parent.get("next").update_(
1840
+ next_td_passthrough, non_blocking=self.non_blocking
1841
+ )
1842
+ else:
1843
+ # next_td_keys = None
1844
+ data = [{} for _ in workers_range]
1845
+
1846
+ if self._non_tensor_keys:
1847
+ for i, td in zip(
1848
+ workers_range,
1849
+ tensordict.select(*self._non_tensor_keys, strict=False).unbind(0),
1850
+ ):
1851
+ data[i]["non_tensor_data"] = td
1852
+
1853
+ self._sync_m2w()
1854
+ for i, _data in zip(workers_range, data):
1855
+ self.parent_channels[i].send(("step_and_maybe_reset", _data))
1856
+
1857
+ self._wait_for_workers(workers_range)
1858
+ if self._non_tensor_keys:
1859
+ non_tensor_tds = []
1860
+ for i in workers_range:
1861
+ msg, non_tensor_td = self.parent_channels[i].recv()
1862
+ non_tensor_tds.append(non_tensor_td)
1863
+
1864
+ # We must pass a clone of the tensordict, as the values of this tensordict
1865
+ # will be modified in-place at further steps
1866
+ device = self.device
1867
+ if shared_tensordict_parent.device == device:
1868
+ next_td = next_td.clone()
1869
+ tensordict_ = tensordict_.clone()
1870
+ elif device is not None:
1871
+ next_td = next_td._fast_apply(
1872
+ lambda x: x.to(device, non_blocking=self.non_blocking)
1873
+ if x.device != device
1874
+ else x.clone(),
1875
+ device=device,
1876
+ filter_empty=True,
1877
+ )
1878
+ tensordict_ = tensordict_._fast_apply(
1879
+ lambda x: x.to(device, non_blocking=self.non_blocking)
1880
+ if x.device != device
1881
+ else x.clone(),
1882
+ device=device,
1883
+ filter_empty=True,
1884
+ )
1885
+ if tensordict.device != device:
1886
+ tensordict = tensordict._fast_apply(
1887
+ lambda x: x.to(device, non_blocking=self.non_blocking)
1888
+ if x.device != device
1889
+ else x,
1890
+ device=device,
1891
+ filter_empty=True,
1892
+ )
1893
+ self._sync_w2m()
1894
+ else:
1895
+ next_td = next_td.clone().clear_device_()
1896
+ tensordict_ = tensordict_.clone().clear_device_()
1897
+ tensordict.set("next", next_td)
1898
+ if self._non_tensor_keys:
1899
+ non_tensor_tds = LazyStackedTensorDict(*non_tensor_tds)
1900
+ tensordict.update(
1901
+ non_tensor_tds,
1902
+ keys_to_update=[("next", key) for key in self._non_tensor_keys],
1903
+ )
1904
+ tensordict_.update(non_tensor_tds, keys_to_update=self._non_tensor_keys)
1905
+
1906
+ if partial_steps is not None:
1907
+ result = tensordict.new_zeros(tensordict_save.shape)
1908
+ result_ = tensordict_.new_zeros(tensordict_save.shape)
1909
+
1910
+ def select_and_transfer(x, y):
1911
+ if y is not None:
1912
+ return (
1913
+ x.to(y.device, non_blocking=self.non_blocking)
1914
+ if x.device != y.device
1915
+ else x.clone()
1916
+ )
1917
+
1918
+ old_r_copy = tensordict_save._fast_apply(
1919
+ select_and_transfer,
1920
+ result,
1921
+ filter_empty=True,
1922
+ device=device,
1923
+ default=None,
1924
+ )
1925
+ old_r_copy.set(
1926
+ "next",
1927
+ tensordict_save._fast_apply(
1928
+ select_and_transfer,
1929
+ next_td,
1930
+ filter_empty=True,
1931
+ device=device,
1932
+ default=None,
1933
+ ),
1934
+ )
1935
+ result.update(old_r_copy)
1936
+ result_.update(
1937
+ tensordict_save._fast_apply(
1938
+ select_and_transfer,
1939
+ result_,
1940
+ filter_empty=True,
1941
+ device=device,
1942
+ default=None,
1943
+ )
1944
+ )
1945
+ self._sync_w2m()
1946
+
1947
+ if partial_steps.any():
1948
+ result[partial_steps] = tensordict
1949
+ result_[partial_steps] = tensordict_
1950
+ return result, result_
1951
+
1952
+ return tensordict, tensordict_
1953
+
1954
+ def _wait_for_workers(self, workers_range):
1955
+ """Wait for all workers to signal completion via their events.
1956
+
1957
+ Uses multiprocessing.connection.wait() for efficient OS-level
1958
+ waiting on multiple pipes simultaneously.
1959
+ """
1960
+ timeout = self.BATCHED_PIPE_TIMEOUT
1961
+ t0 = time.time()
1962
+
1963
+ # In shared-memory/buffer mode, workers signal completion by setting
1964
+ # their `mp_event` (they may not send anything back on the pipe).
1965
+ if self._use_buffers:
1966
+ pending = set(workers_range)
1967
+ n_iter = 0
1968
+ while pending:
1969
+ n_iter += 1
1970
+ remaining = timeout - (time.time() - t0)
1971
+ if remaining <= 0:
1972
+ raise RuntimeError(
1973
+ f"Failed to run all workers within the {timeout} sec time limit. This "
1974
+ f"threshold can be increased via the BATCHED_PIPE_TIMEOUT env variable."
1975
+ )
1976
+
1977
+ # Wait in short slices so we can both harvest multiple events and
1978
+ # periodically check for dead workers without blocking forever.
1979
+ slice_timeout = min(0.1, remaining)
1980
+ progressed = False
1981
+ for wi in tuple(pending):
1982
+ if self._events[wi].wait(timeout=slice_timeout):
1983
+ self._events[wi].clear()
1984
+ pending.remove(wi)
1985
+ progressed = True
1986
+
1987
+ if not progressed and (n_iter % 50) == 0:
1988
+ for wi in pending:
1989
+ if not self._workers[wi].is_alive():
1990
+ try:
1991
+ self._shutdown_workers()
1992
+ finally:
1993
+ raise RuntimeError(f"Cannot proceed, worker {wi} dead.")
1994
+ return
1995
+
1996
+ # No-buffer mode: workers send back data on the pipe, so we can efficiently
1997
+ # block on readability.
1998
+ pipes_pending = {self.parent_channels[i]: i for i in workers_range}
1999
+ i = 0
2000
+ while pipes_pending:
2001
+ i += 1
2002
+ should_check_for_dead_workers = (i % 20) == 0
2003
+ remaining = timeout - (time.time() - t0)
2004
+ if remaining <= 0:
2005
+ raise RuntimeError(
2006
+ f"Failed to run all workers within the {timeout} sec time limit. This "
2007
+ f"threshold can be increased via the BATCHED_PIPE_TIMEOUT env variable."
2008
+ )
2009
+
2010
+ # Wait for any pipes to become readable (OS-level select/poll)
2011
+ ready = connection_wait(list(pipes_pending.keys()), timeout=remaining)
2012
+
2013
+ if not ready and should_check_for_dead_workers:
2014
+ # Timeout with no pipes ready - check for dead workers
2015
+ for wi in pipes_pending.values():
2016
+ if not self._workers[wi].is_alive():
2017
+ try:
2018
+ self._shutdown_workers()
2019
+ finally:
2020
+ raise RuntimeError(f"Cannot proceed, worker {wi} dead.")
2021
+ continue
2022
+
2023
+ # Clear events for ready workers (best-effort)
2024
+ for pipe in ready:
2025
+ wi = pipes_pending.pop(pipe)
2026
+ self._events[wi].clear()
2027
+
2028
+ def _step_no_buffers(
2029
+ self, tensordict: TensorDictBase
2030
+ ) -> tuple[TensorDictBase, TensorDictBase]:
2031
+ partial_steps = tensordict.get("_step")
2032
+ tensordict_save = tensordict
2033
+ if partial_steps is not None and partial_steps.all():
2034
+ partial_steps = None
2035
+ if partial_steps is not None:
2036
+ partial_steps = partial_steps.view(tensordict.shape)
2037
+ tensordict = tensordict[partial_steps]
2038
+ workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
2039
+ else:
2040
+ workers_range = range(self.num_workers)
2041
+
2042
+ if self.consolidate:
2043
+ try:
2044
+ data = tensordict.consolidate(
2045
+ # share_memory=False: avoid resource_sharer which causes
2046
+ # progressive slowdown with fork on Linux
2047
+ share_memory=False,
2048
+ inplace=False,
2049
+ num_threads=1,
2050
+ )
2051
+ except Exception as err:
2052
+ raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
2053
+ else:
2054
+ data = tensordict
2055
+
2056
+ for i, local_data in zip(workers_range, data.unbind(0)):
2057
+ env_device = (
2058
+ self.meta_data[i].device
2059
+ if isinstance(self.meta_data, list)
2060
+ else self.meta_data.device
2061
+ )
2062
+ if data.device != env_device:
2063
+ if env_device is None:
2064
+ local_data.clear_device_()
2065
+ else:
2066
+ local_data = local_data.to(env_device)
2067
+ self.parent_channels[i].send(("step", local_data))
2068
+
2069
+ self._wait_for_workers(workers_range)
2070
+
2071
+ out_tds = []
2072
+ for i in workers_range:
2073
+ channel = self.parent_channels[i]
2074
+ td = channel.recv()
2075
+ out_tds.append(td)
2076
+
2077
+ out = LazyStackedTensorDict.maybe_dense_stack(out_tds)
2078
+ if self.device is not None and out.device != self.device:
2079
+ out = out.to(self.device, non_blocking=self.non_blocking)
2080
+ if partial_steps is not None:
2081
+ result = out.new_zeros(tensordict_save.shape)
2082
+
2083
+ def select_and_clone(x, y):
2084
+ if y is not None:
2085
+ if x.device != y.device:
2086
+ x = x.to(y.device)
2087
+ else:
2088
+ x = x.clone()
2089
+ return x
2090
+
2091
+ prev = tensordict_save._fast_apply(
2092
+ select_and_clone,
2093
+ result,
2094
+ filter_empty=True,
2095
+ device=result.device,
2096
+ batch_size=result.batch_size,
2097
+ is_leaf=_is_leaf_nontensor,
2098
+ default=None,
2099
+ )
2100
+
2101
+ result.update(prev)
2102
+
2103
+ if partial_steps.any():
2104
+ result[partial_steps] = out
2105
+ return result
2106
+ return out
2107
+
2108
+ @torch.no_grad()
2109
+ @_check_start
2110
+ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
2111
+ if not self._use_buffers:
2112
+ return self._step_no_buffers(tensordict)
2113
+ # We must use the in_keys and nothing else for the following reasons:
2114
+ # - efficiency: copying all the keys will in practice mean doing a lot
2115
+ # of writing operations since the input tensordict may (and often will)
2116
+ # contain all the previous output data.
2117
+ # - value mismatch: if the batched env is placed within a transform
2118
+ # and this transform overrides an observation key (eg, CatFrames)
2119
+ # the shape, dtype or device may not necessarily match and writing
2120
+ # the value in-place will fail.
2121
+ partial_steps = tensordict.get("_step")
2122
+ tensordict_save = tensordict
2123
+ if partial_steps is not None and partial_steps.all():
2124
+ partial_steps = None
2125
+ if partial_steps is not None:
2126
+ partial_steps = partial_steps.view(tensordict.shape)
2127
+ workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
2128
+ shared_tensordict_parent = TensorDict.lazy_stack(
2129
+ [self.shared_tensordicts[i] for i in workers_range]
2130
+ )
2131
+ if self.shared_tensordict_parent.device is None:
2132
+ tensordict = tensordict._fast_apply(
2133
+ lambda x, y: x[partial_steps].to(y.device)
2134
+ if y is not None
2135
+ else x[partial_steps],
2136
+ self.shared_tensordict_parent,
2137
+ default=None,
2138
+ device=None,
2139
+ batch_size=shared_tensordict_parent.shape,
2140
+ )
2141
+ else:
2142
+ tensordict = tensordict[partial_steps].to(
2143
+ self.shared_tensordict_parent.device
2144
+ )
2145
+ else:
2146
+ workers_range = range(self.num_workers)
2147
+ shared_tensordict_parent = self.shared_tensordict_parent
2148
+
2149
+ shared_tensordict_parent.update_(
2150
+ tensordict,
2151
+ # We also update the output keys because they can be implicitly used, eg
2152
+ # during partial steps to fill in values
2153
+ keys_to_update=list(self._env_input_keys),
2154
+ non_blocking=self.non_blocking,
2155
+ )
2156
+ next_td_passthrough = tensordict.get("next", None)
2157
+ if next_td_passthrough is not None:
2158
+ # if we have input "next" data (eg, RNNs which pass the next state)
2159
+ # the sub-envs will need to process them through step_and_maybe_reset.
2160
+ # We keep track of which keys are present to let the worker know what
2161
+ # should be passed to the env (we don't want to pass done states for instance)
2162
+ next_td_keys = list(next_td_passthrough.keys(True, True))
2163
+ next_shared_tensordict_parent = shared_tensordict_parent.get("next")
2164
+
2165
+ # We separate keys that are and are not present in the buffer here and not in step_and_maybe_reset.
2166
+ # The reason we do that is that the policy may write stuff in 'next' that is not part of the specs of
2167
+ # the batched env but part of the specs of a transformed batched env.
2168
+ # If that is the case, `update_` will fail to find the entries to update.
2169
+ # What we do instead is keeping the tensors on the side and putting them back after completing _step.
2170
+ keys_to_update, keys_to_copy = zip(
2171
+ *[
2172
+ (key, None)
2173
+ if key in next_shared_tensordict_parent.keys(True, True)
2174
+ else (None, key)
2175
+ for key in next_td_keys
2176
+ ]
2177
+ )
2178
+ keys_to_update = [key for key in keys_to_update if key is not None]
2179
+ keys_to_copy = [key for key in keys_to_copy if key is not None]
2180
+ data = [
2181
+ {"next_td_passthrough_keys": keys_to_update}
2182
+ for _ in range(self.num_workers)
2183
+ ]
2184
+ if keys_to_update:
2185
+ next_shared_tensordict_parent.update_(
2186
+ next_td_passthrough,
2187
+ non_blocking=self.non_blocking,
2188
+ keys_to_update=keys_to_update,
2189
+ )
2190
+ if keys_to_copy:
2191
+ next_td_passthrough = next_td_passthrough.select(*keys_to_copy)
2192
+ else:
2193
+ next_td_passthrough = None
2194
+ else:
2195
+ next_td_passthrough = None
2196
+ data = [{} for _ in range(self.num_workers)]
2197
+
2198
+ if self._non_tensor_keys:
2199
+ for i, td in zip(
2200
+ workers_range,
2201
+ tensordict.select(*self._non_tensor_keys, strict=False).unbind(0),
2202
+ ):
2203
+ data[i]["non_tensor_data"] = td
2204
+
2205
+ self._sync_m2w()
2206
+
2207
+ if self.event is not None:
2208
+ self.event.record()
2209
+ self.event.synchronize()
2210
+
2211
+ for i in workers_range:
2212
+ self.parent_channels[i].send(("step", data[i]))
2213
+
2214
+ self._wait_for_workers(workers_range)
2215
+
2216
+ if self._non_tensor_keys:
2217
+ non_tensor_tds = []
2218
+ for i in workers_range:
2219
+ msg, non_tensor_td = self.parent_channels[i].recv()
2220
+ non_tensor_tds.append(non_tensor_td)
2221
+
2222
+ # We must pass a clone of the tensordict, as the values of this tensordict
2223
+ # will be modified in-place at further steps
2224
+ next_td = shared_tensordict_parent.get("next")
2225
+ device = self.device
2226
+
2227
+ out = next_td.named_apply(
2228
+ self.select_and_clone,
2229
+ nested_keys=True,
2230
+ filter_empty=True,
2231
+ device=device,
2232
+ )
2233
+ if self._non_tensor_keys:
2234
+ out.update(
2235
+ LazyStackedTensorDict(*non_tensor_tds),
2236
+ keys_to_update=self._non_tensor_keys,
2237
+ )
2238
+ if next_td_passthrough is not None:
2239
+ out.update(next_td_passthrough)
2240
+
2241
+ self._sync_w2m()
2242
+ if partial_steps is not None:
2243
+ result = out.new_zeros(tensordict_save.shape)
2244
+
2245
+ def select_and_clone(x, y):
2246
+ if y is not None:
2247
+ if x.device != y.device:
2248
+ x = x.to(y.device)
2249
+ else:
2250
+ x = x.clone()
2251
+ return x
2252
+
2253
+ prev = tensordict_save._fast_apply(
2254
+ select_and_clone,
2255
+ result,
2256
+ filter_empty=True,
2257
+ device=result.device,
2258
+ batch_size=result.batch_size,
2259
+ is_leaf=_is_leaf_nontensor,
2260
+ default=None,
2261
+ )
2262
+
2263
+ result.update(prev)
2264
+ if partial_steps.any():
2265
+ result[partial_steps] = out
2266
+ return result
2267
+ return out
2268
+
2269
+ def _reset_no_buffers(
2270
+ self,
2271
+ tensordict: TensorDictBase,
2272
+ reset_kwargs_list,
2273
+ needs_resetting,
2274
+ ) -> tuple[TensorDictBase, TensorDictBase]:
2275
+ if is_tensor_collection(tensordict):
2276
+ if self.consolidate:
2277
+ try:
2278
+ tensordict = tensordict.consolidate(
2279
+ # share_memory=False: avoid resource_sharer which causes
2280
+ # progressive slowdown with fork on Linux
2281
+ share_memory=False,
2282
+ num_threads=1,
2283
+ )
2284
+ except Exception as err:
2285
+ raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
2286
+ tensordict = tensordict.unbind(0)
2287
+ else:
2288
+ tensordict = [None] * self.num_workers
2289
+ out_tds = [None] * self.num_workers
2290
+ needs_resetting_int = []
2291
+ for i, (local_data, reset_kwargs) in enumerate(
2292
+ zip(tensordict, reset_kwargs_list)
2293
+ ):
2294
+ if not needs_resetting[i]:
2295
+ localtd = local_data
2296
+ if localtd is not None:
2297
+ localtd = localtd.exclude(*self.reset_keys)
2298
+ out_tds[i] = localtd
2299
+ continue
2300
+ needs_resetting_int.append(i)
2301
+ self.parent_channels[i].send(("reset", (local_data, reset_kwargs)))
2302
+
2303
+ self._wait_for_workers(needs_resetting_int)
2304
+
2305
+ for i, channel in enumerate(self.parent_channels):
2306
+ if not needs_resetting[i]:
2307
+ continue
2308
+ td = channel.recv()
2309
+ out_tds[i] = td
2310
+ result = LazyStackedTensorDict.maybe_dense_stack(out_tds)
2311
+ device = self.device
2312
+ if device is not None and result.device != device:
2313
+ return result.to(self.device, non_blocking=self.non_blocking)
2314
+ return result
2315
+
2316
+ @torch.no_grad()
2317
+ @_check_start
2318
+ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2319
+
2320
+ list_of_kwargs = kwargs.pop("list_of_kwargs", [kwargs] * self.num_workers)
2321
+ if kwargs is not list_of_kwargs[0] and kwargs:
2322
+ # this means that kwargs had more than one element and that a list was provided
2323
+ for elt in list_of_kwargs:
2324
+ elt.update(kwargs)
2325
+
2326
+ if tensordict is not None:
2327
+ if "_reset" in tensordict.keys():
2328
+ needs_resetting = tensordict["_reset"]
2329
+ else:
2330
+ needs_resetting = _aggregate_end_of_traj(
2331
+ tensordict, reset_keys=self.reset_keys
2332
+ )
2333
+ if needs_resetting.ndim > 2:
2334
+ needs_resetting = needs_resetting.flatten(1, needs_resetting.ndim - 1)
2335
+ if needs_resetting.ndim > 1:
2336
+ needs_resetting = needs_resetting.any(-1)
2337
+ elif not needs_resetting.ndim:
2338
+ needs_resetting = needs_resetting.expand((self.num_workers,))
2339
+ else:
2340
+ needs_resetting = torch.ones(
2341
+ (self.num_workers,), device=self.device, dtype=torch.bool
2342
+ )
2343
+
2344
+ if not self._use_buffers:
2345
+ return self._reset_no_buffers(tensordict, list_of_kwargs, needs_resetting)
2346
+
2347
+ outs = []
2348
+ for i in range(self.num_workers):
2349
+ if tensordict is not None:
2350
+ tensordict_ = tensordict[i]
2351
+ if tensordict_.is_empty():
2352
+ tensordict_ = None
2353
+ elif self.device is not None and self.device.type == "mps":
2354
+ # copy_ fails when moving mps->cpu using copy_
2355
+ # in some cases when a view of an mps tensor is used.
2356
+ # We know the shared tensors are not MPS, so we can
2357
+ # safely assume that the shared tensors are on cpu
2358
+ tensordict_ = tensordict_.to("cpu")
2359
+ else:
2360
+ tensordict_ = None
2361
+ if not needs_resetting[i]:
2362
+ # We update the stored tensordict with the value of the "next"
2363
+ # key as one may be surprised to receive data that is not up-to-date
2364
+ # If we don't do this, the result of calling reset and skipping one env
2365
+ # will be that the env will have the data from the previous
2366
+ # step at the root (since the shared_tensordict did not go through
2367
+ # step_mdp).
2368
+ self.shared_tensordicts[i].update_(
2369
+ self.shared_tensordicts[i].get("next"),
2370
+ keys_to_update=list(self._selected_reset_keys),
2371
+ non_blocking=self.non_blocking,
2372
+ )
2373
+ if tensordict_ is not None:
2374
+ self.shared_tensordicts[i].update_(
2375
+ tensordict_,
2376
+ keys_to_update=list(self._selected_reset_keys),
2377
+ non_blocking=self.non_blocking,
2378
+ )
2379
+ continue
2380
+ if tensordict_ is not None:
2381
+ tdkeys = list(tensordict_.keys(True, True))
2382
+
2383
+ # This way we can avoid calling select over all the keys in the shared tensordict
2384
+ def tentative_update(val, other):
2385
+ if other is not None:
2386
+ val.copy_(other, non_blocking=self.non_blocking)
2387
+ return val
2388
+
2389
+ self.shared_tensordicts[i].apply_(
2390
+ tentative_update, tensordict_, default=None
2391
+ )
2392
+ out = ("reset", (tdkeys, list_of_kwargs[i]))
2393
+ else:
2394
+ out = ("reset", (False, list_of_kwargs[i]))
2395
+ outs.append((i, out))
2396
+
2397
+ self._sync_m2w()
2398
+
2399
+ for i, out in outs:
2400
+ self.parent_channels[i].send(out)
2401
+
2402
+ self._wait_for_workers(list(zip(*outs))[0])
2403
+
2404
+ workers_nontensor = []
2405
+ if self._non_tensor_keys:
2406
+ for i, _ in outs:
2407
+ msg, non_tensor_td = self.parent_channels[i].recv()
2408
+ workers_nontensor.append((i, non_tensor_td))
2409
+
2410
+ selected_output_keys = self._selected_reset_keys_filt
2411
+ device = self.device
2412
+
2413
+ out = self.shared_tensordict_parent.named_apply(
2414
+ lambda *args: self.select_and_clone(
2415
+ *args, selected_keys=selected_output_keys
2416
+ ),
2417
+ nested_keys=True,
2418
+ filter_empty=True,
2419
+ device=device,
2420
+ )
2421
+ if self._non_tensor_keys:
2422
+ workers, nontensor = zip(*workers_nontensor)
2423
+ out[torch.tensor(workers)] = LazyStackedTensorDict(*nontensor).select(
2424
+ *self._non_tensor_keys
2425
+ )
2426
+ self._sync_w2m()
2427
+ return out
2428
+
2429
+ @_check_start
2430
+ def _shutdown_workers(self) -> None:
2431
+ try:
2432
+ if self.is_closed:
2433
+ raise RuntimeError(
2434
+ "calling {self.__class__.__name__}._shutdown_workers only allowed when env.is_closed = False"
2435
+ )
2436
+ for i, channel in enumerate(self.parent_channels):
2437
+ if self._verbose:
2438
+ torchrl_logger.info(f"closing {i}")
2439
+ channel.send(("close", None))
2440
+ for i in range(self.num_workers):
2441
+ self._events[i].wait(self._timeout)
2442
+ self._events[i].clear()
2443
+ if self._use_buffers:
2444
+ del self.shared_tensordicts, self.shared_tensordict_parent
2445
+
2446
+ for channel in self.parent_channels:
2447
+ channel.close()
2448
+ start_time = time.time()
2449
+ while (
2450
+ any(proc.is_alive() for proc in self._workers)
2451
+ and (time.time() - start_time) < self._timeout
2452
+ ):
2453
+ time.sleep(0.01)
2454
+ for proc in self._workers:
2455
+ proc.join()
2456
+ finally:
2457
+ for proc in self._workers:
2458
+ if proc.is_alive():
2459
+ proc.terminate()
2460
+ del self._workers
2461
+ del self.parent_channels
2462
+ self._cuda_events = None
2463
+ self._events = None
2464
+ self.event = None
2465
+
2466
+ @_check_start
2467
+ def set_seed(
2468
+ self, seed: int | None = None, static_seed: bool = False
2469
+ ) -> int | None:
2470
+ self._seeds = []
2471
+ for channel in self.parent_channels:
2472
+ channel.send(("seed", (seed, static_seed)))
2473
+ self._seeds.append(seed)
2474
+ msg, new_seed = channel.recv()
2475
+ if msg != "seeded":
2476
+ raise RuntimeError(f"Expected 'seeded' but received {msg}")
2477
+ seed = new_seed
2478
+ return seed
2479
+
2480
+ def __reduce__(self):
2481
+ if not self.is_closed:
2482
+ # ParallelEnv contains non-instantiated envs, thus it can be
2483
+ # closed and serialized if the environment building functions
2484
+ # permit it
2485
+ self.close()
2486
+ return super().__reduce__()
2487
+
2488
+ def __getattr__(self, attr: str) -> Any:
2489
+ if attr in self.__dir__():
2490
+ return super().__getattr__(
2491
+ attr
2492
+ ) # make sure that appropriate exceptions are raised
2493
+ elif attr.startswith("__"):
2494
+ raise AttributeError(
2495
+ "dispatching built-in private methods is not permitted."
2496
+ )
2497
+ else:
2498
+ if attr in self._excluded_wrapped_keys:
2499
+ raise AttributeError(f"Getting {attr} resulted in an exception")
2500
+ try:
2501
+ # _ = getattr(self._dummy_env, attr)
2502
+ if self.is_closed:
2503
+ self.start()
2504
+ raise RuntimeError(
2505
+ "Trying to access attributes of closed/non started "
2506
+ "environments. Check that the batched environment "
2507
+ "has been started (e.g. by calling env.reset)"
2508
+ )
2509
+ # dispatch to workers
2510
+ return _dispatch_caller_parallel(attr, self)
2511
+ except AttributeError:
2512
+ raise AttributeError(
2513
+ f"attribute {attr} not found in " f"{self._dummy_env_str}"
2514
+ )
2515
+
2516
+ def to(self, device: DEVICE_TYPING):
2517
+ device = _make_ordinal_device(torch.device(device))
2518
+ if device == self.device:
2519
+ return self
2520
+ super().to(device)
2521
+ if self._seeds is not None:
2522
+ warn(
2523
+ "Sending a seeded ParallelEnv to another device requires "
2524
+ f"re-seeding it. Re-seeding envs to {self._seeds}."
2525
+ )
2526
+ self.set_seed(self._seeds[0])
2527
+ return self
2528
+
2529
+ @classmethod
2530
+ def make_parallel(cls, *args, num_envs: int = 1, **parallel_kwargs) -> EnvBase:
2531
+ """Backward-compatible factory matching EnvBase.make_parallel signature.
2532
+
2533
+ Supports calls like:
2534
+ ParallelEnv.make_parallel(create_env_fn, num_envs=4, ...)
2535
+ or the constructor form:
2536
+ ParallelEnv.make_parallel(num_workers, create_env_fn, ...)
2537
+ """
2538
+ if len(args) >= 1 and isinstance(args[0], int):
2539
+ return cls(*args, **parallel_kwargs)
2540
+ if len(args) >= 1:
2541
+ create_env_fn = args[0]
2542
+ other_args = args[1:]
2543
+ return cls(int(num_envs), create_env_fn, *other_args, **parallel_kwargs)
2544
+ return cls(int(num_envs), **parallel_kwargs)
2545
+
2546
+
2547
+ def _recursively_strip_locks_from_state_dict(state_dict: OrderedDict) -> OrderedDict:
2548
+ return OrderedDict(
2549
+ **{
2550
+ k: _recursively_strip_locks_from_state_dict(item)
2551
+ if isinstance(item, OrderedDict)
2552
+ else None
2553
+ if isinstance(item, MpLock)
2554
+ else item
2555
+ for k, item in state_dict.items()
2556
+ }
2557
+ )
2558
+
2559
+
2560
+ def _run_worker_pipe_shared_mem(
2561
+ parent_pipe: connection.Connection,
2562
+ child_pipe: connection.Connection,
2563
+ env_fun: EnvBase | Callable,
2564
+ env_fun_kwargs: dict[str, Any],
2565
+ mp_event: mp.Event = None,
2566
+ shared_tensordict: TensorDictBase = None,
2567
+ _selected_input_keys=None,
2568
+ _selected_reset_keys=None,
2569
+ _selected_step_keys=None,
2570
+ _non_tensor_keys=None,
2571
+ non_blocking: bool = False,
2572
+ has_lazy_inputs: bool = False,
2573
+ verbose: bool = False,
2574
+ num_threads: int | None = None, # for fork start method
2575
+ filter_warnings: bool = False,
2576
+ ) -> None:
2577
+ pid = os.getpid()
2578
+ # Handle warning filtering (moved from _ProcessNoWarn)
2579
+ if filter_warnings:
2580
+ warnings.filterwarnings("ignore")
2581
+ if num_threads is not None:
2582
+ torch.set_num_threads(num_threads)
2583
+ device = shared_tensordict.device
2584
+ if device is None or device.type != "cuda":
2585
+ # Check if some tensors are shared on cuda
2586
+ has_cuda = [False]
2587
+
2588
+ def look_for_cuda(tensor, has_cuda=has_cuda):
2589
+ has_cuda[0] = has_cuda[0] or tensor.is_cuda
2590
+
2591
+ shared_tensordict.apply(look_for_cuda, filter_empty=True)
2592
+ has_cuda = has_cuda[0]
2593
+ else:
2594
+ has_cuda = device.type == "cuda"
2595
+ if has_cuda:
2596
+ event = torch.cuda.Event()
2597
+ else:
2598
+ event = None
2599
+ parent_pipe.close()
2600
+ if not isinstance(env_fun, EnvBase):
2601
+ env = env_fun(**env_fun_kwargs)
2602
+ else:
2603
+ if env_fun_kwargs:
2604
+ raise RuntimeError(
2605
+ "env_fun_kwargs must be empty if an environment is passed to a process."
2606
+ )
2607
+ env = env_fun
2608
+ del env_fun
2609
+ env.set_spec_lock_()
2610
+
2611
+ i = -1
2612
+ import torchrl
2613
+
2614
+ _timeout = torchrl._utils.BATCHED_PIPE_TIMEOUT
2615
+
2616
+ initialized = False
2617
+
2618
+ child_pipe.send("started")
2619
+ next_shared_tensordict, root_shared_tensordict = (None,) * 2
2620
+ _cmd_count = 0
2621
+ _last_cmd = "N/A"
2622
+ # Create a timeit instance to track elapsed time since worker start
2623
+ _worker_timer = timeit(f"batched_env_worker/{pid}/lifetime").start()
2624
+ while True:
2625
+ try:
2626
+ if child_pipe.poll(_timeout):
2627
+ cmd, data = child_pipe.recv()
2628
+ _cmd_count += 1
2629
+ _last_cmd = cmd
2630
+ # Log every 1000 commands
2631
+ if _cmd_count % 1000 == 0:
2632
+ torchrl_logger.debug(
2633
+ f"batched_env worker {pid}: cmd_count={_cmd_count}, "
2634
+ f"elapsed={_worker_timer.elapsed():.1f}s, last_cmd={cmd}"
2635
+ )
2636
+ else:
2637
+ torchrl_logger.debug(
2638
+ f"batched_env worker {pid}: TIMEOUT after {_timeout}s waiting for cmd, "
2639
+ f"elapsed_since_start={_worker_timer.elapsed():.1f}s, "
2640
+ f"last_cmd={_last_cmd}, cmd_count={_cmd_count}"
2641
+ )
2642
+ raise TimeoutError(
2643
+ f"Worker timed out after {_timeout}s, "
2644
+ f"increase timeout if needed through the BATCHED_PIPE_TIMEOUT environment variable."
2645
+ )
2646
+ except EOFError as err:
2647
+ torchrl_logger.debug(
2648
+ f"batched_env worker {pid}: EOFError - pipe closed, "
2649
+ f"elapsed_since_start={_worker_timer.elapsed():.1f}s, "
2650
+ f"last_cmd={_last_cmd}, cmd_count={_cmd_count}"
2651
+ )
2652
+ raise EOFError(f"proc {pid} failed, last command: {_last_cmd}.") from err
2653
+ if cmd == "seed":
2654
+ if not initialized:
2655
+ raise RuntimeError("call 'init' before closing")
2656
+ torch.manual_seed(data[0])
2657
+ new_seed = env.set_seed(data[0], static_seed=data[1])
2658
+ child_pipe.send(("seeded", new_seed))
2659
+
2660
+ elif cmd == "init":
2661
+ if verbose:
2662
+ torchrl_logger.info(f"initializing {pid}")
2663
+ if initialized:
2664
+ raise RuntimeError("worker already initialized")
2665
+ i = 0
2666
+ next_shared_tensordict = shared_tensordict.get("next")
2667
+ root_shared_tensordict = shared_tensordict.exclude("next")
2668
+ # TODO: restore this
2669
+ # if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()):
2670
+ # raise RuntimeError(
2671
+ # "tensordict must be placed in shared memory (share_memory_() or memmap_())"
2672
+ # )
2673
+ shared_tensordict = shared_tensordict.clone(False).unlock_()
2674
+
2675
+ initialized = True
2676
+
2677
+ elif cmd == "reset":
2678
+ if verbose:
2679
+ torchrl_logger.info(f"resetting worker {pid}")
2680
+ if not initialized:
2681
+ raise RuntimeError("call 'init' before resetting")
2682
+ # we use 'data' to pass the keys that we need to pass to reset,
2683
+ # because passing the entire buffer may have unwanted consequences
2684
+ selected_reset_keys, reset_kwargs = data
2685
+ cur_td = env.reset(
2686
+ tensordict=root_shared_tensordict.select(
2687
+ *selected_reset_keys, strict=False
2688
+ )
2689
+ if selected_reset_keys
2690
+ else None,
2691
+ **reset_kwargs,
2692
+ )
2693
+ shared_tensordict.update_(
2694
+ cur_td,
2695
+ keys_to_update=list(_selected_reset_keys),
2696
+ non_blocking=non_blocking,
2697
+ )
2698
+ if event is not None:
2699
+ event.record()
2700
+ event.synchronize()
2701
+
2702
+ if _non_tensor_keys:
2703
+ # Set event BEFORE sending to avoid deadlocks when the pipe buffer
2704
+ # is full (the parent will start reading as soon as it observes
2705
+ # the event).
2706
+ mp_event.set()
2707
+ child_pipe.send(
2708
+ ("non_tensor", cur_td.select(*_non_tensor_keys, strict=False))
2709
+ )
2710
+ else:
2711
+ mp_event.set()
2712
+
2713
+ del cur_td
2714
+
2715
+ elif cmd == "step":
2716
+ if not initialized:
2717
+ raise RuntimeError("called 'init' before step")
2718
+ i += 1
2719
+ # No need to copy here since we don't write in-place
2720
+ input = root_shared_tensordict.copy()
2721
+ if data:
2722
+ next_td_passthrough_keys = data.get("next_td_passthrough_keys")
2723
+ if next_td_passthrough_keys is not None:
2724
+ input = input.set(
2725
+ "next", next_shared_tensordict.select(*next_td_passthrough_keys)
2726
+ )
2727
+ non_tensor_data = data.get("non_tensor_data")
2728
+ if non_tensor_data is not None:
2729
+ input.update(non_tensor_data)
2730
+
2731
+ input = env.step(input)
2732
+ next_td = input.get("next")
2733
+ next_shared_tensordict.update_(next_td, non_blocking=non_blocking)
2734
+
2735
+ if event is not None:
2736
+ event.record()
2737
+ event.synchronize()
2738
+
2739
+ # Make sure the root is updated
2740
+ root_shared_tensordict.update_(env._step_mdp(input))
2741
+
2742
+ if _non_tensor_keys:
2743
+ # Set event BEFORE sending to avoid deadlocks when the pipe buffer
2744
+ # is full (the parent will start reading as soon as it observes
2745
+ # the event).
2746
+ mp_event.set()
2747
+ child_pipe.send(
2748
+ ("non_tensor", next_td.select(*_non_tensor_keys, strict=False))
2749
+ )
2750
+ else:
2751
+ mp_event.set()
2752
+
2753
+ del next_td
2754
+
2755
+ elif cmd == "step_and_maybe_reset":
2756
+ if not initialized:
2757
+ raise RuntimeError("called 'init' before step")
2758
+ i += 1
2759
+ # We must copy the root shared td here, or at least get rid of done:
2760
+ # if we don't `td is root_shared_tensordict`
2761
+ # which means that root_shared_tensordict will carry the content of next
2762
+ # in the next iteration. When using StepCounter, it will look for an
2763
+ # existing done state, find it and consider the env as done by input (not
2764
+ # by output) of the step!
2765
+ # Caveat: for RNN we may need some keys of the "next" TD so we pass the list
2766
+ # through data
2767
+ input = root_shared_tensordict
2768
+ if data:
2769
+ next_td_passthrough_keys = data.get("next_td_passthrough_keys", None)
2770
+ if next_td_passthrough_keys is not None:
2771
+ input = input.set(
2772
+ "next", next_shared_tensordict.select(*next_td_passthrough_keys)
2773
+ )
2774
+ non_tensor_data = data.get("non_tensor_data", None)
2775
+ if non_tensor_data is not None:
2776
+ input.update(non_tensor_data)
2777
+ td, root_next_td = env.step_and_maybe_reset(input)
2778
+ td_next = td.pop("next")
2779
+ next_shared_tensordict.update_(td_next, non_blocking=non_blocking)
2780
+ root_shared_tensordict.update_(root_next_td, non_blocking=non_blocking)
2781
+
2782
+ if event is not None:
2783
+ event.record()
2784
+ event.synchronize()
2785
+
2786
+ if _non_tensor_keys:
2787
+ ntd = root_next_td.select(*_non_tensor_keys)
2788
+ ntd.set("next", td_next.select(*_non_tensor_keys))
2789
+ # Set event BEFORE sending to avoid deadlocks when the pipe buffer
2790
+ # is full (the parent will start reading as soon as it observes
2791
+ # the event).
2792
+ mp_event.set()
2793
+ child_pipe.send(("non_tensor", ntd))
2794
+ else:
2795
+ mp_event.set()
2796
+
2797
+ del td, root_next_td
2798
+
2799
+ elif cmd == "close":
2800
+ if not initialized:
2801
+ raise RuntimeError("call 'init' before closing")
2802
+ env.close()
2803
+ del (
2804
+ env,
2805
+ shared_tensordict,
2806
+ data,
2807
+ next_shared_tensordict,
2808
+ root_shared_tensordict,
2809
+ )
2810
+ mp_event.set()
2811
+ child_pipe.close()
2812
+ if verbose:
2813
+ torchrl_logger.info(f"{pid} closed")
2814
+ gc.collect()
2815
+ break
2816
+
2817
+ elif cmd == "load_state_dict":
2818
+ env.load_state_dict(data)
2819
+ mp_event.set()
2820
+
2821
+ elif cmd == "state_dict":
2822
+ state_dict = _recursively_strip_locks_from_state_dict(env.state_dict())
2823
+ msg = "state_dict"
2824
+ child_pipe.send((msg, state_dict))
2825
+ del state_dict
2826
+
2827
+ else:
2828
+ err_msg = f"{cmd} from env"
2829
+ try:
2830
+ attr = getattr(env, cmd)
2831
+ if callable(attr):
2832
+ args, kwargs = data
2833
+ args_replace = []
2834
+ for _arg in args:
2835
+ if isinstance(_arg, str) and _arg == "_self":
2836
+ continue
2837
+ else:
2838
+ args_replace.append(_arg)
2839
+ result = attr(*args_replace, **kwargs)
2840
+ else:
2841
+ result = attr
2842
+ except Exception as err:
2843
+ raise AttributeError(
2844
+ f"querying {err_msg} resulted in an error."
2845
+ ) from err
2846
+ if cmd not in ("to"):
2847
+ child_pipe.send(("_".join([cmd, "done"]), result))
2848
+ else:
2849
+ # don't send env through pipe
2850
+ child_pipe.send(("_".join([cmd, "done"]), None))
2851
+
2852
+
2853
+ def _run_worker_pipe_direct(
2854
+ parent_pipe: connection.Connection,
2855
+ child_pipe: connection.Connection,
2856
+ env_fun: EnvBase | Callable,
2857
+ env_fun_kwargs: dict[str, Any],
2858
+ mp_event: mp.Event = None,
2859
+ non_blocking: bool = False,
2860
+ has_lazy_inputs: bool = False,
2861
+ verbose: bool = False,
2862
+ num_threads: int | None = None, # for fork start method
2863
+ consolidate: bool = True,
2864
+ filter_warnings: bool = False,
2865
+ ) -> None:
2866
+ # Handle warning filtering (moved from _ProcessNoWarn)
2867
+ if filter_warnings:
2868
+ warnings.filterwarnings("ignore")
2869
+ if num_threads is not None:
2870
+ torch.set_num_threads(num_threads)
2871
+
2872
+ parent_pipe.close()
2873
+ pid = os.getpid()
2874
+ if not isinstance(env_fun, EnvBase):
2875
+ env = env_fun(**env_fun_kwargs)
2876
+ else:
2877
+ if env_fun_kwargs:
2878
+ raise RuntimeError(
2879
+ "env_fun_kwargs must be empty if an environment is passed to a process."
2880
+ )
2881
+ env = env_fun
2882
+ del env_fun
2883
+ for spec in env.output_spec.values(True, True):
2884
+ if spec.device is not None and spec.device.type == "cuda":
2885
+ has_cuda = True
2886
+ break
2887
+ else:
2888
+ for spec in env.input_spec.values(True, True):
2889
+ if spec.device is not None and spec.device.type == "cuda":
2890
+ has_cuda = True
2891
+ break
2892
+ else:
2893
+ has_cuda = False
2894
+ if has_cuda:
2895
+ event = torch.cuda.Event()
2896
+ else:
2897
+ event = None
2898
+
2899
+ i = -1
2900
+ import torchrl
2901
+
2902
+ _timeout = torchrl._utils.BATCHED_PIPE_TIMEOUT
2903
+
2904
+ initialized = False
2905
+
2906
+ child_pipe.send("started")
2907
+ while True:
2908
+ try:
2909
+ if child_pipe.poll(_timeout):
2910
+ cmd, data = child_pipe.recv()
2911
+ else:
2912
+ raise TimeoutError(
2913
+ f"Worker timed out after {_timeout}s, "
2914
+ f"increase timeout if needed through the BATCHED_PIPE_TIMEOUT environment variable."
2915
+ )
2916
+ except EOFError as err:
2917
+ raise EOFError(f"proc {pid} failed, last command: {cmd}.") from err
2918
+ if cmd == "seed":
2919
+ if not initialized:
2920
+ raise RuntimeError("call 'init' before closing")
2921
+ # torch.manual_seed(data)
2922
+ # np.random.seed(data)
2923
+ new_seed = env.set_seed(data[0], static_seed=data[1])
2924
+ child_pipe.send(("seeded", new_seed))
2925
+
2926
+ elif cmd == "init":
2927
+ if verbose:
2928
+ torchrl_logger.info(f"initializing {pid}")
2929
+ if initialized:
2930
+ raise RuntimeError("worker already initialized")
2931
+ i = 0
2932
+
2933
+ initialized = True
2934
+
2935
+ elif cmd == "reset":
2936
+ if verbose:
2937
+ torchrl_logger.info(f"resetting worker {pid}")
2938
+ if not initialized:
2939
+ raise RuntimeError("call 'init' before resetting")
2940
+ # we use 'data' to pass the keys that we need to pass to reset,
2941
+ # because passing the entire buffer may have unwanted consequences
2942
+ data, reset_kwargs = data
2943
+ if data is not None:
2944
+ data.unlock_()
2945
+ data._fast_apply(
2946
+ lambda x: x.clone() if x.device.type == "cuda" else x, out=data
2947
+ )
2948
+ cur_td = env.reset(
2949
+ tensordict=data,
2950
+ **reset_kwargs,
2951
+ )
2952
+ if event is not None:
2953
+ event.record()
2954
+ event.synchronize()
2955
+ if consolidate:
2956
+ try:
2957
+ cur_td = cur_td.consolidate(
2958
+ # share_memory=False: avoid resource_sharer which causes
2959
+ # progressive slowdown with fork on Linux
2960
+ share_memory=False,
2961
+ inplace=True,
2962
+ num_threads=1,
2963
+ )
2964
+ except Exception as err:
2965
+ raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
2966
+ # Set event BEFORE send so parent starts reading, which unblocks send
2967
+ # if pipe buffer was full (prevents deadlock)
2968
+ mp_event.set()
2969
+ child_pipe.send(cur_td)
2970
+
2971
+ del cur_td
2972
+
2973
+ elif cmd == "step":
2974
+ if not initialized:
2975
+ raise RuntimeError("called 'init' before step")
2976
+ i += 1
2977
+ next_td = env._step(data)
2978
+ if event is not None:
2979
+ event.record()
2980
+ event.synchronize()
2981
+ if consolidate:
2982
+ try:
2983
+ next_td = next_td.consolidate(
2984
+ # share_memory=False: avoid resource_sharer which causes
2985
+ # progressive slowdown with fork on Linux
2986
+ share_memory=False,
2987
+ inplace=True,
2988
+ num_threads=1,
2989
+ )
2990
+ except Exception as err:
2991
+ raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
2992
+ # Set event BEFORE send so parent starts reading, which unblocks send
2993
+ # if pipe buffer was full (prevents deadlock)
2994
+ mp_event.set()
2995
+ child_pipe.send(next_td)
2996
+
2997
+ del next_td
2998
+
2999
+ elif cmd == "step_and_maybe_reset":
3000
+ if not initialized:
3001
+ raise RuntimeError("called 'init' before step")
3002
+ i += 1
3003
+ # data, idx = data
3004
+ # data = data[idx]
3005
+ data._fast_apply(
3006
+ lambda x: x.clone() if x.device.type == "cuda" else x, out=data
3007
+ )
3008
+ td, root_next_td = env.step_and_maybe_reset(data)
3009
+
3010
+ if event is not None:
3011
+ event.record()
3012
+ event.synchronize()
3013
+ child_pipe.send((td, root_next_td))
3014
+ mp_event.set()
3015
+ del td, root_next_td
3016
+
3017
+ elif cmd == "close":
3018
+ if not initialized:
3019
+ raise RuntimeError("call 'init' before closing")
3020
+ env.close()
3021
+ mp_event.set()
3022
+ child_pipe.close()
3023
+ if verbose:
3024
+ torchrl_logger.info(f"{pid} closed")
3025
+ del (env, data, child_pipe, mp_event)
3026
+ gc.collect()
3027
+ return
3028
+
3029
+ elif cmd == "load_state_dict":
3030
+ env.load_state_dict(data)
3031
+ mp_event.set()
3032
+
3033
+ elif cmd == "state_dict":
3034
+ state_dict = _recursively_strip_locks_from_state_dict(env.state_dict())
3035
+ msg = "state_dict"
3036
+ child_pipe.send((msg, state_dict))
3037
+ del state_dict
3038
+
3039
+ else:
3040
+ err_msg = f"{cmd} from env"
3041
+ try:
3042
+ attr = getattr(env, cmd)
3043
+ if callable(attr):
3044
+ args, kwargs = data
3045
+ args_replace = []
3046
+ for _arg in args:
3047
+ if isinstance(_arg, str) and _arg == "_self":
3048
+ continue
3049
+ else:
3050
+ args_replace.append(_arg)
3051
+ result = attr(*args_replace, **kwargs)
3052
+ else:
3053
+ result = attr
3054
+ except Exception as err:
3055
+ raise AttributeError(
3056
+ f"querying {err_msg} resulted in an error."
3057
+ ) from err
3058
+ if cmd not in ("to"):
3059
+ child_pipe.send(("_".join([cmd, "done"]), result))
3060
+ else:
3061
+ # don't send env through pipe
3062
+ child_pipe.send(("_".join([cmd, "done"]), None))
3063
+
3064
+
3065
+ def _filter_empty(tensordict):
3066
+ return tensordict.select(*tensordict.keys(True, True))
3067
+
3068
+
3069
+ def _stackable(*tensordicts):
3070
+ try:
3071
+ ls = LazyStackedTensorDict(*tensordicts, stack_dim=0)
3072
+ ls.contiguous()
3073
+ return not ls._has_exclusive_keys
3074
+ except RuntimeError:
3075
+ return False
3076
+
3077
+
3078
+ def _cuda_sync(device):
3079
+ return functools.partial(torch.cuda.synchronize, device=device)
3080
+
3081
+
3082
+ def _mps_sync(device):
3083
+ return torch.mps.synchronize
3084
+
3085
+
3086
+ # Create an alias for possible imports
3087
+ _BatchedEnv = BatchedEnvBase
3088
+
3089
+ # legacy re-exports (must be at end of file to avoid circular imports)
3090
+ from torchrl.envs.libs.envpool import ( # noqa: F401, E402
3091
+ MultiThreadedEnv,
3092
+ MultiThreadedEnvWrapper,
3093
+ )