torchrl 0.11.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (394) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/__init__.py +144 -0
  97. torchrl/_extension.py +74 -0
  98. torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
  99. torchrl/_utils.py +1431 -0
  100. torchrl/collectors/__init__.py +48 -0
  101. torchrl/collectors/_base.py +1058 -0
  102. torchrl/collectors/_constants.py +88 -0
  103. torchrl/collectors/_multi_async.py +324 -0
  104. torchrl/collectors/_multi_base.py +1805 -0
  105. torchrl/collectors/_multi_sync.py +464 -0
  106. torchrl/collectors/_runner.py +581 -0
  107. torchrl/collectors/_single.py +2009 -0
  108. torchrl/collectors/_single_async.py +259 -0
  109. torchrl/collectors/collectors.py +62 -0
  110. torchrl/collectors/distributed/__init__.py +32 -0
  111. torchrl/collectors/distributed/default_configs.py +133 -0
  112. torchrl/collectors/distributed/generic.py +1306 -0
  113. torchrl/collectors/distributed/ray.py +1092 -0
  114. torchrl/collectors/distributed/rpc.py +1006 -0
  115. torchrl/collectors/distributed/sync.py +731 -0
  116. torchrl/collectors/distributed/utils.py +160 -0
  117. torchrl/collectors/llm/__init__.py +10 -0
  118. torchrl/collectors/llm/base.py +494 -0
  119. torchrl/collectors/llm/ray_collector.py +275 -0
  120. torchrl/collectors/llm/utils.py +36 -0
  121. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  122. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  123. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  124. torchrl/collectors/utils.py +433 -0
  125. torchrl/collectors/weight_update.py +591 -0
  126. torchrl/csrc/numpy_utils.h +38 -0
  127. torchrl/csrc/pybind.cpp +27 -0
  128. torchrl/csrc/segment_tree.h +458 -0
  129. torchrl/csrc/torch_utils.h +34 -0
  130. torchrl/csrc/utils.cpp +48 -0
  131. torchrl/csrc/utils.h +31 -0
  132. torchrl/data/__init__.py +187 -0
  133. torchrl/data/datasets/__init__.py +58 -0
  134. torchrl/data/datasets/atari_dqn.py +878 -0
  135. torchrl/data/datasets/common.py +281 -0
  136. torchrl/data/datasets/d4rl.py +489 -0
  137. torchrl/data/datasets/d4rl_infos.py +187 -0
  138. torchrl/data/datasets/gen_dgrl.py +375 -0
  139. torchrl/data/datasets/minari_data.py +643 -0
  140. torchrl/data/datasets/openml.py +177 -0
  141. torchrl/data/datasets/openx.py +798 -0
  142. torchrl/data/datasets/roboset.py +363 -0
  143. torchrl/data/datasets/utils.py +11 -0
  144. torchrl/data/datasets/vd4rl.py +432 -0
  145. torchrl/data/llm/__init__.py +34 -0
  146. torchrl/data/llm/dataset.py +491 -0
  147. torchrl/data/llm/history.py +1378 -0
  148. torchrl/data/llm/prompt.py +198 -0
  149. torchrl/data/llm/reward.py +225 -0
  150. torchrl/data/llm/topk.py +186 -0
  151. torchrl/data/llm/utils.py +543 -0
  152. torchrl/data/map/__init__.py +21 -0
  153. torchrl/data/map/hash.py +185 -0
  154. torchrl/data/map/query.py +204 -0
  155. torchrl/data/map/tdstorage.py +363 -0
  156. torchrl/data/map/tree.py +1434 -0
  157. torchrl/data/map/utils.py +103 -0
  158. torchrl/data/postprocs/__init__.py +8 -0
  159. torchrl/data/postprocs/postprocs.py +391 -0
  160. torchrl/data/replay_buffers/__init__.py +99 -0
  161. torchrl/data/replay_buffers/checkpointers.py +622 -0
  162. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  163. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  164. torchrl/data/replay_buffers/samplers.py +2578 -0
  165. torchrl/data/replay_buffers/scheduler.py +265 -0
  166. torchrl/data/replay_buffers/storages.py +2412 -0
  167. torchrl/data/replay_buffers/utils.py +1042 -0
  168. torchrl/data/replay_buffers/writers.py +781 -0
  169. torchrl/data/tensor_specs.py +7101 -0
  170. torchrl/data/utils.py +334 -0
  171. torchrl/envs/__init__.py +265 -0
  172. torchrl/envs/async_envs.py +1105 -0
  173. torchrl/envs/batched_envs.py +3093 -0
  174. torchrl/envs/common.py +4241 -0
  175. torchrl/envs/custom/__init__.py +11 -0
  176. torchrl/envs/custom/chess.py +617 -0
  177. torchrl/envs/custom/llm.py +214 -0
  178. torchrl/envs/custom/pendulum.py +401 -0
  179. torchrl/envs/custom/san_moves.txt +29274 -0
  180. torchrl/envs/custom/tictactoeenv.py +288 -0
  181. torchrl/envs/env_creator.py +263 -0
  182. torchrl/envs/gym_like.py +752 -0
  183. torchrl/envs/libs/__init__.py +68 -0
  184. torchrl/envs/libs/_gym_utils.py +326 -0
  185. torchrl/envs/libs/brax.py +846 -0
  186. torchrl/envs/libs/dm_control.py +544 -0
  187. torchrl/envs/libs/envpool.py +447 -0
  188. torchrl/envs/libs/gym.py +2239 -0
  189. torchrl/envs/libs/habitat.py +138 -0
  190. torchrl/envs/libs/isaac_lab.py +87 -0
  191. torchrl/envs/libs/isaacgym.py +203 -0
  192. torchrl/envs/libs/jax_utils.py +166 -0
  193. torchrl/envs/libs/jumanji.py +963 -0
  194. torchrl/envs/libs/meltingpot.py +599 -0
  195. torchrl/envs/libs/openml.py +153 -0
  196. torchrl/envs/libs/openspiel.py +652 -0
  197. torchrl/envs/libs/pettingzoo.py +1042 -0
  198. torchrl/envs/libs/procgen.py +351 -0
  199. torchrl/envs/libs/robohive.py +429 -0
  200. torchrl/envs/libs/smacv2.py +645 -0
  201. torchrl/envs/libs/unity_mlagents.py +891 -0
  202. torchrl/envs/libs/utils.py +147 -0
  203. torchrl/envs/libs/vmas.py +813 -0
  204. torchrl/envs/llm/__init__.py +63 -0
  205. torchrl/envs/llm/chat.py +730 -0
  206. torchrl/envs/llm/datasets/README.md +4 -0
  207. torchrl/envs/llm/datasets/__init__.py +17 -0
  208. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  209. torchrl/envs/llm/datasets/ifeval.py +274 -0
  210. torchrl/envs/llm/envs.py +789 -0
  211. torchrl/envs/llm/libs/README.md +3 -0
  212. torchrl/envs/llm/libs/__init__.py +8 -0
  213. torchrl/envs/llm/libs/mlgym.py +869 -0
  214. torchrl/envs/llm/reward/__init__.py +10 -0
  215. torchrl/envs/llm/reward/gsm8k.py +324 -0
  216. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  217. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  218. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  222. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  223. torchrl/envs/llm/transforms/__init__.py +55 -0
  224. torchrl/envs/llm/transforms/browser.py +292 -0
  225. torchrl/envs/llm/transforms/dataloading.py +859 -0
  226. torchrl/envs/llm/transforms/format.py +73 -0
  227. torchrl/envs/llm/transforms/kl.py +1544 -0
  228. torchrl/envs/llm/transforms/policy_version.py +189 -0
  229. torchrl/envs/llm/transforms/reason.py +323 -0
  230. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  231. torchrl/envs/llm/transforms/tools.py +1955 -0
  232. torchrl/envs/model_based/__init__.py +9 -0
  233. torchrl/envs/model_based/common.py +180 -0
  234. torchrl/envs/model_based/dreamer.py +112 -0
  235. torchrl/envs/transforms/__init__.py +147 -0
  236. torchrl/envs/transforms/functional.py +48 -0
  237. torchrl/envs/transforms/gym_transforms.py +203 -0
  238. torchrl/envs/transforms/module.py +341 -0
  239. torchrl/envs/transforms/r3m.py +372 -0
  240. torchrl/envs/transforms/ray_service.py +663 -0
  241. torchrl/envs/transforms/rb_transforms.py +214 -0
  242. torchrl/envs/transforms/transforms.py +11835 -0
  243. torchrl/envs/transforms/utils.py +94 -0
  244. torchrl/envs/transforms/vc1.py +307 -0
  245. torchrl/envs/transforms/vecnorm.py +845 -0
  246. torchrl/envs/transforms/vip.py +407 -0
  247. torchrl/envs/utils.py +1718 -0
  248. torchrl/envs/vec_envs.py +11 -0
  249. torchrl/modules/__init__.py +206 -0
  250. torchrl/modules/distributions/__init__.py +73 -0
  251. torchrl/modules/distributions/continuous.py +830 -0
  252. torchrl/modules/distributions/discrete.py +908 -0
  253. torchrl/modules/distributions/truncated_normal.py +187 -0
  254. torchrl/modules/distributions/utils.py +233 -0
  255. torchrl/modules/llm/__init__.py +62 -0
  256. torchrl/modules/llm/backends/__init__.py +65 -0
  257. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  258. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  259. torchrl/modules/llm/backends/vllm/base.py +72 -0
  260. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  261. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  262. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  263. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  264. torchrl/modules/llm/policies/__init__.py +28 -0
  265. torchrl/modules/llm/policies/common.py +1809 -0
  266. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  267. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  268. torchrl/modules/llm/utils.py +23 -0
  269. torchrl/modules/mcts/__init__.py +21 -0
  270. torchrl/modules/mcts/scores.py +579 -0
  271. torchrl/modules/models/__init__.py +86 -0
  272. torchrl/modules/models/batchrenorm.py +119 -0
  273. torchrl/modules/models/decision_transformer.py +179 -0
  274. torchrl/modules/models/exploration.py +731 -0
  275. torchrl/modules/models/llm.py +156 -0
  276. torchrl/modules/models/model_based.py +596 -0
  277. torchrl/modules/models/models.py +1712 -0
  278. torchrl/modules/models/multiagent.py +1067 -0
  279. torchrl/modules/models/recipes/impala.py +185 -0
  280. torchrl/modules/models/utils.py +162 -0
  281. torchrl/modules/planners/__init__.py +10 -0
  282. torchrl/modules/planners/cem.py +228 -0
  283. torchrl/modules/planners/common.py +73 -0
  284. torchrl/modules/planners/mppi.py +265 -0
  285. torchrl/modules/tensordict_module/__init__.py +89 -0
  286. torchrl/modules/tensordict_module/actors.py +2457 -0
  287. torchrl/modules/tensordict_module/common.py +529 -0
  288. torchrl/modules/tensordict_module/exploration.py +814 -0
  289. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  290. torchrl/modules/tensordict_module/rnn.py +1639 -0
  291. torchrl/modules/tensordict_module/sequence.py +132 -0
  292. torchrl/modules/tensordict_module/world_models.py +34 -0
  293. torchrl/modules/utils/__init__.py +38 -0
  294. torchrl/modules/utils/mappings.py +9 -0
  295. torchrl/modules/utils/utils.py +89 -0
  296. torchrl/objectives/__init__.py +78 -0
  297. torchrl/objectives/a2c.py +659 -0
  298. torchrl/objectives/common.py +753 -0
  299. torchrl/objectives/cql.py +1346 -0
  300. torchrl/objectives/crossq.py +710 -0
  301. torchrl/objectives/ddpg.py +453 -0
  302. torchrl/objectives/decision_transformer.py +371 -0
  303. torchrl/objectives/deprecated.py +516 -0
  304. torchrl/objectives/dqn.py +683 -0
  305. torchrl/objectives/dreamer.py +488 -0
  306. torchrl/objectives/functional.py +48 -0
  307. torchrl/objectives/gail.py +258 -0
  308. torchrl/objectives/iql.py +996 -0
  309. torchrl/objectives/llm/__init__.py +30 -0
  310. torchrl/objectives/llm/grpo.py +846 -0
  311. torchrl/objectives/llm/sft.py +482 -0
  312. torchrl/objectives/multiagent/__init__.py +8 -0
  313. torchrl/objectives/multiagent/qmixer.py +396 -0
  314. torchrl/objectives/ppo.py +1669 -0
  315. torchrl/objectives/redq.py +683 -0
  316. torchrl/objectives/reinforce.py +530 -0
  317. torchrl/objectives/sac.py +1580 -0
  318. torchrl/objectives/td3.py +570 -0
  319. torchrl/objectives/td3_bc.py +625 -0
  320. torchrl/objectives/utils.py +782 -0
  321. torchrl/objectives/value/__init__.py +28 -0
  322. torchrl/objectives/value/advantages.py +1956 -0
  323. torchrl/objectives/value/functional.py +1459 -0
  324. torchrl/objectives/value/utils.py +360 -0
  325. torchrl/record/__init__.py +17 -0
  326. torchrl/record/loggers/__init__.py +23 -0
  327. torchrl/record/loggers/common.py +48 -0
  328. torchrl/record/loggers/csv.py +226 -0
  329. torchrl/record/loggers/mlflow.py +142 -0
  330. torchrl/record/loggers/tensorboard.py +139 -0
  331. torchrl/record/loggers/trackio.py +163 -0
  332. torchrl/record/loggers/utils.py +78 -0
  333. torchrl/record/loggers/wandb.py +214 -0
  334. torchrl/record/recorder.py +554 -0
  335. torchrl/services/__init__.py +79 -0
  336. torchrl/services/base.py +109 -0
  337. torchrl/services/ray_service.py +453 -0
  338. torchrl/testing/__init__.py +107 -0
  339. torchrl/testing/assertions.py +179 -0
  340. torchrl/testing/dist_utils.py +122 -0
  341. torchrl/testing/env_creators.py +227 -0
  342. torchrl/testing/env_helper.py +35 -0
  343. torchrl/testing/gym_helpers.py +156 -0
  344. torchrl/testing/llm_mocks.py +119 -0
  345. torchrl/testing/mocking_classes.py +2720 -0
  346. torchrl/testing/modules.py +295 -0
  347. torchrl/testing/mp_helpers.py +15 -0
  348. torchrl/testing/ray_helpers.py +293 -0
  349. torchrl/testing/utils.py +190 -0
  350. torchrl/trainers/__init__.py +42 -0
  351. torchrl/trainers/algorithms/__init__.py +11 -0
  352. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  353. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  354. torchrl/trainers/algorithms/configs/common.py +41 -0
  355. torchrl/trainers/algorithms/configs/data.py +308 -0
  356. torchrl/trainers/algorithms/configs/envs.py +104 -0
  357. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  358. torchrl/trainers/algorithms/configs/logging.py +80 -0
  359. torchrl/trainers/algorithms/configs/modules.py +570 -0
  360. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  361. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  362. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  363. torchrl/trainers/algorithms/configs/utils.py +252 -0
  364. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  365. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  366. torchrl/trainers/algorithms/ppo.py +373 -0
  367. torchrl/trainers/algorithms/sac.py +308 -0
  368. torchrl/trainers/helpers/__init__.py +40 -0
  369. torchrl/trainers/helpers/collectors.py +416 -0
  370. torchrl/trainers/helpers/envs.py +573 -0
  371. torchrl/trainers/helpers/logger.py +33 -0
  372. torchrl/trainers/helpers/losses.py +132 -0
  373. torchrl/trainers/helpers/models.py +658 -0
  374. torchrl/trainers/helpers/replay_buffer.py +59 -0
  375. torchrl/trainers/helpers/trainers.py +301 -0
  376. torchrl/trainers/trainers.py +2052 -0
  377. torchrl/weight_update/__init__.py +33 -0
  378. torchrl/weight_update/_distributed.py +749 -0
  379. torchrl/weight_update/_mp.py +624 -0
  380. torchrl/weight_update/_noupdate.py +102 -0
  381. torchrl/weight_update/_ray.py +1032 -0
  382. torchrl/weight_update/_rpc.py +284 -0
  383. torchrl/weight_update/_shared.py +891 -0
  384. torchrl/weight_update/llm/__init__.py +32 -0
  385. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  386. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  387. torchrl/weight_update/utils.py +73 -0
  388. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  389. torchrl-0.11.0.dist-info/LICENSE +21 -0
  390. torchrl-0.11.0.dist-info/METADATA +1307 -0
  391. torchrl-0.11.0.dist-info/RECORD +394 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/top_level.txt +7 -0
torchrl/_utils.py ADDED
@@ -0,0 +1,1431 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import collections
8
+ import functools
9
+ import inspect
10
+ import logging
11
+ import math
12
+ import os
13
+ import pickle
14
+ import sys
15
+ import threading
16
+ import time
17
+ import traceback
18
+ import warnings
19
+ from collections.abc import Callable
20
+ from contextlib import nullcontext
21
+ from functools import wraps
22
+ from textwrap import indent
23
+ from typing import Any, cast, TypeVar
24
+
25
+ import numpy as np
26
+ import torch
27
+
28
+ from pyvers import implement_for # noqa: F401
29
+ from tensordict import unravel_key
30
+ from tensordict.utils import NestedKey
31
+ from torch import multiprocessing as mp, Tensor
32
+ from torch.autograd.profiler import record_function
33
+
34
+ try:
35
+ from torch.compiler import is_compiling
36
+ except ImportError:
37
+ from torch._dynamo import is_compiling
38
+
39
+
40
+ def _get_default_mp_start_method() -> str:
41
+ """Returns TorchRL's preferred multiprocessing start method.
42
+
43
+ If the user has explicitly set a global start method via ``mp.set_start_method()``,
44
+ that method is returned. Otherwise, defaults to ``"spawn"`` for improved safety
45
+ across backends and to avoid known issues with ``fork`` in multi-threaded programs.
46
+ """
47
+ # Check if user has explicitly set a global start method
48
+ try:
49
+ current = mp.get_start_method(allow_none=True)
50
+ if current is not None:
51
+ return current
52
+ except (TypeError, RuntimeError):
53
+ pass
54
+ return "spawn"
55
+
56
+
57
+ def _get_mp_ctx(start_method: str | None = None):
58
+ """Return a multiprocessing context with TorchRL's preferred start method.
59
+
60
+ This is intentionally context-based (instead of relying on global
61
+ ``mp.set_start_method``) so that TorchRL components can consistently allocate
62
+ primitives (Queue/Pipe/Lock/Process) with a matching context.
63
+ """
64
+ if start_method is None:
65
+ start_method = _get_default_mp_start_method()
66
+ try:
67
+ return mp.get_context(start_method)
68
+ except ValueError:
69
+ # Best effort fallback if a start method isn't supported on this platform.
70
+ return mp.get_context("spawn")
71
+
72
+
73
+ def _set_mp_start_method_if_unset(start_method: str | None = None) -> str | None:
74
+ """Set the global start method only if it hasn't been set yet.
75
+
76
+ Returns the (possibly pre-existing) start method, or ``None`` if it cannot be
77
+ determined.
78
+ """
79
+ if start_method is None:
80
+ start_method = _get_default_mp_start_method()
81
+
82
+ current = None
83
+ try:
84
+ current = mp.get_start_method(allow_none=True)
85
+ except TypeError:
86
+ # Older python/torch wrappers may not accept allow_none.
87
+ try:
88
+ current = mp.get_start_method()
89
+ except Exception:
90
+ current = None
91
+ except Exception:
92
+ current = None
93
+
94
+ if current is None:
95
+ try:
96
+ mp.set_start_method(start_method, force=False)
97
+ current = start_method
98
+ except Exception:
99
+ # If another library already touched the context, we should not
100
+ # override it here.
101
+ pass
102
+ return current
103
+
104
+
105
+ @implement_for("torch", None, "2.8")
106
+ def _mp_sharing_strategy_for_spawn() -> str | None:
107
+ # On older torch stacks, pickling Process objects for "spawn" can end up
108
+ # passing file descriptors for shared storages; using "file_system" reduces
109
+ # FD passing and avoids spawn-time failures on some old Python versions.
110
+ return "file_system"
111
+
112
+
113
+ @implement_for("torch", "2.8")
114
+ def _mp_sharing_strategy_for_spawn() -> str | None: # noqa: F811
115
+ return None
116
+
117
+
118
+ def strtobool(val: Any) -> bool:
119
+ """Convert a string representation of truth to a boolean.
120
+
121
+ True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'.
122
+ Raises ValueError if 'val' is anything else.
123
+ """
124
+ val = val.lower()
125
+ if val in ("y", "yes", "t", "true", "on", "1"):
126
+ return True
127
+ if val in ("n", "no", "f", "false", "off", "0"):
128
+ return False
129
+ raise ValueError(f"Invalid truth value {val!r}")
130
+
131
+
132
+ LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO")
133
+ logger = logging.getLogger("torchrl")
134
+ logger.setLevel(LOGGING_LEVEL)
135
+ logger.propagate = False
136
+ # Clear existing handlers
137
+ while logger.hasHandlers():
138
+ logger.removeHandler(logger.handlers[0])
139
+ stream_handlers = {
140
+ "stdout": sys.stdout,
141
+ "stderr": sys.stderr,
142
+ }
143
+ TORCHRL_CONSOLE_STREAM = os.getenv("TORCHRL_CONSOLE_STREAM")
144
+ stream_handler = stream_handlers.get(TORCHRL_CONSOLE_STREAM, sys.stdout)
145
+
146
+
147
+ # Create colored handler
148
+ class _CustomFormatter(logging.Formatter):
149
+ def format(self, record):
150
+ # Format the initial part in green
151
+ green_format = "\033[92m%(asctime)s [%(name)s][%(levelname)s]\033[0m"
152
+ # Format the message part
153
+ message_format = "%(message)s"
154
+ # End marker in green
155
+ end_marker = "\033[92m [END]\033[0m"
156
+ # Combine all parts
157
+ formatted_message = logging.Formatter(
158
+ green_format + indent(message_format, " " * 4) + end_marker
159
+ ).format(record)
160
+
161
+ return formatted_message
162
+
163
+
164
+ console_handler = logging.StreamHandler(stream=stream_handler)
165
+ console_handler.setFormatter(_CustomFormatter())
166
+ logger.addHandler(console_handler)
167
+
168
+ console_handler.setLevel(LOGGING_LEVEL)
169
+ logger.debug(f"Logging level: {logger.getEffectiveLevel()}")
170
+
171
+ VERBOSE = strtobool(os.environ.get("VERBOSE", str(logger.isEnabledFor(logging.DEBUG))))
172
+ _os_is_windows = sys.platform == "win32"
173
+ RL_WARNINGS = strtobool(os.environ.get("RL_WARNINGS", "1"))
174
+ if RL_WARNINGS:
175
+ warnings.filterwarnings("once", category=DeprecationWarning, module="torchrl")
176
+
177
+ BATCHED_PIPE_TIMEOUT = float(os.environ.get("BATCHED_PIPE_TIMEOUT", "10000.0"))
178
+ WEIGHT_SYNC_TIMEOUT = float(os.environ.get("WEIGHT_SYNC_TIMEOUT", "60.0"))
179
+
180
+ _TORCH_DTYPES = (
181
+ torch.bfloat16,
182
+ torch.bool,
183
+ torch.complex128,
184
+ torch.complex32,
185
+ torch.complex64,
186
+ torch.float16,
187
+ torch.float32,
188
+ torch.float64,
189
+ torch.int16,
190
+ torch.int32,
191
+ torch.int64,
192
+ torch.int8,
193
+ torch.qint32,
194
+ torch.qint8,
195
+ torch.quint4x2,
196
+ torch.quint8,
197
+ torch.uint8,
198
+ )
199
+ if hasattr(torch, "uint16"):
200
+ _TORCH_DTYPES = _TORCH_DTYPES + (torch.uint16,)
201
+ if hasattr(torch, "uint32"):
202
+ _TORCH_DTYPES = _TORCH_DTYPES + (torch.uint32,)
203
+ if hasattr(torch, "uint64"):
204
+ _TORCH_DTYPES = _TORCH_DTYPES + (torch.uint64,)
205
+ _STR_DTYPE_TO_DTYPE = {str(dtype): dtype for dtype in _TORCH_DTYPES}
206
+ _STRDTYPE2DTYPE = _STR_DTYPE_TO_DTYPE
207
+ _DTYPE_TO_STR_DTYPE = {
208
+ dtype: str_dtype for str_dtype, dtype in _STR_DTYPE_TO_DTYPE.items()
209
+ }
210
+ _DTYPE2STRDTYPE = _STR_DTYPE_TO_DTYPE
211
+
212
+
213
+ class timeit:
214
+ """A dirty but easy to use decorator for profiling code.
215
+
216
+ Args:
217
+ name (str): The name of the timer.
218
+
219
+ Examples:
220
+ >>> from torchrl import timeit
221
+ >>> @timeit("my_function")
222
+ >>> def my_function():
223
+ ...
224
+ >>> my_function()
225
+ >>> with timeit("my_other_function"):
226
+ ... my_other_function()
227
+ >>> timeit.print() # prints the state of the timer for each function
228
+
229
+ The timer can also be queried mid-execution using the :meth:`elapsed` method:
230
+
231
+ >>> with timeit("my_function") as timer:
232
+ ... # do some work
233
+ ... print(f"Elapsed so far: {timer.elapsed():.3f}s")
234
+ ... # do more work
235
+
236
+ For long-running processes where a context manager isn't practical,
237
+ use the :meth:`start` method:
238
+
239
+ >>> timer = timeit("long_process").start()
240
+ >>> for i in range(100):
241
+ ... # do work
242
+ ... print(f"Elapsed: {timer.elapsed():.3f}s")
243
+ """
244
+
245
+ _REG = {}
246
+
247
+ def __init__(self, name):
248
+ self.name = name
249
+
250
+ def __call__(self, fn: Callable) -> Callable:
251
+ @wraps(fn)
252
+ def decorated_fn(*args, **kwargs):
253
+ with self:
254
+ out = fn(*args, **kwargs)
255
+ return out
256
+
257
+ return decorated_fn
258
+
259
+ def __enter__(self) -> timeit:
260
+ self.t0 = time.time()
261
+ return self
262
+
263
+ def start(self) -> timeit:
264
+ """Starts the timer without using a context manager.
265
+
266
+ This is useful when you need to track elapsed time over a long-running
267
+ loop or process where a context manager isn't practical.
268
+
269
+ Returns:
270
+ timeit: Returns self for method chaining.
271
+
272
+ Examples:
273
+ >>> timer = timeit("my_long_process").start()
274
+ >>> for i in range(100):
275
+ ... # do work
276
+ ... if i % 10 == 0:
277
+ ... print(f"Elapsed: {timer.elapsed():.3f}s")
278
+ """
279
+ self.t0 = time.time()
280
+ return self
281
+
282
+ def elapsed(self) -> float:
283
+ """Returns the elapsed time in seconds since the timer was started.
284
+
285
+ This can be called during execution to query the current elapsed time.
286
+
287
+ Returns:
288
+ float: Elapsed time in seconds.
289
+
290
+ Examples:
291
+ >>> with timeit("my_function") as timer:
292
+ ... # do some work
293
+ ... print(f"Elapsed so far: {timer.elapsed():.3f}s")
294
+ ... # do more work
295
+ """
296
+ return time.time() - self.t0
297
+
298
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
299
+ t = self.elapsed()
300
+ val = self._REG.setdefault(self.name, [0.0, 0.0, 0])
301
+
302
+ count = val[2]
303
+ N = count + 1
304
+ val[0] = val[0] * (count / N) + t / N
305
+ val[1] += t
306
+ val[2] = N
307
+
308
+ @staticmethod
309
+ def print(prefix: str | None = None) -> str: # noqa: T202
310
+ """Prints the state of the timer.
311
+
312
+ Args:
313
+ prefix (str): The prefix to add to the keys. If `None`, no prefix is added.
314
+
315
+ Returns:
316
+ the string printed using the logger.
317
+ """
318
+ keys = list(timeit._REG)
319
+ keys.sort()
320
+ string = []
321
+ for name in keys:
322
+ strings = []
323
+ if prefix:
324
+ strings.append(prefix)
325
+ strings.append(
326
+ f"{name} took {timeit._REG[name][0] * 1000:4.4f} msec (total = {timeit._REG[name][1]: 4.4f} sec since last reset)."
327
+ )
328
+ string.append(" -- ".join(strings))
329
+ logger.info(string[-1])
330
+ return "\n".join(string)
331
+
332
+ _printevery_count = 0
333
+
334
+ @classmethod
335
+ def printevery(
336
+ cls,
337
+ num_prints: int,
338
+ total_count: int,
339
+ *,
340
+ prefix: str | None = None,
341
+ erase: bool = False,
342
+ ) -> None:
343
+ """Prints the state of the timer at regular intervals.
344
+
345
+ Args:
346
+ num_prints (int): The number of times to print the state of the timer, given the total_count.
347
+ total_count (int): The total number of times to print the state of the timer.
348
+ prefix (str): The prefix to add to the keys. If `None`, no prefix is added.
349
+ erase (bool): If True, erase the timer after printing. Default is `False`.
350
+
351
+ """
352
+ interval = max(1, total_count // num_prints)
353
+ if cls._printevery_count % interval == 0:
354
+ cls.print(prefix=prefix)
355
+ if erase:
356
+ cls.erase()
357
+ cls._printevery_count += 1
358
+
359
+ @classmethod
360
+ def todict(
361
+ cls, percall: bool = True, prefix: str | None = None
362
+ ) -> dict[str, float]:
363
+ """Convert the timer to a dictionary.
364
+
365
+ Args:
366
+ percall (bool): If True, return the average time per call.
367
+ prefix (str): The prefix to add to the keys.
368
+ """
369
+
370
+ def _make_key(key):
371
+ if prefix:
372
+ return f"{prefix}/{key}"
373
+ return key
374
+
375
+ if percall:
376
+ return {_make_key(key): val[0] for key, val in cls._REG.items()}
377
+ return {_make_key(key): val[1] for key, val in cls._REG.items()}
378
+
379
+ @staticmethod
380
+ def erase():
381
+ """Erase the timer.
382
+
383
+ .. seealso:: :meth:`reset`
384
+ """
385
+ for k in timeit._REG:
386
+ timeit._REG[k] = [0.0, 0.0, 0]
387
+
388
+ @classmethod
389
+ def reset(cls):
390
+ """Reset the timer.
391
+
392
+ .. seealso:: :meth:`erase`
393
+ """
394
+ cls.erase()
395
+
396
+
397
+ # Global flag to enable detailed profiling instrumentation.
398
+ # When False (default), _maybe_record_function returns nullcontext() immediately
399
+ # to avoid overhead in hot code paths.
400
+ _PROFILING_ENABLED = False
401
+
402
+ # Singleton nullcontext to avoid repeated object creation
403
+ _NULL_CONTEXT = nullcontext()
404
+
405
+
406
+ def set_profiling_enabled(enabled: bool) -> None:
407
+ """Enable or disable detailed profiling instrumentation.
408
+
409
+ When disabled (default), `_maybe_record_function` and `_maybe_timeit`
410
+ return immediately with minimal overhead. Enable only when actively
411
+ profiling to avoid performance regression.
412
+
413
+ Args:
414
+ enabled: If True, enable profiling instrumentation.
415
+ """
416
+ global _PROFILING_ENABLED
417
+ _PROFILING_ENABLED = enabled
418
+
419
+
420
+ def _maybe_timeit(name):
421
+ """Return timeit context if not compiling, nullcontext otherwise.
422
+
423
+ torch.compiler.is_compiling() returns True when inside a compiled region,
424
+ and timeit uses time.time() which dynamo cannot trace.
425
+ """
426
+ if is_compiling():
427
+ return _NULL_CONTEXT
428
+ return timeit(name)
429
+
430
+
431
+ def _maybe_record_function(name):
432
+ """Return record_function context if profiling enabled and not compiling.
433
+
434
+ When _PROFILING_ENABLED is False (default), returns immediately with
435
+ minimal overhead to avoid performance regression in hot code paths.
436
+ """
437
+ if not _PROFILING_ENABLED:
438
+ return _NULL_CONTEXT
439
+ if is_compiling():
440
+ return _NULL_CONTEXT
441
+
442
+ return record_function(name)
443
+
444
+
445
+ def _maybe_record_function_decorator(name: str) -> Callable[[Callable], Callable]:
446
+ """Decorator version of :func:`_maybe_record_function`.
447
+
448
+ This is preferred over sprinkling many context managers in hot code paths,
449
+ as it reduces Python overhead while keeping a useful profiler structure.
450
+
451
+ When _PROFILING_ENABLED is False (default), the decorator is a no-op.
452
+ """
453
+
454
+ def decorator(fn: Callable) -> Callable:
455
+ @wraps(fn)
456
+ def wrapped(*args, **kwargs):
457
+ if not _PROFILING_ENABLED:
458
+ return fn(*args, **kwargs)
459
+ with _maybe_record_function(name):
460
+ return fn(*args, **kwargs)
461
+
462
+ return wrapped
463
+
464
+ return decorator
465
+
466
+
467
+ def _check_for_faulty_process(processes):
468
+ terminate = False
469
+ for p in processes:
470
+ if not p._closed and not p.is_alive():
471
+ terminate = True
472
+ for _p in processes:
473
+ _p: mp.Process
474
+ if not _p._closed and _p.is_alive():
475
+ try:
476
+ _p.terminate()
477
+ except Exception:
478
+ _p.kill()
479
+ finally:
480
+ time.sleep(0.1)
481
+ _p.close()
482
+ if terminate:
483
+ break
484
+ if terminate:
485
+ raise RuntimeError(
486
+ "At least one process failed. Check for more infos in the log."
487
+ )
488
+
489
+
490
+ def seed_generator(seed):
491
+ """A seed generator function.
492
+
493
+ Given a seeding integer, generates a deterministic next seed to be used in a
494
+ seeding sequence.
495
+
496
+ Args:
497
+ seed (int): initial seed.
498
+
499
+ Returns: Next seed of the chain.
500
+
501
+ """
502
+ max_seed_val = (
503
+ 2**32 - 1
504
+ ) # https://discuss.pytorch.org/t/what-is-the-max-seed-you-can-set-up/145688
505
+ rng = np.random.default_rng(seed)
506
+ seed = int.from_bytes(rng.bytes(8), "big")
507
+ return seed % max_seed_val
508
+
509
+
510
+ class KeyDependentDefaultDict(collections.defaultdict):
511
+ """A key-dependent default dict.
512
+
513
+ Examples:
514
+ >>> my_dict = KeyDependentDefaultDict(lambda key: "foo_" + key)
515
+ >>> print(my_dict["bar"])
516
+ foo_bar
517
+ """
518
+
519
+ def __init__(self, fun):
520
+ self.fun = fun
521
+ super().__init__()
522
+
523
+ def __missing__(self, key):
524
+ value = self.fun(key)
525
+ self[key] = value
526
+ return value
527
+
528
+
529
+ def prod(sequence):
530
+ """General prod function, that generalised usage across math and np.
531
+
532
+ Created for multiple python versions compatibility).
533
+
534
+ """
535
+ if hasattr(math, "prod"):
536
+ return math.prod(sequence)
537
+ else:
538
+ return int(np.prod(sequence))
539
+
540
+
541
+ def get_binary_env_var(key):
542
+ """Parses and returns the binary environment variable value.
543
+
544
+ If not present in environment, it is considered `False`.
545
+
546
+ Args:
547
+ key (str): name of the environment variable.
548
+ """
549
+ val = os.environ.get(key, "False")
550
+ if val in ("0", "False", "false"):
551
+ val = False
552
+ elif val in ("1", "True", "true"):
553
+ val = True
554
+ else:
555
+ raise ValueError(
556
+ f"Environment variable {key} should be in 'True', 'False', '0' or '1'. "
557
+ f"Got {val} instead."
558
+ )
559
+ return val
560
+
561
+
562
+ class _Dynamic_CKPT_BACKEND:
563
+ """Allows CKPT_BACKEND to be changed on-the-fly."""
564
+
565
+ backends = ["torch", "torchsnapshot"]
566
+
567
+ def _get_backend(self):
568
+ backend = os.environ.get("CKPT_BACKEND", "torch")
569
+ if backend == "torchsnapshot":
570
+ try:
571
+ import torchsnapshot # noqa: F401
572
+ except ImportError as err:
573
+ raise ImportError(
574
+ f"torchsnapshot not found, but the backend points to this library. "
575
+ f"Consider installing torchsnapshot or choose another backend (available backends: {self.backends})"
576
+ ) from err
577
+ return backend
578
+
579
+ def __getattr__(self, item):
580
+ return getattr(self._get_backend(), item)
581
+
582
+ def __eq__(self, other):
583
+ return self._get_backend() == other
584
+
585
+ def __ne__(self, other):
586
+ return self._get_backend() != other
587
+
588
+ def __repr__(self):
589
+ return self._get_backend()
590
+
591
+
592
+ _CKPT_BACKEND = _Dynamic_CKPT_BACKEND()
593
+
594
+
595
+ def accept_remote_rref_invocation(func):
596
+ """Decorator that allows a method to be invoked remotely.
597
+
598
+ Passes the `rpc.RRef` associated with the remote object construction as first argument in place of the object reference.
599
+
600
+ """
601
+
602
+ @wraps(func)
603
+ def unpack_rref_and_invoke_function(self, *args, **kwargs):
604
+ # windows does not know torch._C._distributed_rpc.PyRRef
605
+ if not _os_is_windows and isinstance(self, torch._C._distributed_rpc.PyRRef):
606
+ self = self.local_value()
607
+ return func(self, *args, **kwargs)
608
+
609
+ return unpack_rref_and_invoke_function
610
+
611
+
612
+ def accept_remote_rref_udf_invocation(decorated_class):
613
+ """Class decorator that applies `accept_remote_rref_invocation` to all public methods."""
614
+ # ignores private methods
615
+ for name in dir(decorated_class):
616
+ method = getattr(decorated_class, name, None)
617
+ if method is None:
618
+ continue
619
+ if callable(method) and not name.startswith("_"):
620
+ setattr(decorated_class, name, accept_remote_rref_invocation(method))
621
+ return decorated_class
622
+
623
+
624
+ # We copy this from torch as older versions do not have it
625
+ # see torch.utils._contextlib
626
+
627
+ # Extra utilities for working with context managers that should have been
628
+ # in the standard library but are not
629
+
630
+ # Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
631
+ # 'no_grad' and 'enable_grad').
632
+ # See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
633
+ FuncType = Callable[..., Any]
634
+ F = TypeVar("F", bound=FuncType)
635
+
636
+
637
+ def _wrap_generator(ctx_factory, func):
638
+ """Wrap each generator invocation with the context manager factory.
639
+
640
+ The input should be a function that returns a context manager,
641
+ not a context manager itself, to handle one-shot context managers.
642
+ """
643
+
644
+ @functools.wraps(func)
645
+ def generator_context(*args, **kwargs):
646
+ gen = func(*args, **kwargs)
647
+
648
+ # Generators are suspended and unsuspended at `yield`, hence we
649
+ # make sure the grad mode is properly set every time the execution
650
+ # flow returns into the wrapped generator and restored when it
651
+ # returns through our `yield` to our caller (see PR #49017).
652
+ try:
653
+ # Issuing `None` to a generator fires it up
654
+ with ctx_factory():
655
+ response = gen.send(None)
656
+
657
+ while True:
658
+ try:
659
+ # Forward the response to our caller and get its next request
660
+ request = yield response
661
+
662
+ except GeneratorExit:
663
+ # Inform the still active generator about its imminent closure
664
+ with ctx_factory():
665
+ gen.close()
666
+ raise
667
+
668
+ except BaseException:
669
+ # Propagate the exception thrown at us by the caller
670
+ with ctx_factory():
671
+ response = gen.throw(*sys.exc_info())
672
+
673
+ else:
674
+ # Pass the last request to the generator and get its response
675
+ with ctx_factory():
676
+ response = gen.send(request)
677
+
678
+ # We let the exceptions raised above by the generator's `.throw` or
679
+ # `.send` methods bubble up to our caller, except for StopIteration
680
+ except StopIteration as e:
681
+ # The generator informed us that it is done: take whatever its
682
+ # returned value (if any) was and indicate that we're done too
683
+ # by returning it (see docs for python's return-statement).
684
+ return e.value
685
+
686
+ return generator_context
687
+
688
+
689
+ def context_decorator(ctx, func):
690
+ """Context decorator.
691
+
692
+ Like contextlib.ContextDecorator, but:
693
+
694
+ 1. Is done by wrapping, rather than inheritance, so it works with context
695
+ managers that are implemented from C and thus cannot easily inherit from
696
+ Python classes
697
+ 2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743)
698
+ 3. Errors out if you try to wrap a class, because it is ambiguous whether
699
+ or not you intended to wrap only the constructor
700
+
701
+ The input argument can either be a context manager (in which case it must
702
+ be a multi-shot context manager that can be directly invoked multiple times)
703
+ or a callable that produces a context manager.
704
+ """
705
+ if callable(ctx) and hasattr(ctx, "__enter__"):
706
+ raise RuntimeError(
707
+ f"Passed in {ctx} is both callable and also a valid context manager "
708
+ "(has __enter__), making it ambiguous which interface to use. If you "
709
+ "intended to pass a context manager factory, rewrite your call as "
710
+ "context_decorator(lambda: ctx()); if you intended to pass a context "
711
+ "manager directly, rewrite your call as context_decorator(lambda: ctx)"
712
+ )
713
+
714
+ if not callable(ctx):
715
+
716
+ def ctx_factory():
717
+ return ctx
718
+
719
+ else:
720
+ ctx_factory = ctx
721
+
722
+ if inspect.isclass(func):
723
+ raise RuntimeError(
724
+ "Cannot decorate classes; it is ambiguous whether only the "
725
+ "constructor or all methods should have the context manager applied; "
726
+ "additionally, decorating a class at definition-site will prevent "
727
+ "use of the identifier as a conventional type. "
728
+ "To specify which methods to decorate, decorate each of them "
729
+ "individually."
730
+ )
731
+
732
+ if inspect.isgeneratorfunction(func):
733
+ return _wrap_generator(ctx_factory, func)
734
+
735
+ @functools.wraps(func)
736
+ def decorate_context(*args, **kwargs):
737
+ with ctx_factory():
738
+ return func(*args, **kwargs)
739
+
740
+ return decorate_context
741
+
742
+
743
+ class _DecoratorContextManager:
744
+ """Allow a context manager to be used as a decorator."""
745
+
746
+ def __call__(self, orig_func: F) -> F:
747
+ if inspect.isclass(orig_func):
748
+ warnings.warn(
749
+ "Decorating classes is deprecated and will be disabled in "
750
+ "future versions. You should only decorate functions or methods. "
751
+ "To preserve the current behavior of class decoration, you can "
752
+ "directly decorate the `__init__` method and nothing else."
753
+ )
754
+ func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs))
755
+ else:
756
+ func = orig_func
757
+
758
+ return cast(F, context_decorator(self.clone, func))
759
+
760
+ def __enter__(self) -> None:
761
+ raise NotImplementedError
762
+
763
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
764
+ raise NotImplementedError
765
+
766
+ def clone(self):
767
+ # override this method if your children class takes __init__ parameters
768
+ return self.__class__()
769
+
770
+
771
+ def get_trace():
772
+ """A simple debugging util to spot where a function is being called."""
773
+ traceback.print_stack()
774
+
775
+
776
+ def _make_process_no_warn_cls(ctx=None):
777
+ """Create a _ProcessNoWarn class that inherits from the appropriate Process class.
778
+
779
+ When using multiprocessing contexts (e.g., fork or spawn), the Process class
780
+ used must match the context to ensure synchronization primitives like locks
781
+ work correctly. This factory function creates a _ProcessNoWarn class that
782
+ inherits from the context's Process class.
783
+
784
+ Args:
785
+ ctx: A multiprocessing context (e.g., from mp.get_context('fork')).
786
+ If None, uses the default mp.Process.
787
+
788
+ Returns:
789
+ A _ProcessNoWarn class that inherits from the appropriate Process base.
790
+
791
+ .. note::
792
+ For the "spawn" start method, this returns pre-defined module-level classes
793
+ to ensure they can be pickled correctly.
794
+ """
795
+ if ctx is None:
796
+ return _ProcessNoWarn
797
+
798
+ start_method = ctx.get_start_method()
799
+ if start_method == "fork":
800
+ return _ProcessNoWarnFork
801
+ elif start_method == "spawn":
802
+ return _ProcessNoWarnSpawn
803
+ elif start_method == "forkserver":
804
+ return _ProcessNoWarnForkserver
805
+ else:
806
+ # For unknown start methods, fall back to default
807
+ return _ProcessNoWarn
808
+
809
+
810
+ # Keep the old class name as a default for backwards compatibility
811
+ class _ProcessNoWarn(mp.Process):
812
+ """A private Process class that shuts down warnings on the subprocess and controls the number of threads in the subprocess.
813
+
814
+ .. note::
815
+ When using multiprocessing contexts with synchronization primitives (locks, etc.),
816
+ use :func:`_make_process_no_warn_cls` with the context to ensure compatibility.
817
+ """
818
+
819
+ @wraps(mp.Process.__init__)
820
+ def __init__(self, *args, num_threads=None, _start_method=None, **kwargs):
821
+ import torchrl
822
+
823
+ self.filter_warnings_subprocess = torchrl.filter_warnings_subprocess
824
+ self.num_threads = num_threads
825
+ if _start_method is not None:
826
+ self._start_method = _start_method
827
+ super().__init__(*args, **kwargs)
828
+
829
+ def run(self, *args, **kwargs):
830
+ if self.num_threads is not None:
831
+ torch.set_num_threads(self.num_threads)
832
+ if self.filter_warnings_subprocess:
833
+ import warnings
834
+
835
+ with warnings.catch_warnings():
836
+ warnings.simplefilter("ignore")
837
+ return mp.Process.run(self, *args, **kwargs)
838
+ return mp.Process.run(self, *args, **kwargs)
839
+
840
+
841
+ # Pre-defined _ProcessNoWarn classes for different multiprocessing start methods.
842
+ # These must be defined at module level to be picklable with the "spawn" start method.
843
+ #
844
+ # We use a mixin pattern to avoid code duplication while still having
845
+ # distinct module-level classes that can be pickled.
846
+
847
+
848
+ class _ProcessNoWarnMixin:
849
+ """Mixin class providing the common functionality for _ProcessNoWarn variants."""
850
+
851
+ def _init_process_no_warn(self, num_threads=None, _start_method=None):
852
+ import torchrl
853
+
854
+ self.filter_warnings_subprocess = torchrl.filter_warnings_subprocess
855
+ self.num_threads = num_threads
856
+ if _start_method is not None:
857
+ self._start_method = _start_method
858
+
859
+ def run(self, *args, **kwargs):
860
+ if self.num_threads is not None:
861
+ torch.set_num_threads(self.num_threads)
862
+ if self.filter_warnings_subprocess:
863
+ import warnings
864
+
865
+ with warnings.catch_warnings():
866
+ warnings.simplefilter("ignore")
867
+ return super().run(*args, **kwargs)
868
+ return super().run(*args, **kwargs)
869
+
870
+
871
+ # Spawn-specific class (for macOS default and Windows)
872
+ try:
873
+ _spawn_ctx = mp.get_context("spawn")
874
+
875
+ class _ProcessNoWarnSpawn(_ProcessNoWarnMixin, _spawn_ctx.Process):
876
+ """_ProcessNoWarn for the 'spawn' multiprocessing context."""
877
+
878
+ def __init__(self, *args, num_threads=None, _start_method=None, **kwargs):
879
+ self._init_process_no_warn(num_threads, _start_method)
880
+ super().__init__(*args, **kwargs)
881
+
882
+ except ValueError:
883
+ _ProcessNoWarnSpawn = _ProcessNoWarn
884
+
885
+
886
+ # Fork-specific class (for Linux default, not available on Windows)
887
+ try:
888
+ _fork_ctx = mp.get_context("fork")
889
+
890
+ class _ProcessNoWarnFork(_ProcessNoWarnMixin, _fork_ctx.Process):
891
+ """_ProcessNoWarn for the 'fork' multiprocessing context."""
892
+
893
+ def __init__(self, *args, num_threads=None, _start_method=None, **kwargs):
894
+ self._init_process_no_warn(num_threads, _start_method)
895
+ super().__init__(*args, **kwargs)
896
+
897
+ except ValueError:
898
+ _ProcessNoWarnFork = _ProcessNoWarn
899
+
900
+
901
+ # Forkserver-specific class (not available on Windows)
902
+ try:
903
+ _forkserver_ctx = mp.get_context("forkserver")
904
+
905
+ class _ProcessNoWarnForkserver(_ProcessNoWarnMixin, _forkserver_ctx.Process):
906
+ """_ProcessNoWarn for the 'forkserver' multiprocessing context."""
907
+
908
+ def __init__(self, *args, num_threads=None, _start_method=None, **kwargs):
909
+ self._init_process_no_warn(num_threads, _start_method)
910
+ super().__init__(*args, **kwargs)
911
+
912
+ except ValueError:
913
+ _ProcessNoWarnForkserver = _ProcessNoWarn
914
+
915
+
916
+ def print_directory_tree(path, indent="", display_metadata=True):
917
+ """Prints the directory tree starting from the specified path.
918
+
919
+ Args:
920
+ path (str): The path of the directory to print.
921
+ indent (str): The current indentation level for formatting.
922
+ display_metadata (bool): if ``True``, metadata of the dir will be
923
+ displayed too.
924
+
925
+ """
926
+ if display_metadata:
927
+
928
+ def get_directory_size(path="."):
929
+ total_size = 0
930
+
931
+ for dirpath, _, filenames in os.walk(path):
932
+ for filename in filenames:
933
+ file_path = os.path.join(dirpath, filename)
934
+ total_size += os.path.getsize(file_path)
935
+
936
+ return total_size
937
+
938
+ def format_size(size):
939
+ # Convert size to a human-readable format
940
+ for unit in ["B", "KB", "MB", "GB", "TB"]:
941
+ if size < 1024.0:
942
+ return f"{size:.2f} {unit}"
943
+ size /= 1024.0
944
+
945
+ total_size_bytes = get_directory_size(path)
946
+ formatted_size = format_size(total_size_bytes)
947
+ logger.info(f"Directory size: {formatted_size}")
948
+
949
+ if os.path.isdir(path):
950
+ logger.info(indent + os.path.basename(path) + "/")
951
+ indent += " "
952
+ for item in os.listdir(path):
953
+ print_directory_tree(
954
+ os.path.join(path, item), indent=indent, display_metadata=False
955
+ )
956
+ else:
957
+ logger.info(indent + os.path.basename(path))
958
+
959
+
960
+ def _ends_with(key, match):
961
+ if isinstance(key, str):
962
+ return key == match
963
+ return key[-1] == match
964
+
965
+
966
+ def _replace_last(key: NestedKey, new_ending: str) -> NestedKey:
967
+ if isinstance(key, str):
968
+ return new_ending
969
+ else:
970
+ return key[:-1] + (new_ending,)
971
+
972
+
973
+ def _append_last(key: NestedKey, new_suffix: str) -> NestedKey:
974
+ key = unravel_key(key)
975
+ if isinstance(key, str):
976
+ return key + new_suffix
977
+ else:
978
+ return key[:-1] + (key[-1] + new_suffix,)
979
+
980
+
981
+ class _rng_decorator(_DecoratorContextManager):
982
+ """Temporarily sets the seed and sets back the rng state when exiting."""
983
+
984
+ def __init__(self, seed, device=None):
985
+ self.seed = seed
986
+ self.device = device
987
+ self.has_cuda = torch.cuda.is_available()
988
+
989
+ def __enter__(self):
990
+ self._get_state()
991
+ torch.manual_seed(self.seed)
992
+
993
+ def _get_state(self):
994
+ if self.has_cuda:
995
+ if self.device is None:
996
+ self._state = (torch.random.get_rng_state(), torch.cuda.get_rng_state())
997
+ else:
998
+ self._state = (
999
+ torch.random.get_rng_state(),
1000
+ torch.cuda.get_rng_state(self.device),
1001
+ )
1002
+
1003
+ else:
1004
+ self._state = torch.random.get_rng_state()
1005
+
1006
+ def __exit__(self, exc_type, exc_val, exc_tb):
1007
+ if self.has_cuda:
1008
+ torch.random.set_rng_state(self._state[0])
1009
+ if self.device is not None:
1010
+ torch.cuda.set_rng_state(self._state[1], device=self.device)
1011
+ else:
1012
+ torch.cuda.set_rng_state(self._state[1])
1013
+
1014
+ else:
1015
+ torch.random.set_rng_state(self._state)
1016
+
1017
+
1018
+ def _can_be_pickled(obj):
1019
+ try:
1020
+ pickle.dumps(obj)
1021
+ return True
1022
+ except (pickle.PickleError, AttributeError, TypeError):
1023
+ return False
1024
+
1025
+
1026
+ def _make_ordinal_device(device: torch.device):
1027
+ if device is None:
1028
+ return device
1029
+ device = torch.device(device)
1030
+ if device.type == "cuda" and device.index is None:
1031
+ return torch.device("cuda", index=torch.cuda.current_device())
1032
+ if device.type == "mps" and device.index is None:
1033
+ return torch.device("mps", index=0)
1034
+ return device
1035
+
1036
+
1037
+ def get_available_device(return_str: bool = False) -> torch.device | str:
1038
+ """Return the available accelerator device, or CPU if none is found.
1039
+
1040
+ Checks for accelerator availability in the following order: CUDA, NPU, MPS.
1041
+ Returns the first available accelerator, or CPU if none are present.
1042
+
1043
+ .. note::
1044
+ PyTorch generally assumes a single accelerator type per system.
1045
+ Running with multiple accelerator types (e.g., both CUDA and NPU)
1046
+ is not officially supported. This function simply returns the first
1047
+ available accelerator it finds.
1048
+
1049
+ Args:
1050
+ return_str: If ``True``, returns a string representation of the device
1051
+ instead of a :class:`~torch.device` object. Defaults to ``False``.
1052
+
1053
+ Returns:
1054
+ The available accelerator device as a :class:`~torch.device` object,
1055
+ or as a string if ``return_str`` is ``True``. Falls back to CPU if
1056
+ no accelerator is available.
1057
+
1058
+ Examples:
1059
+ >>> from torchrl._utils import get_available_device
1060
+ >>> device = get_available_device()
1061
+ >>> # Use with config fallback:
1062
+ >>> device = cfg.device or get_available_device()
1063
+ """
1064
+ if torch.cuda.is_available():
1065
+ device = "cuda:0"
1066
+ elif hasattr(torch, "npu") and torch.npu.is_available():
1067
+ device = "npu:0"
1068
+ elif torch.backends.mps.is_available():
1069
+ device = "mps:0"
1070
+ else:
1071
+ device = "cpu"
1072
+ if return_str:
1073
+ return device
1074
+ return torch.device(device)
1075
+
1076
+
1077
+ class _ContextManager:
1078
+ def __init__(self):
1079
+ self._mode: Any | None = None
1080
+ self._lock = threading.Lock()
1081
+
1082
+ def get_mode(self) -> Any | None:
1083
+ cm = self._lock if not is_compiling() else nullcontext()
1084
+ with cm:
1085
+ return self._mode
1086
+
1087
+ def set_mode(self, type: Any | None) -> None:
1088
+ cm = self._lock if not is_compiling() else nullcontext()
1089
+ with cm:
1090
+ self._mode = type
1091
+
1092
+
1093
+ def _standardize(
1094
+ input: Tensor,
1095
+ exclude_dims: tuple[int] = (),
1096
+ mean: Tensor | None = None,
1097
+ std: Tensor | None = None,
1098
+ eps: float | None = None,
1099
+ ):
1100
+ """Standardizes the input tensor with the possibility of excluding specific dims from the statistics.
1101
+
1102
+ Useful when processing multi-agent data to keep the agent dimensions independent.
1103
+
1104
+ Args:
1105
+ input (Tensor): the input tensor to be standardized.
1106
+ exclude_dims (Tuple[int]): dimensions to exclude from the statistics, can be negative. Default: ().
1107
+ mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None.
1108
+ std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None.
1109
+ eps (:obj:`float`): epsilon to be used for numerical stability. Default: float32 resolution.
1110
+
1111
+ """
1112
+ if eps is None:
1113
+ if input.dtype.is_floating_point:
1114
+ eps = torch.finfo(torch.float).resolution
1115
+ else:
1116
+ eps = 1e-6
1117
+
1118
+ len_exclude_dims = len(exclude_dims)
1119
+ if not len_exclude_dims:
1120
+ if mean is None:
1121
+ mean = input.mean()
1122
+ else:
1123
+ # Assume dtypes are compatible
1124
+ mean = torch.as_tensor(mean, device=input.device)
1125
+ if std is None:
1126
+ std = input.std()
1127
+ else:
1128
+ # Assume dtypes are compatible
1129
+ std = torch.as_tensor(std, device=input.device)
1130
+ return (input - mean) / std.clamp_min(eps)
1131
+
1132
+ input_shape = input.shape
1133
+ exclude_dims = [
1134
+ d if d >= 0 else d + len(input_shape) for d in exclude_dims
1135
+ ] # Make negative dims positive
1136
+
1137
+ if len(set(exclude_dims)) != len_exclude_dims:
1138
+ raise ValueError("Exclude dims has repeating elements")
1139
+ if any(dim < 0 or dim >= len(input_shape) for dim in exclude_dims):
1140
+ raise ValueError(
1141
+ f"exclude_dims={exclude_dims} provided outside bounds for input of shape={input_shape}"
1142
+ )
1143
+ if len_exclude_dims == len(input_shape):
1144
+ warnings.warn(
1145
+ "_standardize called but all dims were excluded from the statistics, returning unprocessed input"
1146
+ )
1147
+ return input
1148
+
1149
+ included_dims = tuple(d for d in range(len(input_shape)) if d not in exclude_dims)
1150
+ if mean is None:
1151
+ mean = torch.mean(input, keepdim=True, dim=included_dims)
1152
+ if std is None:
1153
+ std = torch.std(input, keepdim=True, dim=included_dims)
1154
+ return (input - mean) / std.clamp_min(eps)
1155
+
1156
+
1157
+ @wraps(torch.compile)
1158
+ def compile_with_warmup(*args, warmup: int = 1, **kwargs):
1159
+ """Compile a model with warm-up.
1160
+
1161
+ This function wraps :func:`~torch.compile` to add a warm-up phase. During the warm-up phase,
1162
+ the original model is used. After the warm-up phase, the model is compiled using
1163
+ `torch.compile`.
1164
+
1165
+ Args:
1166
+ *args: Arguments to be passed to `torch.compile`.
1167
+ warmup (int): Number of calls to the model before compiling it. Defaults to 1.
1168
+ **kwargs: Keyword arguments to be passed to `torch.compile`.
1169
+
1170
+ Returns:
1171
+ A callable that wraps the original model. If no model is provided, returns a
1172
+ lambda function that takes a model as input and returns the wrapped model.
1173
+
1174
+ Notes:
1175
+ If no model is provided, this function returns a lambda function that can be
1176
+ used to wrap a model later. This allows for delayed compilation of the model.
1177
+
1178
+ Example:
1179
+ >>> model = torch.nn.Linear(5, 3)
1180
+ >>> compiled_model = compile_with_warmup(model, warmup=10)
1181
+ >>> # First 10 calls use the original model
1182
+ >>> # After 10 calls, the model is compiled and used
1183
+ """
1184
+ if len(args):
1185
+ model = args[0]
1186
+ args = ()
1187
+ else:
1188
+ model = kwargs.pop("model", None)
1189
+ if model is None:
1190
+ return lambda model: compile_with_warmup(model, warmup=warmup, **kwargs)
1191
+ else:
1192
+ count = -1
1193
+ compiled_model = model
1194
+
1195
+ @wraps(model)
1196
+ def count_and_compile(*model_args, **model_kwargs):
1197
+ nonlocal count
1198
+ nonlocal compiled_model
1199
+ count += 1
1200
+ if count == warmup:
1201
+ compiled_model = torch.compile(model, *args, **kwargs)
1202
+ return compiled_model(*model_args, **model_kwargs)
1203
+
1204
+ return count_and_compile
1205
+
1206
+
1207
+ # auto unwrap control
1208
+ _DEFAULT_AUTO_UNWRAP = True
1209
+ _AUTO_UNWRAP = os.environ.get("AUTO_UNWRAP_TRANSFORMED_ENV")
1210
+
1211
+
1212
+ class set_auto_unwrap_transformed_env(_DecoratorContextManager):
1213
+ """A context manager or decorator to control whether TransformedEnv should automatically unwrap nested TransformedEnv instances.
1214
+
1215
+ Args:
1216
+ mode (bool): Whether to automatically unwrap nested :class:`~torchrl.envs.TransformedEnv`
1217
+ instances. If ``False``, :class:`~torchrl.envs.TransformedEnv` will not unwrap nested instances.
1218
+ Defaults to ``True``.
1219
+
1220
+ .. note:: Until v0.9, this will raise a warning if :class:`~torchrl.envs.TransformedEnv` are nested
1221
+ and the value is not set explicitly (`auto_unwrap=True` default behavior).
1222
+ You can set the value of :func:`~torchrl.envs.auto_unwrap_transformed_env`
1223
+ through:
1224
+
1225
+ - The ``AUTO_UNWRAP_TRANSFORMED_ENV`` environment variable;
1226
+ - By setting ``torchrl.set_auto_unwrap_transformed_env(val: bool).set()`` at the
1227
+ beginning of your script;
1228
+ - By using ``torchrl.set_auto_unwrap_transformed_env(val: bool)`` as a context
1229
+ manager or a decorator.
1230
+
1231
+ .. seealso:: :class:`~torchrl.envs.TransformedEnv`
1232
+
1233
+ Examples:
1234
+ >>> with set_auto_unwrap_transformed_env(False):
1235
+ ... env = TransformedEnv(TransformedEnv(env))
1236
+ ... assert not isinstance(env.base_env, TransformedEnv)
1237
+ >>> @set_auto_unwrap_transformed_env(False)
1238
+ ... def my_function():
1239
+ ... env = TransformedEnv(TransformedEnv(env))
1240
+ ... assert not isinstance(env.base_env, TransformedEnv)
1241
+ ... return env
1242
+
1243
+ """
1244
+
1245
+ def __init__(self, mode: bool) -> None:
1246
+ super().__init__()
1247
+ self.mode = mode
1248
+
1249
+ def clone(self) -> set_auto_unwrap_transformed_env:
1250
+ # override this method if your children class takes __init__ parameters
1251
+ return type(self)(self.mode)
1252
+
1253
+ def __enter__(self) -> None:
1254
+ self.set()
1255
+
1256
+ def set(self) -> None:
1257
+ global _AUTO_UNWRAP
1258
+ self._old_mode = _AUTO_UNWRAP
1259
+ _AUTO_UNWRAP = bool(self.mode)
1260
+ # we do this such that sub-processes see the same lazy op than the main one
1261
+ os.environ["AUTO_UNWRAP_TRANSFORMED_ENV"] = str(_AUTO_UNWRAP)
1262
+
1263
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
1264
+ global _AUTO_UNWRAP
1265
+ _AUTO_UNWRAP = self._old_mode
1266
+ os.environ["AUTO_UNWRAP_TRANSFORMED_ENV"] = str(_AUTO_UNWRAP)
1267
+
1268
+
1269
+ def auto_unwrap_transformed_env(allow_none=False):
1270
+ """Get the current setting for automatically unwrapping TransformedEnv instances.
1271
+
1272
+ Args:
1273
+ allow_none (bool, optional): If True, returns ``None`` if no setting has been
1274
+ specified. Otherwise, returns the default setting. Defaults to ``False``.
1275
+
1276
+ seealso: :func:`~torchrl.set_auto_unwrap_transformed_env`
1277
+
1278
+ Returns:
1279
+ bool or None: The current setting for automatically unwrapping TransformedEnv
1280
+ instances.
1281
+ """
1282
+ global _AUTO_UNWRAP # noqa: F824
1283
+ if _AUTO_UNWRAP is None and allow_none:
1284
+ return None
1285
+ elif _AUTO_UNWRAP is None:
1286
+ return _DEFAULT_AUTO_UNWRAP
1287
+ return strtobool(_AUTO_UNWRAP) if isinstance(_AUTO_UNWRAP, str) else _AUTO_UNWRAP
1288
+
1289
+
1290
+ def safe_is_current_stream_capturing():
1291
+ """A safe proxy to torch.cuda.is_current_stream_capturing."""
1292
+ if not torch.cuda.is_available():
1293
+ return False
1294
+ try:
1295
+ return torch.cuda.is_current_stream_capturing()
1296
+ except Exception as error:
1297
+ warnings.warn(
1298
+ f"torch.cuda.is_current_stream_capturing() exited unexpectedly with the error message {error=}. "
1299
+ f"Returning False by default."
1300
+ )
1301
+ return False
1302
+
1303
+
1304
+ @classmethod
1305
+ def as_remote(cls, remote_config: dict[str, Any] | None = None):
1306
+ """Creates an instance of a remote ray class.
1307
+
1308
+ Args:
1309
+ cls (Python Class): class to be remotely instantiated.
1310
+ remote_config (dict): the quantity of CPU cores to reserve for this class.
1311
+
1312
+ Returns:
1313
+ A function that creates ray remote class instances.
1314
+ """
1315
+ import ray
1316
+
1317
+ if remote_config is None:
1318
+ remote_config = {}
1319
+
1320
+ remote_collector = ray.remote(**remote_config)(cls)
1321
+ remote_collector.is_remote = True
1322
+ return remote_collector
1323
+
1324
+
1325
+ def get_ray_default_runtime_env() -> dict[str, Any]:
1326
+ """Get the default Ray runtime environment configuration for TorchRL.
1327
+
1328
+ This function returns a runtime environment configuration that excludes
1329
+ large directories and files that should not be uploaded to Ray workers.
1330
+ This helps prevent issues with Ray's working_dir size limits (512MB default).
1331
+
1332
+ Returns:
1333
+ dict: A dictionary containing the default runtime_env configuration with
1334
+ excludes for common large directories.
1335
+
1336
+ Examples:
1337
+ >>> import ray
1338
+ >>> from torchrl._utils import get_ray_default_runtime_env
1339
+ >>> ray_init_config = {"num_cpus": 4}
1340
+ >>> ray_init_config["runtime_env"] = get_ray_default_runtime_env()
1341
+ >>> ray.init(**ray_init_config)
1342
+
1343
+ Note:
1344
+ The excludes list includes:
1345
+ - Virtual environments (.venv/, venv/, etc.)
1346
+ - Test files and caches
1347
+ - Documentation builds
1348
+ - Benchmarks
1349
+ - Examples and tutorials
1350
+ - CI/CD configurations
1351
+ - IDE configurations
1352
+
1353
+ """
1354
+ return {
1355
+ "excludes": [
1356
+ ".venv/",
1357
+ "venv/",
1358
+ "env/",
1359
+ "ENV/",
1360
+ "env.bak/",
1361
+ "venv.bak/",
1362
+ "test/",
1363
+ "tests/",
1364
+ "docs/",
1365
+ "benchmarks/",
1366
+ "tutorials/",
1367
+ "examples/",
1368
+ ".github/",
1369
+ ".pytest_cache/",
1370
+ ".mypy_cache/",
1371
+ ".ruff_cache/",
1372
+ "__pycache__/",
1373
+ "*.pyc",
1374
+ "*.pyo",
1375
+ "*.egg-info/",
1376
+ ".idea/",
1377
+ ".vscode/",
1378
+ "dev/",
1379
+ "main/",
1380
+ "*.html",
1381
+ ]
1382
+ }
1383
+
1384
+
1385
+ def merge_ray_runtime_env(ray_init_config: dict[str, Any]) -> dict[str, Any]:
1386
+ """Merge user-provided ray_init_config with default runtime_env excludes.
1387
+
1388
+ This function ensures that the default TorchRL runtime_env excludes are applied
1389
+ to prevent large directories from being uploaded to Ray workers, while preserving
1390
+ any user-provided configuration.
1391
+
1392
+ Args:
1393
+ ray_init_config (dict): The ray init configuration dictionary to merge.
1394
+
1395
+ Returns:
1396
+ dict: The merged configuration with default runtime_env excludes applied.
1397
+
1398
+ Examples:
1399
+ >>> from torchrl._utils import merge_ray_runtime_env
1400
+ >>> ray_init_config = {"num_cpus": 4}
1401
+ >>> ray_init_config = merge_ray_runtime_env(ray_init_config)
1402
+ >>> ray.init(**ray_init_config)
1403
+
1404
+ """
1405
+ default_runtime_env = get_ray_default_runtime_env()
1406
+ runtime_env = ray_init_config.get("runtime_env")
1407
+
1408
+ # Handle None or missing runtime_env
1409
+ if runtime_env is None:
1410
+ runtime_env = {}
1411
+ ray_init_config["runtime_env"] = runtime_env
1412
+ elif not isinstance(runtime_env, dict):
1413
+ runtime_env = dict(runtime_env)
1414
+ ray_init_config["runtime_env"] = runtime_env
1415
+
1416
+ # Merge excludes lists
1417
+ excludes = runtime_env.get("excludes", [])
1418
+ runtime_env["excludes"] = list(set(default_runtime_env["excludes"] + excludes))
1419
+
1420
+ # Ensure env_vars exists
1421
+ if "env_vars" not in runtime_env:
1422
+ runtime_env["env_vars"] = {}
1423
+ elif not isinstance(runtime_env["env_vars"], dict):
1424
+ runtime_env["env_vars"] = dict(runtime_env["env_vars"])
1425
+
1426
+ return ray_init_config
1427
+
1428
+
1429
+ def rl_warnings():
1430
+ """Checks the status of the RL_WARNINGS env varioble."""
1431
+ return RL_WARNINGS