torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. benchmarks/benchmark_batched_envs.py +104 -0
  2. benchmarks/conftest.py +91 -0
  3. benchmarks/ecosystem/gym_env_throughput.py +321 -0
  4. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
  5. benchmarks/requirements.txt +7 -0
  6. benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
  7. benchmarks/test_collectors_benchmark.py +240 -0
  8. benchmarks/test_compressed_storage_benchmark.py +145 -0
  9. benchmarks/test_envs_benchmark.py +133 -0
  10. benchmarks/test_llm.py +101 -0
  11. benchmarks/test_non_tensor_env_benchmark.py +70 -0
  12. benchmarks/test_objectives_benchmarks.py +1199 -0
  13. benchmarks/test_replaybuffer_benchmark.py +254 -0
  14. sota-check/README.md +35 -0
  15. sota-implementations/README.md +142 -0
  16. sota-implementations/a2c/README.md +39 -0
  17. sota-implementations/a2c/a2c_atari.py +291 -0
  18. sota-implementations/a2c/a2c_mujoco.py +273 -0
  19. sota-implementations/a2c/utils_atari.py +240 -0
  20. sota-implementations/a2c/utils_mujoco.py +160 -0
  21. sota-implementations/bandits/README.md +7 -0
  22. sota-implementations/bandits/dqn.py +126 -0
  23. sota-implementations/cql/cql_offline.py +198 -0
  24. sota-implementations/cql/cql_online.py +249 -0
  25. sota-implementations/cql/discrete_cql_offline.py +180 -0
  26. sota-implementations/cql/discrete_cql_online.py +227 -0
  27. sota-implementations/cql/utils.py +471 -0
  28. sota-implementations/crossq/crossq.py +271 -0
  29. sota-implementations/crossq/utils.py +320 -0
  30. sota-implementations/ddpg/ddpg.py +231 -0
  31. sota-implementations/ddpg/utils.py +325 -0
  32. sota-implementations/decision_transformer/dt.py +163 -0
  33. sota-implementations/decision_transformer/lamb.py +167 -0
  34. sota-implementations/decision_transformer/online_dt.py +178 -0
  35. sota-implementations/decision_transformer/utils.py +562 -0
  36. sota-implementations/discrete_sac/discrete_sac.py +243 -0
  37. sota-implementations/discrete_sac/utils.py +324 -0
  38. sota-implementations/dqn/README.md +30 -0
  39. sota-implementations/dqn/dqn_atari.py +272 -0
  40. sota-implementations/dqn/dqn_cartpole.py +236 -0
  41. sota-implementations/dqn/utils_atari.py +132 -0
  42. sota-implementations/dqn/utils_cartpole.py +90 -0
  43. sota-implementations/dreamer/README.md +129 -0
  44. sota-implementations/dreamer/dreamer.py +586 -0
  45. sota-implementations/dreamer/dreamer_utils.py +1107 -0
  46. sota-implementations/expert-iteration/README.md +352 -0
  47. sota-implementations/expert-iteration/ei_utils.py +770 -0
  48. sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
  49. sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
  50. sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
  51. sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
  52. sota-implementations/gail/gail.py +327 -0
  53. sota-implementations/gail/gail_utils.py +68 -0
  54. sota-implementations/gail/ppo_utils.py +157 -0
  55. sota-implementations/grpo/README.md +273 -0
  56. sota-implementations/grpo/grpo-async.py +437 -0
  57. sota-implementations/grpo/grpo-sync.py +435 -0
  58. sota-implementations/grpo/grpo_utils.py +843 -0
  59. sota-implementations/grpo/requirements_gsm8k.txt +11 -0
  60. sota-implementations/grpo/requirements_ifeval.txt +16 -0
  61. sota-implementations/impala/README.md +33 -0
  62. sota-implementations/impala/impala_multi_node_ray.py +292 -0
  63. sota-implementations/impala/impala_multi_node_submitit.py +284 -0
  64. sota-implementations/impala/impala_single_node.py +261 -0
  65. sota-implementations/impala/utils.py +184 -0
  66. sota-implementations/iql/discrete_iql.py +230 -0
  67. sota-implementations/iql/iql_offline.py +164 -0
  68. sota-implementations/iql/iql_online.py +225 -0
  69. sota-implementations/iql/utils.py +437 -0
  70. sota-implementations/multiagent/README.md +74 -0
  71. sota-implementations/multiagent/iql.py +237 -0
  72. sota-implementations/multiagent/maddpg_iddpg.py +266 -0
  73. sota-implementations/multiagent/mappo_ippo.py +267 -0
  74. sota-implementations/multiagent/qmix_vdn.py +271 -0
  75. sota-implementations/multiagent/sac.py +337 -0
  76. sota-implementations/multiagent/utils/__init__.py +4 -0
  77. sota-implementations/multiagent/utils/logging.py +151 -0
  78. sota-implementations/multiagent/utils/utils.py +43 -0
  79. sota-implementations/ppo/README.md +29 -0
  80. sota-implementations/ppo/ppo_atari.py +305 -0
  81. sota-implementations/ppo/ppo_mujoco.py +293 -0
  82. sota-implementations/ppo/utils_atari.py +238 -0
  83. sota-implementations/ppo/utils_mujoco.py +152 -0
  84. sota-implementations/ppo_trainer/train.py +21 -0
  85. sota-implementations/redq/README.md +7 -0
  86. sota-implementations/redq/redq.py +199 -0
  87. sota-implementations/redq/utils.py +1060 -0
  88. sota-implementations/sac/sac-async.py +266 -0
  89. sota-implementations/sac/sac.py +239 -0
  90. sota-implementations/sac/utils.py +381 -0
  91. sota-implementations/sac_trainer/train.py +16 -0
  92. sota-implementations/td3/td3.py +254 -0
  93. sota-implementations/td3/utils.py +319 -0
  94. sota-implementations/td3_bc/td3_bc.py +177 -0
  95. sota-implementations/td3_bc/utils.py +251 -0
  96. torchrl/.dylibs/libc++.1.0.dylib +0 -0
  97. torchrl/__init__.py +144 -0
  98. torchrl/_extension.py +74 -0
  99. torchrl/_torchrl.cpython-314t-darwin.so +0 -0
  100. torchrl/_utils.py +1431 -0
  101. torchrl/collectors/__init__.py +48 -0
  102. torchrl/collectors/_base.py +1058 -0
  103. torchrl/collectors/_constants.py +88 -0
  104. torchrl/collectors/_multi_async.py +324 -0
  105. torchrl/collectors/_multi_base.py +1805 -0
  106. torchrl/collectors/_multi_sync.py +464 -0
  107. torchrl/collectors/_runner.py +581 -0
  108. torchrl/collectors/_single.py +2009 -0
  109. torchrl/collectors/_single_async.py +259 -0
  110. torchrl/collectors/collectors.py +62 -0
  111. torchrl/collectors/distributed/__init__.py +32 -0
  112. torchrl/collectors/distributed/default_configs.py +133 -0
  113. torchrl/collectors/distributed/generic.py +1306 -0
  114. torchrl/collectors/distributed/ray.py +1092 -0
  115. torchrl/collectors/distributed/rpc.py +1006 -0
  116. torchrl/collectors/distributed/sync.py +731 -0
  117. torchrl/collectors/distributed/utils.py +160 -0
  118. torchrl/collectors/llm/__init__.py +10 -0
  119. torchrl/collectors/llm/base.py +494 -0
  120. torchrl/collectors/llm/ray_collector.py +275 -0
  121. torchrl/collectors/llm/utils.py +36 -0
  122. torchrl/collectors/llm/weight_update/__init__.py +10 -0
  123. torchrl/collectors/llm/weight_update/vllm.py +348 -0
  124. torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
  125. torchrl/collectors/utils.py +433 -0
  126. torchrl/collectors/weight_update.py +591 -0
  127. torchrl/csrc/numpy_utils.h +38 -0
  128. torchrl/csrc/pybind.cpp +27 -0
  129. torchrl/csrc/segment_tree.h +458 -0
  130. torchrl/csrc/torch_utils.h +34 -0
  131. torchrl/csrc/utils.cpp +48 -0
  132. torchrl/csrc/utils.h +31 -0
  133. torchrl/data/__init__.py +187 -0
  134. torchrl/data/datasets/__init__.py +58 -0
  135. torchrl/data/datasets/atari_dqn.py +878 -0
  136. torchrl/data/datasets/common.py +281 -0
  137. torchrl/data/datasets/d4rl.py +489 -0
  138. torchrl/data/datasets/d4rl_infos.py +187 -0
  139. torchrl/data/datasets/gen_dgrl.py +375 -0
  140. torchrl/data/datasets/minari_data.py +643 -0
  141. torchrl/data/datasets/openml.py +177 -0
  142. torchrl/data/datasets/openx.py +798 -0
  143. torchrl/data/datasets/roboset.py +363 -0
  144. torchrl/data/datasets/utils.py +11 -0
  145. torchrl/data/datasets/vd4rl.py +432 -0
  146. torchrl/data/llm/__init__.py +34 -0
  147. torchrl/data/llm/dataset.py +491 -0
  148. torchrl/data/llm/history.py +1378 -0
  149. torchrl/data/llm/prompt.py +198 -0
  150. torchrl/data/llm/reward.py +225 -0
  151. torchrl/data/llm/topk.py +186 -0
  152. torchrl/data/llm/utils.py +543 -0
  153. torchrl/data/map/__init__.py +21 -0
  154. torchrl/data/map/hash.py +185 -0
  155. torchrl/data/map/query.py +204 -0
  156. torchrl/data/map/tdstorage.py +363 -0
  157. torchrl/data/map/tree.py +1434 -0
  158. torchrl/data/map/utils.py +103 -0
  159. torchrl/data/postprocs/__init__.py +8 -0
  160. torchrl/data/postprocs/postprocs.py +391 -0
  161. torchrl/data/replay_buffers/__init__.py +99 -0
  162. torchrl/data/replay_buffers/checkpointers.py +622 -0
  163. torchrl/data/replay_buffers/ray_buffer.py +292 -0
  164. torchrl/data/replay_buffers/replay_buffers.py +2376 -0
  165. torchrl/data/replay_buffers/samplers.py +2578 -0
  166. torchrl/data/replay_buffers/scheduler.py +265 -0
  167. torchrl/data/replay_buffers/storages.py +2412 -0
  168. torchrl/data/replay_buffers/utils.py +1042 -0
  169. torchrl/data/replay_buffers/writers.py +781 -0
  170. torchrl/data/tensor_specs.py +7101 -0
  171. torchrl/data/utils.py +334 -0
  172. torchrl/envs/__init__.py +265 -0
  173. torchrl/envs/async_envs.py +1105 -0
  174. torchrl/envs/batched_envs.py +3093 -0
  175. torchrl/envs/common.py +4241 -0
  176. torchrl/envs/custom/__init__.py +11 -0
  177. torchrl/envs/custom/chess.py +617 -0
  178. torchrl/envs/custom/llm.py +214 -0
  179. torchrl/envs/custom/pendulum.py +401 -0
  180. torchrl/envs/custom/san_moves.txt +29274 -0
  181. torchrl/envs/custom/tictactoeenv.py +288 -0
  182. torchrl/envs/env_creator.py +263 -0
  183. torchrl/envs/gym_like.py +752 -0
  184. torchrl/envs/libs/__init__.py +68 -0
  185. torchrl/envs/libs/_gym_utils.py +326 -0
  186. torchrl/envs/libs/brax.py +846 -0
  187. torchrl/envs/libs/dm_control.py +544 -0
  188. torchrl/envs/libs/envpool.py +447 -0
  189. torchrl/envs/libs/gym.py +2239 -0
  190. torchrl/envs/libs/habitat.py +138 -0
  191. torchrl/envs/libs/isaac_lab.py +87 -0
  192. torchrl/envs/libs/isaacgym.py +203 -0
  193. torchrl/envs/libs/jax_utils.py +166 -0
  194. torchrl/envs/libs/jumanji.py +963 -0
  195. torchrl/envs/libs/meltingpot.py +599 -0
  196. torchrl/envs/libs/openml.py +153 -0
  197. torchrl/envs/libs/openspiel.py +652 -0
  198. torchrl/envs/libs/pettingzoo.py +1042 -0
  199. torchrl/envs/libs/procgen.py +351 -0
  200. torchrl/envs/libs/robohive.py +429 -0
  201. torchrl/envs/libs/smacv2.py +645 -0
  202. torchrl/envs/libs/unity_mlagents.py +891 -0
  203. torchrl/envs/libs/utils.py +147 -0
  204. torchrl/envs/libs/vmas.py +813 -0
  205. torchrl/envs/llm/__init__.py +63 -0
  206. torchrl/envs/llm/chat.py +730 -0
  207. torchrl/envs/llm/datasets/README.md +4 -0
  208. torchrl/envs/llm/datasets/__init__.py +17 -0
  209. torchrl/envs/llm/datasets/gsm8k.py +353 -0
  210. torchrl/envs/llm/datasets/ifeval.py +274 -0
  211. torchrl/envs/llm/envs.py +789 -0
  212. torchrl/envs/llm/libs/README.md +3 -0
  213. torchrl/envs/llm/libs/__init__.py +8 -0
  214. torchrl/envs/llm/libs/mlgym.py +869 -0
  215. torchrl/envs/llm/reward/__init__.py +10 -0
  216. torchrl/envs/llm/reward/gsm8k.py +324 -0
  217. torchrl/envs/llm/reward/ifeval/README.md +13 -0
  218. torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
  219. torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
  220. torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
  221. torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
  222. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
  223. torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
  224. torchrl/envs/llm/transforms/__init__.py +55 -0
  225. torchrl/envs/llm/transforms/browser.py +292 -0
  226. torchrl/envs/llm/transforms/dataloading.py +859 -0
  227. torchrl/envs/llm/transforms/format.py +73 -0
  228. torchrl/envs/llm/transforms/kl.py +1544 -0
  229. torchrl/envs/llm/transforms/policy_version.py +189 -0
  230. torchrl/envs/llm/transforms/reason.py +323 -0
  231. torchrl/envs/llm/transforms/tokenizer.py +321 -0
  232. torchrl/envs/llm/transforms/tools.py +1955 -0
  233. torchrl/envs/model_based/__init__.py +9 -0
  234. torchrl/envs/model_based/common.py +180 -0
  235. torchrl/envs/model_based/dreamer.py +112 -0
  236. torchrl/envs/transforms/__init__.py +147 -0
  237. torchrl/envs/transforms/functional.py +48 -0
  238. torchrl/envs/transforms/gym_transforms.py +203 -0
  239. torchrl/envs/transforms/module.py +341 -0
  240. torchrl/envs/transforms/r3m.py +372 -0
  241. torchrl/envs/transforms/ray_service.py +663 -0
  242. torchrl/envs/transforms/rb_transforms.py +214 -0
  243. torchrl/envs/transforms/transforms.py +11835 -0
  244. torchrl/envs/transforms/utils.py +94 -0
  245. torchrl/envs/transforms/vc1.py +307 -0
  246. torchrl/envs/transforms/vecnorm.py +845 -0
  247. torchrl/envs/transforms/vip.py +407 -0
  248. torchrl/envs/utils.py +1718 -0
  249. torchrl/envs/vec_envs.py +11 -0
  250. torchrl/modules/__init__.py +206 -0
  251. torchrl/modules/distributions/__init__.py +73 -0
  252. torchrl/modules/distributions/continuous.py +830 -0
  253. torchrl/modules/distributions/discrete.py +908 -0
  254. torchrl/modules/distributions/truncated_normal.py +187 -0
  255. torchrl/modules/distributions/utils.py +233 -0
  256. torchrl/modules/llm/__init__.py +62 -0
  257. torchrl/modules/llm/backends/__init__.py +65 -0
  258. torchrl/modules/llm/backends/vllm/__init__.py +94 -0
  259. torchrl/modules/llm/backends/vllm/_models.py +46 -0
  260. torchrl/modules/llm/backends/vllm/base.py +72 -0
  261. torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
  262. torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
  263. torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
  264. torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
  265. torchrl/modules/llm/policies/__init__.py +28 -0
  266. torchrl/modules/llm/policies/common.py +1809 -0
  267. torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
  268. torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
  269. torchrl/modules/llm/utils.py +23 -0
  270. torchrl/modules/mcts/__init__.py +21 -0
  271. torchrl/modules/mcts/scores.py +579 -0
  272. torchrl/modules/models/__init__.py +86 -0
  273. torchrl/modules/models/batchrenorm.py +119 -0
  274. torchrl/modules/models/decision_transformer.py +179 -0
  275. torchrl/modules/models/exploration.py +731 -0
  276. torchrl/modules/models/llm.py +156 -0
  277. torchrl/modules/models/model_based.py +596 -0
  278. torchrl/modules/models/models.py +1712 -0
  279. torchrl/modules/models/multiagent.py +1067 -0
  280. torchrl/modules/models/recipes/impala.py +185 -0
  281. torchrl/modules/models/utils.py +162 -0
  282. torchrl/modules/planners/__init__.py +10 -0
  283. torchrl/modules/planners/cem.py +228 -0
  284. torchrl/modules/planners/common.py +73 -0
  285. torchrl/modules/planners/mppi.py +265 -0
  286. torchrl/modules/tensordict_module/__init__.py +89 -0
  287. torchrl/modules/tensordict_module/actors.py +2457 -0
  288. torchrl/modules/tensordict_module/common.py +529 -0
  289. torchrl/modules/tensordict_module/exploration.py +814 -0
  290. torchrl/modules/tensordict_module/probabilistic.py +321 -0
  291. torchrl/modules/tensordict_module/rnn.py +1639 -0
  292. torchrl/modules/tensordict_module/sequence.py +132 -0
  293. torchrl/modules/tensordict_module/world_models.py +34 -0
  294. torchrl/modules/utils/__init__.py +38 -0
  295. torchrl/modules/utils/mappings.py +9 -0
  296. torchrl/modules/utils/utils.py +89 -0
  297. torchrl/objectives/__init__.py +78 -0
  298. torchrl/objectives/a2c.py +659 -0
  299. torchrl/objectives/common.py +753 -0
  300. torchrl/objectives/cql.py +1346 -0
  301. torchrl/objectives/crossq.py +710 -0
  302. torchrl/objectives/ddpg.py +453 -0
  303. torchrl/objectives/decision_transformer.py +371 -0
  304. torchrl/objectives/deprecated.py +516 -0
  305. torchrl/objectives/dqn.py +683 -0
  306. torchrl/objectives/dreamer.py +488 -0
  307. torchrl/objectives/functional.py +48 -0
  308. torchrl/objectives/gail.py +258 -0
  309. torchrl/objectives/iql.py +996 -0
  310. torchrl/objectives/llm/__init__.py +30 -0
  311. torchrl/objectives/llm/grpo.py +846 -0
  312. torchrl/objectives/llm/sft.py +482 -0
  313. torchrl/objectives/multiagent/__init__.py +8 -0
  314. torchrl/objectives/multiagent/qmixer.py +396 -0
  315. torchrl/objectives/ppo.py +1669 -0
  316. torchrl/objectives/redq.py +683 -0
  317. torchrl/objectives/reinforce.py +530 -0
  318. torchrl/objectives/sac.py +1580 -0
  319. torchrl/objectives/td3.py +570 -0
  320. torchrl/objectives/td3_bc.py +625 -0
  321. torchrl/objectives/utils.py +782 -0
  322. torchrl/objectives/value/__init__.py +28 -0
  323. torchrl/objectives/value/advantages.py +1956 -0
  324. torchrl/objectives/value/functional.py +1459 -0
  325. torchrl/objectives/value/utils.py +360 -0
  326. torchrl/record/__init__.py +17 -0
  327. torchrl/record/loggers/__init__.py +23 -0
  328. torchrl/record/loggers/common.py +48 -0
  329. torchrl/record/loggers/csv.py +226 -0
  330. torchrl/record/loggers/mlflow.py +142 -0
  331. torchrl/record/loggers/tensorboard.py +139 -0
  332. torchrl/record/loggers/trackio.py +163 -0
  333. torchrl/record/loggers/utils.py +78 -0
  334. torchrl/record/loggers/wandb.py +214 -0
  335. torchrl/record/recorder.py +554 -0
  336. torchrl/services/__init__.py +79 -0
  337. torchrl/services/base.py +109 -0
  338. torchrl/services/ray_service.py +453 -0
  339. torchrl/testing/__init__.py +107 -0
  340. torchrl/testing/assertions.py +179 -0
  341. torchrl/testing/dist_utils.py +122 -0
  342. torchrl/testing/env_creators.py +227 -0
  343. torchrl/testing/env_helper.py +35 -0
  344. torchrl/testing/gym_helpers.py +156 -0
  345. torchrl/testing/llm_mocks.py +119 -0
  346. torchrl/testing/mocking_classes.py +2720 -0
  347. torchrl/testing/modules.py +295 -0
  348. torchrl/testing/mp_helpers.py +15 -0
  349. torchrl/testing/ray_helpers.py +293 -0
  350. torchrl/testing/utils.py +190 -0
  351. torchrl/trainers/__init__.py +42 -0
  352. torchrl/trainers/algorithms/__init__.py +11 -0
  353. torchrl/trainers/algorithms/configs/__init__.py +705 -0
  354. torchrl/trainers/algorithms/configs/collectors.py +216 -0
  355. torchrl/trainers/algorithms/configs/common.py +41 -0
  356. torchrl/trainers/algorithms/configs/data.py +308 -0
  357. torchrl/trainers/algorithms/configs/envs.py +104 -0
  358. torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
  359. torchrl/trainers/algorithms/configs/logging.py +80 -0
  360. torchrl/trainers/algorithms/configs/modules.py +570 -0
  361. torchrl/trainers/algorithms/configs/objectives.py +177 -0
  362. torchrl/trainers/algorithms/configs/trainers.py +340 -0
  363. torchrl/trainers/algorithms/configs/transforms.py +955 -0
  364. torchrl/trainers/algorithms/configs/utils.py +252 -0
  365. torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
  366. torchrl/trainers/algorithms/configs/weight_update.py +159 -0
  367. torchrl/trainers/algorithms/ppo.py +373 -0
  368. torchrl/trainers/algorithms/sac.py +308 -0
  369. torchrl/trainers/helpers/__init__.py +40 -0
  370. torchrl/trainers/helpers/collectors.py +416 -0
  371. torchrl/trainers/helpers/envs.py +573 -0
  372. torchrl/trainers/helpers/logger.py +33 -0
  373. torchrl/trainers/helpers/losses.py +132 -0
  374. torchrl/trainers/helpers/models.py +658 -0
  375. torchrl/trainers/helpers/replay_buffer.py +59 -0
  376. torchrl/trainers/helpers/trainers.py +301 -0
  377. torchrl/trainers/trainers.py +2052 -0
  378. torchrl/weight_update/__init__.py +33 -0
  379. torchrl/weight_update/_distributed.py +749 -0
  380. torchrl/weight_update/_mp.py +624 -0
  381. torchrl/weight_update/_noupdate.py +102 -0
  382. torchrl/weight_update/_ray.py +1032 -0
  383. torchrl/weight_update/_rpc.py +284 -0
  384. torchrl/weight_update/_shared.py +891 -0
  385. torchrl/weight_update/llm/__init__.py +32 -0
  386. torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
  387. torchrl/weight_update/llm/vllm_nccl.py +710 -0
  388. torchrl/weight_update/utils.py +73 -0
  389. torchrl/weight_update/weight_sync_schemes.py +1244 -0
  390. torchrl-0.11.0.dist-info/METADATA +1308 -0
  391. torchrl-0.11.0.dist-info/RECORD +395 -0
  392. torchrl-0.11.0.dist-info/WHEEL +5 -0
  393. torchrl-0.11.0.dist-info/entry_points.txt +2 -0
  394. torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
  395. torchrl-0.11.0.dist-info/top_level.txt +7 -0
@@ -0,0 +1,2009 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import threading
5
+ import warnings
6
+ import weakref
7
+ from collections import OrderedDict
8
+ from collections.abc import Callable, Iterator, Sequence
9
+ from textwrap import indent
10
+ from typing import Any
11
+
12
+ import torch
13
+
14
+ from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
15
+ from tensordict.nn import CudaGraphModule, TensorDictModule, TensorDictModuleBase
16
+ from torch import nn
17
+ from torchrl import compile_with_warmup
18
+ from torchrl._utils import (
19
+ _ends_with,
20
+ _make_ordinal_device,
21
+ _replace_last,
22
+ accept_remote_rref_udf_invocation,
23
+ prod,
24
+ RL_WARNINGS,
25
+ )
26
+ from torchrl.collectors._base import _LegacyCollectorMeta, BaseCollector, ProfileConfig
27
+ from torchrl.collectors._constants import (
28
+ cudagraph_mark_step_begin,
29
+ DEFAULT_EXPLORATION_TYPE,
30
+ ExplorationType,
31
+ )
32
+ from torchrl.collectors.utils import _TrajectoryPool, split_trajectories
33
+ from torchrl.collectors.weight_update import WeightUpdaterBase
34
+ from torchrl.data import ReplayBuffer
35
+ from torchrl.data.utils import DEVICE_TYPING
36
+ from torchrl.envs import EnvBase, EnvCreator, StepCounter, TransformedEnv
37
+ from torchrl.envs.common import _do_nothing
38
+ from torchrl.envs.llm.transforms import PolicyVersion
39
+ from torchrl.envs.utils import (
40
+ _aggregate_end_of_traj,
41
+ _make_compatible_policy,
42
+ set_exploration_type,
43
+ )
44
+ from torchrl.modules import RandomPolicy, set_exploration_modules_spec_from_env
45
+ from torchrl.weight_update.utils import _resolve_model
46
+ from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme
47
+
48
+
49
+ class _CollectorProfiler:
50
+ """Helper class for profiling collector rollouts in single-process mode.
51
+
52
+ Manages the PyTorch profiler lifecycle for the Collector class.
53
+ """
54
+
55
+ def __init__(self, profile_config: ProfileConfig):
56
+ self.config = profile_config
57
+ self.rollout_count = 0
58
+ self._profiler = None
59
+ self._stopped = False
60
+ self._active = False
61
+
62
+ # Set up profiler schedule
63
+ active_rollouts = self.config.num_rollouts - self.config.warmup_rollouts
64
+ profiler_schedule = torch.profiler.schedule(
65
+ skip_first=self.config.warmup_rollouts,
66
+ wait=0,
67
+ warmup=0,
68
+ active=active_rollouts,
69
+ repeat=1,
70
+ )
71
+
72
+ # Get activities
73
+ activities = self.config.get_activities()
74
+ if not activities:
75
+ return
76
+
77
+ # Determine trace handler
78
+ if self.config.on_trace_ready is not None:
79
+ on_trace_ready = self.config.on_trace_ready
80
+ else:
81
+ save_path = self.config.get_save_path(
82
+ 0
83
+ ) # Use worker_idx 0 for single-process
84
+ save_path.parent.mkdir(parents=True, exist_ok=True)
85
+
86
+ from torchrl import logger as torchrl_logger
87
+
88
+ def on_trace_ready(prof, save_path=save_path):
89
+ prof.export_chrome_trace(str(save_path))
90
+ torchrl_logger.info(f"Collector: Profiling trace saved to {save_path}")
91
+
92
+ self._profiler = torch.profiler.profile(
93
+ activities=activities,
94
+ schedule=profiler_schedule,
95
+ on_trace_ready=on_trace_ready,
96
+ record_shapes=self.config.record_shapes,
97
+ profile_memory=self.config.profile_memory,
98
+ with_stack=self.config.with_stack,
99
+ with_flops=self.config.with_flops,
100
+ )
101
+ self._active = True
102
+
103
+ def start(self) -> None:
104
+ """Start the profiler."""
105
+ from torchrl import logger as torchrl_logger
106
+
107
+ if self._profiler is not None and not self._stopped:
108
+ self._profiler.start()
109
+ torchrl_logger.info(
110
+ f"Collector: Profiling started. "
111
+ f"Will profile rollouts {self.config.warmup_rollouts} to {self.config.num_rollouts - 1}."
112
+ )
113
+
114
+ def step(self) -> bool:
115
+ """Step the profiler after a rollout.
116
+
117
+ Returns:
118
+ True if profiling is complete.
119
+ """
120
+ if self._profiler is None or self._stopped:
121
+ return False
122
+
123
+ self.rollout_count += 1
124
+ self._profiler.step()
125
+
126
+ # Check if profiling is complete
127
+ if self.rollout_count >= self.config.num_rollouts:
128
+ self.stop()
129
+ return True
130
+
131
+ return False
132
+
133
+ def stop(self) -> None:
134
+ """Stop the profiler and export trace."""
135
+ from torchrl import logger as torchrl_logger
136
+
137
+ if self._profiler is not None and not self._stopped:
138
+ self._profiler.stop()
139
+ self._stopped = True
140
+ torchrl_logger.info(
141
+ f"Collector: Profiling complete after {self.rollout_count} rollouts."
142
+ )
143
+
144
+ @property
145
+ def is_active(self) -> bool:
146
+ """Check if profiling is active."""
147
+ return self._active and not self._stopped
148
+
149
+ @contextlib.contextmanager
150
+ def profile_rollout(self):
151
+ """Context manager for profiling a single rollout."""
152
+ if self._profiler is not None and not self._stopped:
153
+ with torch.profiler.record_function("collector_rollout"):
154
+ yield
155
+ else:
156
+ yield
157
+
158
+
159
+ def _cuda_sync_if_initialized():
160
+ """Synchronize CUDA only if it has been initialized.
161
+
162
+ This is a safe alternative to calling `torch.cuda.synchronize()` directly.
163
+ In forked subprocesses on machines with CUDA, calling `synchronize()` will
164
+ fail with "Cannot re-initialize CUDA in forked subprocess" if CUDA was
165
+ initialized in the parent process before fork. By checking
166
+ `is_initialized()` first, we skip the sync in such cases since no CUDA
167
+ operations have occurred in this process.
168
+ """
169
+ if torch.cuda.is_initialized():
170
+ torch.cuda.synchronize()
171
+
172
+
173
+ @accept_remote_rref_udf_invocation
174
+ class Collector(BaseCollector):
175
+ """Generic data collector for RL problems. Requires an environment constructor and a policy.
176
+
177
+ Args:
178
+ create_env_fn (Callable or EnvBase): a callable that returns an instance of
179
+ :class:`~torchrl.envs.EnvBase` class, or the env itself.
180
+ policy (Callable): Policy to be executed in the environment.
181
+ Must accept :class:`tensordict.tensordict.TensorDictBase` object as input.
182
+ If ``None`` is provided, the policy used will be a
183
+ :class:`~torchrl.collectors.RandomPolicy` instance with the environment
184
+ ``action_spec``.
185
+ Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`.
186
+ This is the recommended usage of the collector.
187
+ Other callables are accepted too:
188
+ If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module`
189
+ instances) it will be wrapped in a `nn.Module` first.
190
+ Then, the collector will try to assess if these
191
+ modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
192
+
193
+ - If the policy forward signature matches any of ``forward(self, tensordict)``,
194
+ ``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
195
+ any typing with a single argument typed as a subclass of ``TensorDictBase``)
196
+ then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
197
+
198
+ - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
199
+
200
+ .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
201
+ pickled directly), the ``policy_factory`` should be used instead.
202
+
203
+ Keyword Args:
204
+ policy_factory (Callable[[], Callable], optional): a callable that returns
205
+ a policy instance. This is exclusive with the `policy` argument.
206
+
207
+ .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
208
+
209
+ frames_per_batch (int): A keyword-only argument representing the total
210
+ number of elements in a batch.
211
+ total_frames (int): A keyword-only argument representing the total
212
+ number of frames returned by the collector
213
+ during its lifespan. If the ``total_frames`` is not divisible by
214
+ ``frames_per_batch``, an exception is raised.
215
+ Endless collectors can be created by passing ``total_frames=-1``.
216
+ Defaults to ``-1`` (endless collector).
217
+ device (int, str or torch.device, optional): The generic device of the
218
+ collector. The ``device`` args fills any non-specified device: if
219
+ ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
220
+ ``env_device`` is not specified, its value will be set to ``device``.
221
+ Defaults to ``None`` (No default device).
222
+ storing_device (int, str or torch.device, optional): The device on which
223
+ the output :class:`~tensordict.TensorDict` will be stored.
224
+ If ``device`` is passed and ``storing_device`` is ``None``, it will
225
+ default to the value indicated by ``device``.
226
+ For long trajectories, it may be necessary to store the data on a different
227
+ device than the one where the policy and env are executed.
228
+ Defaults to ``None`` (the output tensordict isn't on a specific device,
229
+ leaf tensors sit on the device where they were created).
230
+ env_device (int, str or torch.device, optional): The device on which
231
+ the environment should be cast (or executed if that functionality is
232
+ supported). If not specified and the env has a non-``None`` device,
233
+ ``env_device`` will default to that value. If ``device`` is passed
234
+ and ``env_device=None``, it will default to ``device``. If the value
235
+ as such specified of ``env_device`` differs from ``policy_device``
236
+ and one of them is not ``None``, the data will be cast to ``env_device``
237
+ before being passed to the env (i.e., passing different devices to
238
+ policy and env is supported). Defaults to ``None``.
239
+ policy_device (int, str or torch.device, optional): The device on which
240
+ the policy should be cast.
241
+ If ``device`` is passed and ``policy_device=None``, it will default
242
+ to ``device``. If the value as such specified of ``policy_device``
243
+ differs from ``env_device`` and one of them is not ``None``,
244
+ the data will be cast to ``policy_device`` before being passed to
245
+ the policy (i.e., passing different devices to policy and env is
246
+ supported). Defaults to ``None``.
247
+ create_env_kwargs (dict, optional): Dictionary of kwargs for
248
+ ``create_env_fn``.
249
+ max_frames_per_traj (int, optional): Maximum steps per trajectory.
250
+ Note that a trajectory can span across multiple batches (unless
251
+ ``reset_at_each_iter`` is set to ``True``, see below).
252
+ Once a trajectory reaches ``n_steps``, the environment is reset.
253
+ If the environment wraps multiple environments together, the number
254
+ of steps is tracked for each environment independently. Negative
255
+ values are allowed, in which case this argument is ignored.
256
+ Defaults to ``None`` (i.e., no maximum number of steps).
257
+ init_random_frames (int, optional): Number of frames for which the
258
+ policy is ignored before it is called. This feature is mainly
259
+ intended to be used in offline/model-based settings, where a
260
+ batch of random trajectories can be used to initialize training.
261
+ If provided, it will be rounded up to the closest multiple of frames_per_batch.
262
+ Defaults to ``None`` (i.e. no random frames).
263
+ reset_at_each_iter (bool, optional): Whether environments should be reset
264
+ at the beginning of a batch collection.
265
+ Defaults to ``False``.
266
+ postproc (Callable, optional): A post-processing transform, such as
267
+ a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
268
+ instance.
269
+
270
+ .. warning:: Postproc is not applied when a replay buffer is used and items are added to the buffer
271
+ as they are produced (`extend_buffer=False`). The recommended usage is to use `extend_buffer=True`.
272
+
273
+ Defaults to ``None``.
274
+ split_trajs (bool, optional): Boolean indicating whether the resulting
275
+ TensorDict should be split according to the trajectories.
276
+ See :func:`~torchrl.collectors.utils.split_trajectories` for more
277
+ information.
278
+ Defaults to ``False``.
279
+ exploration_type (ExplorationType, optional): interaction mode to be used when
280
+ collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
281
+ ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
282
+ or ``torchrl.envs.utils.ExplorationType.MEAN``.
283
+ return_same_td (bool, optional): if ``True``, the same TensorDict
284
+ will be returned at each iteration, with its values
285
+ updated. This feature should be used cautiously: if the same
286
+ tensordict is added to a replay buffer for instance,
287
+ the whole content of the buffer will be identical.
288
+ Default is ``False``.
289
+ interruptor (_Interruptor, optional):
290
+ An _Interruptor object that can be used from outside the class to control rollout collection.
291
+ The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement
292
+ strategies such as preeptively stopping rollout collection.
293
+ Default is ``False``.
294
+ set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding
295
+ ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of
296
+ a rollout is reached. If no ``"truncated"`` key is found, an exception is raised.
297
+ Truncated keys can be set through ``env.add_truncated_keys``.
298
+ Defaults to ``False``.
299
+ use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data.
300
+ This isn't compatible with environments with dynamic specs. Defaults to ``True``
301
+ for envs without dynamic specs, ``False`` for others.
302
+ replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts
303
+ but populate the buffer instead.
304
+ Defaults to ``None``.
305
+
306
+ .. seealso:: By default (``extend_buffer=True``), the buffer is extended with entire rollouts.
307
+ If the buffer needs to be populated with individual frames as they are collected,
308
+ set ``extend_buffer=False`` (deprecated).
309
+
310
+ .. warning:: Using a replay buffer with a `postproc` or `split_trajs=True` requires
311
+ `extend_buffer=True`, as the whole batch needs to be observed to apply these transforms.
312
+
313
+ extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not
314
+ with single steps. Defaults to `True`.
315
+
316
+ .. note:: Setting this to `False` is deprecated and will be removed in a future version.
317
+ Extending the buffer with entire rollouts is the recommended approach for better
318
+ compatibility with postprocessing and trajectory splitting.
319
+ trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
320
+ assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules
321
+ and ``False`` otherwise.
322
+ compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled
323
+ using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it
324
+ will be used to compile the policy.
325
+ cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
326
+ in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
327
+ If a dictionary of kwargs is passed, it will be used to wrap the policy.
328
+ no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
329
+ For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_
330
+ or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected
331
+ crashes.
332
+ Defaults to ``False``.
333
+ weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase`
334
+ or its subclass, responsible for updating the policy weights on remote inference workers.
335
+ This is typically not used in :class:`~torchrl.collectors.Collector` as it operates in a single-process environment.
336
+ Consider using a constructor if the updater needs to be serialized.
337
+ weight_sync_schemes (dict[str, WeightSyncScheme], optional): **Not supported for Collector**.
338
+ Collector is a leaf collector and cannot send weights to sub-collectors.
339
+ Providing this parameter will raise a ValueError.
340
+ Use ``weight_recv_schemes`` if you need to receive weights from a parent collector.
341
+ weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for
342
+ RECEIVING weights from parent collectors. Keys are model identifiers (e.g., "policy")
343
+ and values are WeightSyncScheme instances configured to receive weights.
344
+ This enables cascading weight updates in hierarchies like:
345
+ RPCDataCollector -> MultiSyncCollector -> Collector.
346
+ Defaults to ``None``.
347
+ track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
348
+ This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
349
+ Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
350
+ the policy version.
351
+ Defaults to `False`.
352
+
353
+ Examples:
354
+ >>> from torchrl.envs.libs.gym import GymEnv
355
+ >>> from tensordict.nn import TensorDictModule
356
+ >>> from torch import nn
357
+ >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
358
+ >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
359
+ >>> collector = Collector(
360
+ ... create_env_fn=env_maker,
361
+ ... policy=policy,
362
+ ... total_frames=2000,
363
+ ... max_frames_per_traj=50,
364
+ ... frames_per_batch=200,
365
+ ... init_random_frames=-1,
366
+ ... reset_at_each_iter=False,
367
+ ... device="cpu",
368
+ ... storing_device="cpu",
369
+ ... )
370
+ >>> for i, data in enumerate(collector):
371
+ ... if i == 2:
372
+ ... print(data)
373
+ ... break
374
+ TensorDict(
375
+ fields={
376
+ action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
377
+ collector: TensorDict(
378
+ fields={
379
+ traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
380
+ batch_size=torch.Size([200]),
381
+ device=cpu,
382
+ is_shared=False),
383
+ done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
384
+ next: TensorDict(
385
+ fields={
386
+ done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
387
+ observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
388
+ reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
389
+ step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
390
+ truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
391
+ batch_size=torch.Size([200]),
392
+ device=cpu,
393
+ is_shared=False),
394
+ observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
395
+ step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
396
+ truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
397
+ batch_size=torch.Size([200]),
398
+ device=cpu,
399
+ is_shared=False)
400
+ >>> del collector
401
+
402
+ The collector delivers batches of data that are marked with a ``"time"``
403
+ dimension.
404
+
405
+ Examples:
406
+ >>> assert data.names[-1] == "time"
407
+
408
+ """
409
+
410
+ _ignore_rb: bool = False
411
+
412
+ def __init__(
413
+ self,
414
+ create_env_fn: (
415
+ EnvBase | EnvCreator | Sequence[Callable[[], EnvBase]] # noqa: F821
416
+ ), # noqa: F821
417
+ policy: None
418
+ | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None,
419
+ *,
420
+ policy_factory: Callable[[], Callable] | None = None,
421
+ frames_per_batch: int,
422
+ total_frames: int = -1,
423
+ device: DEVICE_TYPING | None = None,
424
+ storing_device: DEVICE_TYPING | None = None,
425
+ policy_device: DEVICE_TYPING | None = None,
426
+ env_device: DEVICE_TYPING | None = None,
427
+ create_env_kwargs: dict[str, Any] | None = None,
428
+ max_frames_per_traj: int | None = None,
429
+ init_random_frames: int | None = None,
430
+ reset_at_each_iter: bool = False,
431
+ postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
432
+ split_trajs: bool | None = None,
433
+ exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
434
+ return_same_td: bool = False,
435
+ reset_when_done: bool = True,
436
+ interruptor=None,
437
+ set_truncated: bool = False,
438
+ use_buffers: bool | None = None,
439
+ replay_buffer: ReplayBuffer | None = None,
440
+ extend_buffer: bool = True,
441
+ local_init_rb: bool | None = None,
442
+ trust_policy: bool | None = None,
443
+ compile_policy: bool | dict[str, Any] | None = None,
444
+ cudagraph_policy: bool | dict[str, Any] | None = None,
445
+ no_cuda_sync: bool = False,
446
+ weight_updater: WeightUpdaterBase
447
+ | Callable[[], WeightUpdaterBase]
448
+ | None = None,
449
+ weight_sync_schemes: dict[str, WeightSyncScheme] | None = None,
450
+ weight_recv_schemes: dict[str, WeightSyncScheme] | None = None,
451
+ track_policy_version: bool = False,
452
+ worker_idx: int | None = None,
453
+ **kwargs,
454
+ ):
455
+ self.closed = True
456
+ self.worker_idx = worker_idx
457
+
458
+ # Note: weight_sync_schemes can be used to send weights to components
459
+ # within the environment (e.g., RayModuleTransform), not just sub-collectors
460
+
461
+ # Initialize environment
462
+ env = self._init_env(create_env_fn, create_env_kwargs)
463
+
464
+ # Initialize policy
465
+ policy = self._init_policy(policy, policy_factory, env, trust_policy)
466
+ self._read_compile_kwargs(compile_policy, cudagraph_policy)
467
+
468
+ # Handle trajectory pool and validate kwargs
469
+ self._traj_pool_val = kwargs.pop("traj_pool", None)
470
+ if kwargs:
471
+ raise TypeError(
472
+ f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}."
473
+ )
474
+
475
+ # Set up devices and synchronization
476
+ self._setup_devices(
477
+ device=device,
478
+ storing_device=storing_device,
479
+ policy_device=policy_device,
480
+ env_device=env_device,
481
+ no_cuda_sync=no_cuda_sync,
482
+ )
483
+
484
+ self.env: EnvBase = env
485
+ del env
486
+
487
+ # Set up policy version tracking
488
+ self._setup_policy_version_tracking(track_policy_version)
489
+
490
+ # Set up replay buffer
491
+ self._setup_replay_buffer(
492
+ replay_buffer=replay_buffer,
493
+ extend_buffer=extend_buffer,
494
+ local_init_rb=local_init_rb,
495
+ postproc=postproc,
496
+ split_trajs=split_trajs,
497
+ return_same_td=return_same_td,
498
+ use_buffers=use_buffers,
499
+ )
500
+
501
+ self.closed = False
502
+
503
+ # Validate reset_when_done
504
+ if not reset_when_done:
505
+ raise ValueError("reset_when_done is deprecated.")
506
+ self.reset_when_done = reset_when_done
507
+ self.n_env = self.env.batch_size.numel()
508
+
509
+ # Register collector with policy and env
510
+ if hasattr(policy, "register_collector"):
511
+ policy.register_collector(self)
512
+ if hasattr(self.env, "register_collector"):
513
+ self.env.register_collector(self)
514
+
515
+ # Set up policy and weights
516
+ self._setup_policy_and_weights(policy)
517
+
518
+ # Apply environment device
519
+ self._apply_env_device()
520
+
521
+ # Set up max frames per trajectory
522
+ self._setup_max_frames_per_traj(max_frames_per_traj)
523
+
524
+ # Validate and set total frames
525
+ self.reset_at_each_iter = reset_at_each_iter
526
+ self._setup_total_frames(total_frames, frames_per_batch)
527
+
528
+ # Set up init random frames
529
+ self._setup_init_random_frames(init_random_frames, frames_per_batch)
530
+
531
+ # Set up postproc
532
+ self._setup_postproc(postproc)
533
+
534
+ # Calculate frames per batch
535
+ self._setup_frames_per_batch(frames_per_batch)
536
+
537
+ # Set exploration and other options
538
+ self.exploration_type = (
539
+ exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE
540
+ )
541
+ self.return_same_td = return_same_td
542
+ self.set_truncated = set_truncated
543
+
544
+ # Create shuttle and rollout buffers
545
+ self._make_shuttle()
546
+ self._maybe_make_final_rollout(make_rollout=self._use_buffers)
547
+ self._set_truncated_keys()
548
+
549
+ # Set split trajectories option
550
+ if split_trajs is None:
551
+ split_trajs = False
552
+ self.split_trajs = split_trajs
553
+ self._exclude_private_keys = True
554
+
555
+ # Set up interruptor and frame tracking
556
+ self.interruptor = interruptor
557
+ self._frames = 0
558
+ self._iter = -1
559
+
560
+ # Set up weight synchronization
561
+ self._setup_weight_sync(weight_updater, weight_sync_schemes)
562
+
563
+ # Set up weight receivers if provided
564
+ if weight_recv_schemes is not None:
565
+ self.register_scheme_receiver(weight_recv_schemes)
566
+
567
+ def _init_env(
568
+ self,
569
+ create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase],
570
+ create_env_kwargs: dict[str, Any] | None,
571
+ ) -> EnvBase:
572
+ """Initialize and configure the environment."""
573
+ from torchrl.envs.batched_envs import BatchedEnvBase
574
+
575
+ if create_env_kwargs is None:
576
+ create_env_kwargs = {}
577
+
578
+ if not isinstance(create_env_fn, EnvBase):
579
+ env = create_env_fn(**create_env_kwargs)
580
+ else:
581
+ env = create_env_fn
582
+ if create_env_kwargs:
583
+ if not isinstance(env, BatchedEnvBase):
584
+ raise RuntimeError(
585
+ "kwargs were passed to Collector but they can't be set "
586
+ f"on environment of type {type(create_env_fn)}."
587
+ )
588
+ env.update_kwargs(create_env_kwargs)
589
+ return env
590
+
591
+ def _init_policy(
592
+ self,
593
+ policy: TensorDictModule | Callable | None,
594
+ policy_factory: Callable[[], Callable] | None,
595
+ env: EnvBase,
596
+ trust_policy: bool | None,
597
+ ) -> TensorDictModule | Callable:
598
+ """Initialize and configure the policy before device placement / wrapping."""
599
+ if policy is None:
600
+ if policy_factory is not None:
601
+ policy = policy_factory()
602
+ else:
603
+ policy = RandomPolicy(env.full_action_spec)
604
+ elif policy_factory is not None:
605
+ raise TypeError("policy_factory cannot be used with policy argument.")
606
+
607
+ if trust_policy is None:
608
+ trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule))
609
+ self.trust_policy = trust_policy
610
+
611
+ return policy
612
+
613
+ def _setup_devices(
614
+ self,
615
+ device: DEVICE_TYPING | None,
616
+ storing_device: DEVICE_TYPING | None,
617
+ policy_device: DEVICE_TYPING | None,
618
+ env_device: DEVICE_TYPING | None,
619
+ no_cuda_sync: bool,
620
+ ) -> None:
621
+ """Set up devices and synchronization functions."""
622
+ storing_device, policy_device, env_device = self._get_devices(
623
+ storing_device=storing_device,
624
+ policy_device=policy_device,
625
+ env_device=env_device,
626
+ device=device,
627
+ )
628
+
629
+ self.storing_device = storing_device
630
+ self._sync_storage = self._get_sync_fn(storing_device)
631
+
632
+ self.env_device = env_device
633
+ self._sync_env = self._get_sync_fn(env_device)
634
+
635
+ self.policy_device = policy_device
636
+ self._sync_policy = self._get_sync_fn(policy_device)
637
+
638
+ self.device = device
639
+ self.no_cuda_sync = no_cuda_sync
640
+ self._cast_to_policy_device = self.policy_device != self.env_device
641
+
642
+ def _get_sync_fn(self, device: torch.device | None) -> Callable:
643
+ """Get the appropriate synchronization function for a device."""
644
+ if device is not None and device.type != "cuda":
645
+ # When destination is not CUDA, we may need to sync to wait for
646
+ # async GPU→CPU transfers to complete before proceeding.
647
+ if torch.cuda.is_available():
648
+ # Return a safe wrapper that only syncs if CUDA was actually
649
+ # initialized. This avoids "Cannot re-initialize CUDA in forked
650
+ # subprocess" errors when using fork start method on GPU machines
651
+ # with CPU-only collectors.
652
+ return _cuda_sync_if_initialized
653
+ elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
654
+ return torch.mps.synchronize
655
+ elif hasattr(torch, "npu") and torch.npu.is_available():
656
+ return torch.npu.synchronize
657
+ elif device.type == "cpu":
658
+ return _do_nothing
659
+ else:
660
+ raise RuntimeError("Non supported device")
661
+ else:
662
+ return _do_nothing
663
+
664
+ def _setup_policy_version_tracking(
665
+ self, track_policy_version: bool | PolicyVersion
666
+ ) -> None:
667
+ """Set up policy version tracking if requested."""
668
+ self.policy_version_tracker = track_policy_version
669
+ if isinstance(track_policy_version, bool) and track_policy_version:
670
+ from torchrl.envs.batched_envs import BatchedEnvBase
671
+
672
+ if isinstance(self.env, BatchedEnvBase):
673
+ raise RuntimeError(
674
+ "BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, "
675
+ "and pass that transform to the collector."
676
+ )
677
+ self.policy_version_tracker = PolicyVersion()
678
+ self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore
679
+ elif hasattr(track_policy_version, "increment_version"):
680
+ self.policy_version_tracker = track_policy_version
681
+ self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore
682
+ else:
683
+ self.policy_version_tracker = None
684
+
685
+ def _setup_replay_buffer(
686
+ self,
687
+ replay_buffer: ReplayBuffer | None,
688
+ extend_buffer: bool,
689
+ local_init_rb: bool | None,
690
+ postproc: Callable | None,
691
+ split_trajs: bool | None,
692
+ return_same_td: bool,
693
+ use_buffers: bool | None,
694
+ ) -> None:
695
+ """Set up replay buffer configuration and validate compatibility."""
696
+ self.replay_buffer = replay_buffer
697
+ self.extend_buffer = extend_buffer
698
+
699
+ # Handle local_init_rb deprecation
700
+ if local_init_rb is None:
701
+ local_init_rb = False
702
+ if replay_buffer is not None and not local_init_rb:
703
+ warnings.warn(
704
+ "local_init_rb=False is deprecated and will be removed in v0.12. "
705
+ "The new storage-level initialization provides better performance.",
706
+ FutureWarning,
707
+ )
708
+ self.local_init_rb = local_init_rb
709
+
710
+ # Validate replay buffer compatibility
711
+ if self.replay_buffer is not None and not self._ignore_rb:
712
+ if postproc is not None and not self.extend_buffer:
713
+ raise TypeError(
714
+ "postproc must be None when a replay buffer is passed, or extend_buffer must be set to True."
715
+ )
716
+ if split_trajs not in (None, False) and not self.extend_buffer:
717
+ raise TypeError(
718
+ "split_trajs must be None/False when a replay buffer is passed, or extend_buffer must be set to True."
719
+ )
720
+ if return_same_td:
721
+ raise TypeError(
722
+ "return_same_td must be False when a replay buffer is passed, or extend_buffer must be set to True."
723
+ )
724
+ if use_buffers:
725
+ raise TypeError("replay_buffer is exclusive with use_buffers.")
726
+
727
+ if use_buffers is None:
728
+ use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None
729
+ self._use_buffers = use_buffers
730
+
731
+ def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None:
732
+ """Set up policy, wrapped policy, and extract weights."""
733
+ # Store weak reference to original policy before any transformations
734
+ # This allows update_policy_weights_ to sync from the original when no scheme is configured
735
+ if isinstance(policy, nn.Module):
736
+ self._orig_policy_ref = weakref.ref(policy)
737
+ else:
738
+ self._orig_policy_ref = None
739
+
740
+ # Check if policy has meta-device parameters (sent from weight sync schemes)
741
+ # In that case, skip device placement - weights will come from the receiver
742
+ has_meta_params = False
743
+ if isinstance(policy, nn.Module):
744
+ for p in policy.parameters():
745
+ if p.device.type == "meta":
746
+ has_meta_params = True
747
+ break
748
+
749
+ if has_meta_params:
750
+ # Policy has meta params - sent from weight sync schemes
751
+ # Skip device placement, weights will come from receiver
752
+ # Keep policy on meta device until weights are loaded
753
+ if not self.trust_policy:
754
+ self.policy = policy
755
+ env = getattr(self, "env", None)
756
+ try:
757
+ wrapped_policy = _make_compatible_policy(
758
+ policy=policy,
759
+ observation_spec=getattr(env, "observation_spec", None),
760
+ env=self.env,
761
+ )
762
+ except (TypeError, AttributeError, ValueError) as err:
763
+ raise TypeError(
764
+ "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details."
765
+ ) from err
766
+ self._wrapped_policy = wrapped_policy
767
+ else:
768
+ self.policy = self._wrapped_policy = policy
769
+
770
+ # Auto-configure exploration modules if needed (e.g. spec=None)
771
+ if isinstance(self.policy, nn.Module):
772
+ set_exploration_modules_spec_from_env(self.policy, self.env)
773
+
774
+ # For meta-parameter policies, keep the internal (worker-side) policy
775
+ # as the reference for collector state_dict / load_state_dict.
776
+ if isinstance(self.policy, nn.Module):
777
+ self._policy_w_state_dict = self.policy
778
+
779
+ # Don't extract weights yet - they're on meta device (empty)
780
+ self.policy_weights = TensorDict()
781
+ self.get_weights_fn = None
782
+ else:
783
+ # Normal path: move policy to correct device
784
+ policy, self.get_weights_fn = self._get_policy_and_device(policy=policy)
785
+
786
+ if not self.trust_policy:
787
+ self.policy = policy
788
+ env = getattr(self, "env", None)
789
+ try:
790
+ wrapped_policy = _make_compatible_policy(
791
+ policy=policy,
792
+ observation_spec=getattr(env, "observation_spec", None),
793
+ env=self.env,
794
+ )
795
+ except (TypeError, AttributeError, ValueError) as err:
796
+ raise TypeError(
797
+ "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details."
798
+ ) from err
799
+ self._wrapped_policy = wrapped_policy
800
+ else:
801
+ self.policy = self._wrapped_policy = policy
802
+
803
+ # Auto-configure exploration modules if needed (e.g. spec=None)
804
+ if isinstance(self.policy, nn.Module):
805
+ set_exploration_modules_spec_from_env(self.policy, self.env)
806
+
807
+ # Use the internal, unwrapped policy (cast to the correct device) as the
808
+ # reference for state_dict / load_state_dict and legacy weight extractors.
809
+ if isinstance(self.policy, nn.Module):
810
+ self._policy_w_state_dict = self.policy
811
+
812
+ # Extract policy weights from the uncompiled wrapped policy
813
+ # Access _wrapped_policy_uncompiled directly to avoid triggering compilation.
814
+ if isinstance(self._wrapped_policy_uncompiled, nn.Module):
815
+ self.policy_weights = TensorDict.from_module(
816
+ self._wrapped_policy_uncompiled, as_module=True
817
+ ).data
818
+ else:
819
+ self.policy_weights = TensorDict()
820
+
821
+ # If policy doesn't have meta params, compile immediately
822
+ # Otherwise, defer until first use (after weights are loaded)
823
+ if not has_meta_params and (self.compiled_policy or self.cudagraphed_policy):
824
+ self._wrapped_policy_maybe_compiled = self._compile_wrapped_policy(
825
+ self._wrapped_policy_uncompiled
826
+ )
827
+
828
+ def _compile_wrapped_policy(self, policy):
829
+ """Apply compilation and/or cudagraph to a policy."""
830
+ if self.compiled_policy:
831
+ policy = compile_with_warmup(policy, **self.compiled_policy_kwargs)
832
+ if self.cudagraphed_policy:
833
+ policy = CudaGraphModule(
834
+ policy,
835
+ in_keys=[],
836
+ out_keys=[],
837
+ device=self.policy_device,
838
+ **self.cudagraphed_policy_kwargs,
839
+ )
840
+ return policy
841
+
842
+ @property
843
+ def _wrapped_policy(self):
844
+ """Returns the compiled policy, compiling it lazily if needed."""
845
+ if (policy := self._wrapped_policy_maybe_compiled) is None:
846
+ if self.compiled_policy or self.cudagraphed_policy:
847
+ policy = (
848
+ self._wrapped_policy_maybe_compiled
849
+ ) = self._compile_wrapped_policy(self._wrapped_policy_uncompiled)
850
+ else:
851
+ policy = (
852
+ self._wrapped_policy_maybe_compiled
853
+ ) = self._wrapped_policy_uncompiled
854
+ return policy
855
+
856
+ @property
857
+ def _orig_policy(self):
858
+ """Returns the original policy passed to the collector, if still alive."""
859
+ if self._orig_policy_ref is not None:
860
+ return self._orig_policy_ref()
861
+ return None
862
+
863
+ @_wrapped_policy.setter
864
+ def _wrapped_policy(self, value):
865
+ """Allow setting the wrapped policy during initialization."""
866
+ self._wrapped_policy_uncompiled = value
867
+ self._wrapped_policy_maybe_compiled = None
868
+
869
+ def _apply_env_device(self) -> None:
870
+ """Apply device to environment if specified."""
871
+ if self.env_device:
872
+ self.env: EnvBase = self.env.to(self.env_device)
873
+ elif self.env.device is not None:
874
+ # Use the device of the env if none was provided
875
+ self.env_device = self.env.device
876
+
877
+ # Check if we need to cast to env device
878
+ self._cast_to_env_device = self._cast_to_policy_device or (
879
+ self.env.device != self.storing_device
880
+ )
881
+
882
+ def _setup_max_frames_per_traj(self, max_frames_per_traj: int | None) -> None:
883
+ """Set up maximum frames per trajectory and add StepCounter if needed."""
884
+ self.max_frames_per_traj = (
885
+ int(max_frames_per_traj) if max_frames_per_traj is not None else 0
886
+ )
887
+ if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0:
888
+ # Check that there is no StepCounter yet
889
+ for key in self.env.output_spec.keys(True, True):
890
+ if isinstance(key, str):
891
+ key = (key,)
892
+ if "step_count" in key:
893
+ raise ValueError(
894
+ "A 'step_count' key is already present in the environment "
895
+ "and the 'max_frames_per_traj' argument may conflict with "
896
+ "a 'StepCounter' that has already been set. "
897
+ "Possible solutions: Set max_frames_per_traj to 0 or "
898
+ "remove the StepCounter limit from the environment transforms."
899
+ )
900
+ self.env = TransformedEnv(
901
+ self.env, StepCounter(max_steps=self.max_frames_per_traj)
902
+ )
903
+
904
+ def _setup_total_frames(self, total_frames: int, frames_per_batch: int) -> None:
905
+ """Validate and set total frames."""
906
+ if total_frames is None or total_frames < 0:
907
+ total_frames = float("inf")
908
+ else:
909
+ remainder = total_frames % frames_per_batch
910
+ if remainder != 0 and RL_WARNINGS:
911
+ warnings.warn(
912
+ f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). "
913
+ f"This means {frames_per_batch - remainder} additional frames will be collected."
914
+ "To silence this message, set the environment variable RL_WARNINGS to False."
915
+ )
916
+ self.total_frames = (
917
+ int(total_frames) if total_frames != float("inf") else total_frames
918
+ )
919
+
920
+ def _setup_init_random_frames(
921
+ self, init_random_frames: int | None, frames_per_batch: int
922
+ ) -> None:
923
+ """Set up initial random frames."""
924
+ self.init_random_frames = (
925
+ int(init_random_frames) if init_random_frames not in (None, -1) else 0
926
+ )
927
+ if (
928
+ init_random_frames not in (-1, None, 0)
929
+ and init_random_frames % frames_per_batch != 0
930
+ and RL_WARNINGS
931
+ ):
932
+ warnings.warn(
933
+ f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), "
934
+ f" this results in more init_random_frames than requested"
935
+ f" ({-(-init_random_frames // frames_per_batch) * frames_per_batch})."
936
+ "To silence this message, set the environment variable RL_WARNINGS to False."
937
+ )
938
+
939
+ def _setup_postproc(self, postproc: Callable | None) -> None:
940
+ """Set up post-processing transform."""
941
+ self.postproc = postproc
942
+ if (
943
+ self.postproc is not None
944
+ and hasattr(self.postproc, "to")
945
+ and self.storing_device
946
+ ):
947
+ postproc = self.postproc.to(self.storing_device)
948
+ if postproc is not self.postproc and postproc is not None:
949
+ self.postproc = postproc
950
+
951
+ def _setup_frames_per_batch(self, frames_per_batch: int) -> None:
952
+ """Calculate and validate frames per batch."""
953
+ if frames_per_batch % self.n_env != 0 and RL_WARNINGS:
954
+ warnings.warn(
955
+ f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), "
956
+ f" this results in more frames_per_batch per iteration that requested"
957
+ f" ({-(-frames_per_batch // self.n_env) * self.n_env}). "
958
+ "To silence this message, set the environment variable RL_WARNINGS to False."
959
+ )
960
+ self.frames_per_batch = -(-frames_per_batch // self.n_env)
961
+ self.requested_frames_per_batch = self.frames_per_batch * self.n_env
962
+
963
+ def _setup_weight_sync(
964
+ self,
965
+ weight_updater: WeightUpdaterBase | Callable | None,
966
+ weight_sync_schemes: dict[str, WeightSyncScheme] | None,
967
+ ) -> None:
968
+ """Set up weight synchronization system."""
969
+ if weight_sync_schemes is not None:
970
+ # Use new simplified weight synchronization system
971
+ self._weight_sync_schemes = weight_sync_schemes
972
+ # Initialize and synchronize schemes that need sender-side setup
973
+ # (e.g., RayModuleTransformScheme for updating transforms in the env)
974
+ for model_id, scheme in weight_sync_schemes.items():
975
+ if not scheme.initialized_on_sender:
976
+ scheme.init_on_sender(model_id=model_id, context=self)
977
+ if not scheme.synchronized_on_sender:
978
+ scheme.connect()
979
+ self.weight_updater = None # Don't use legacy system
980
+ elif weight_updater is not None:
981
+ # Use legacy weight updater system if explicitly provided
982
+ if not isinstance(weight_updater, WeightUpdaterBase):
983
+ if callable(weight_updater):
984
+ weight_updater = weight_updater()
985
+ else:
986
+ raise TypeError(
987
+ f"weight_updater must be a subclass of WeightUpdaterBase. Got {type(weight_updater)} instead."
988
+ )
989
+ warnings.warn(
990
+ "Using WeightUpdaterBase is deprecated. Please use weight_sync_schemes instead. "
991
+ "This will be removed in a future version.",
992
+ DeprecationWarning,
993
+ stacklevel=2,
994
+ )
995
+ self.weight_updater = weight_updater
996
+ self._weight_sync_schemes = None
997
+ else:
998
+ # No weight sync needed for single-process collectors
999
+ self.weight_updater = None
1000
+ self._weight_sync_schemes = None
1001
+
1002
+ @property
1003
+ def _traj_pool(self):
1004
+ pool = getattr(self, "_traj_pool_val", None)
1005
+ if pool is None:
1006
+ pool = self._traj_pool_val = _TrajectoryPool()
1007
+ return pool
1008
+
1009
+ def _make_shuttle(self):
1010
+ # Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env
1011
+ with torch.no_grad():
1012
+ self._carrier = self.env.reset()
1013
+ if self.policy_device != self.env_device or self.env_device is None:
1014
+ self._shuttle_has_no_device = True
1015
+ self._carrier.clear_device_()
1016
+ else:
1017
+ self._shuttle_has_no_device = False
1018
+
1019
+ traj_ids = self._traj_pool.get_traj_and_increment(
1020
+ self.n_env, device=self.storing_device
1021
+ ).view(self.env.batch_size)
1022
+ self._carrier.set(
1023
+ ("collector", "traj_ids"),
1024
+ traj_ids,
1025
+ )
1026
+
1027
+ def _maybe_make_final_rollout(self, make_rollout: bool):
1028
+ if make_rollout:
1029
+ with torch.no_grad():
1030
+ self._final_rollout = self.env.fake_tensordict()
1031
+
1032
+ # If storing device is not None, we use this to cast the storage.
1033
+ # If it is None and the env and policy are on the same device,
1034
+ # the storing device is already the same as those, so we don't need
1035
+ # to consider this use case.
1036
+ # In all other cases, we can't really put a device on the storage,
1037
+ # since at least one data source has a device that is not clear.
1038
+ if self.storing_device:
1039
+ self._final_rollout = self._final_rollout.to(
1040
+ self.storing_device, non_blocking=True
1041
+ )
1042
+ else:
1043
+ # erase all devices
1044
+ self._final_rollout.clear_device_()
1045
+
1046
+ # Check if policy has meta-device parameters (not yet initialized)
1047
+ has_meta_params = False
1048
+ if hasattr(self, "_wrapped_policy_uncompiled") and isinstance(
1049
+ self._wrapped_policy_uncompiled, nn.Module
1050
+ ):
1051
+ for p in self._wrapped_policy_uncompiled.parameters():
1052
+ if p.device.type == "meta":
1053
+ has_meta_params = True
1054
+ break
1055
+
1056
+ # If the policy has a valid spec, we use it
1057
+ self._policy_output_keys = set()
1058
+ _policy_to_check = (
1059
+ self._wrapped_policy_uncompiled if has_meta_params else self._wrapped_policy
1060
+ )
1061
+ _has_spec = hasattr(_policy_to_check, "spec")
1062
+ _spec_not_none = False
1063
+ _all_values_not_none = False
1064
+ if _has_spec:
1065
+ _spec = _policy_to_check.spec
1066
+ _spec_not_none = _spec is not None
1067
+ if _spec_not_none:
1068
+ _all_values_not_none = all(
1069
+ v is not None for v in _spec.values(True, True)
1070
+ )
1071
+ _condition = (
1072
+ make_rollout and _has_spec and _spec_not_none and _all_values_not_none
1073
+ )
1074
+ if _condition:
1075
+ if any(
1076
+ key not in self._final_rollout.keys(isinstance(key, tuple))
1077
+ for key in (
1078
+ self._wrapped_policy_uncompiled
1079
+ if has_meta_params
1080
+ else self._wrapped_policy
1081
+ ).spec.keys(True, True)
1082
+ ):
1083
+ # if policy spec is non-empty, all the values are not None and the keys
1084
+ # match the out_keys we assume the user has given all relevant information
1085
+ # the policy could have more keys than the env:
1086
+ policy_spec = (
1087
+ self._wrapped_policy_uncompiled
1088
+ if has_meta_params
1089
+ else self._wrapped_policy
1090
+ ).spec
1091
+ if policy_spec.ndim < self._final_rollout.ndim:
1092
+ policy_spec = policy_spec.expand(self._final_rollout.shape)
1093
+ for key, spec in policy_spec.items(True, True):
1094
+ self._policy_output_keys.add(key)
1095
+ if key in self._final_rollout.keys(True):
1096
+ continue
1097
+ self._final_rollout.set(key, spec.zero())
1098
+ elif (
1099
+ not make_rollout
1100
+ and hasattr(
1101
+ self._wrapped_policy_uncompiled
1102
+ if has_meta_params
1103
+ else self._wrapped_policy,
1104
+ "out_keys",
1105
+ )
1106
+ and (
1107
+ self._wrapped_policy_uncompiled
1108
+ if has_meta_params
1109
+ else self._wrapped_policy
1110
+ ).out_keys
1111
+ ):
1112
+ self._policy_output_keys = list(
1113
+ (
1114
+ self._wrapped_policy_uncompiled
1115
+ if has_meta_params
1116
+ else self._wrapped_policy
1117
+ ).out_keys
1118
+ )
1119
+ elif has_meta_params:
1120
+ # Policy has meta params and no spec/out_keys - defer initialization
1121
+ # Mark that we need to initialize later when weights are loaded
1122
+ self._policy_output_keys = set()
1123
+ if make_rollout:
1124
+ # We'll populate keys on first actual rollout after weights are loaded
1125
+ self._final_rollout_needs_init = True
1126
+ else:
1127
+ if make_rollout:
1128
+ # otherwise, we perform a small number of steps with the policy to
1129
+ # determine the relevant keys with which to pre-populate _final_rollout.
1130
+ # This is the safest thing to do if the spec has None fields or if there is
1131
+ # no spec at all.
1132
+ # See #505 for additional context.
1133
+ self._final_rollout.update(self._carrier.copy())
1134
+ with torch.no_grad():
1135
+ policy_input = self._carrier.copy()
1136
+ if self.policy_device:
1137
+ policy_input = policy_input.to(self.policy_device)
1138
+ # we cast to policy device, we'll deal with the device later
1139
+ policy_input_copy = policy_input.copy()
1140
+ policy_input_clone = (
1141
+ policy_input.clone()
1142
+ ) # to test if values have changed in-place
1143
+ if self.compiled_policy:
1144
+ cudagraph_mark_step_begin()
1145
+ policy_output = self._wrapped_policy(policy_input)
1146
+
1147
+ # check that we don't have exclusive keys, because they don't appear in keys
1148
+ def check_exclusive(val):
1149
+ if (
1150
+ isinstance(val, LazyStackedTensorDict)
1151
+ and val._has_exclusive_keys
1152
+ ):
1153
+ raise RuntimeError(
1154
+ "LazyStackedTensorDict with exclusive keys are not permitted in collectors. "
1155
+ "Consider using a placeholder for missing keys."
1156
+ )
1157
+
1158
+ policy_output._fast_apply(
1159
+ check_exclusive, call_on_nested=True, filter_empty=True
1160
+ )
1161
+
1162
+ # Use apply, because it works well with lazy stacks
1163
+ # Edge-case of this approach: the policy may change the values in-place and only by a tiny bit
1164
+ # or occasionally. In these cases, the keys will be missed (we can't detect if the policy has
1165
+ # changed them here).
1166
+ # This will cause a failure to update entries when policy and env device mismatch and
1167
+ # casting is necessary.
1168
+ def filter_policy(name, value_output, value_input, value_input_clone):
1169
+ if (value_input is None) or (
1170
+ (value_output is not value_input)
1171
+ and (
1172
+ value_output.device != value_input_clone.device
1173
+ or ~torch.isclose(value_output, value_input_clone).any()
1174
+ )
1175
+ ):
1176
+ return value_output
1177
+
1178
+ filtered_policy_output = policy_output.apply(
1179
+ filter_policy,
1180
+ policy_input_copy,
1181
+ policy_input_clone,
1182
+ default=None,
1183
+ filter_empty=True,
1184
+ named=True,
1185
+ )
1186
+ self._policy_output_keys = list(
1187
+ self._policy_output_keys.union(
1188
+ set(filtered_policy_output.keys(True, True))
1189
+ )
1190
+ )
1191
+ if make_rollout:
1192
+ self._final_rollout.update(
1193
+ policy_output.select(*self._policy_output_keys)
1194
+ )
1195
+ del filtered_policy_output, policy_output, policy_input
1196
+
1197
+ _env_output_keys = []
1198
+ for spec in ["full_observation_spec", "full_done_spec", "full_reward_spec"]:
1199
+ _env_output_keys += list(self.env.output_spec[spec].keys(True, True))
1200
+ self._env_output_keys = _env_output_keys
1201
+ if make_rollout:
1202
+ self._final_rollout = (
1203
+ self._final_rollout.unsqueeze(-1)
1204
+ .expand(*self.env.batch_size, self.frames_per_batch)
1205
+ .clone()
1206
+ .zero_()
1207
+ )
1208
+
1209
+ # in addition to outputs of the policy, we add traj_ids to
1210
+ # _final_rollout which will be collected during rollout
1211
+ self._final_rollout.set(
1212
+ ("collector", "traj_ids"),
1213
+ torch.zeros(
1214
+ *self._final_rollout.batch_size,
1215
+ dtype=torch.int64,
1216
+ device=self.storing_device,
1217
+ ),
1218
+ )
1219
+ self._final_rollout.refine_names(..., "time")
1220
+
1221
+ def _set_truncated_keys(self):
1222
+ self._truncated_keys = []
1223
+ if self.set_truncated:
1224
+ if not any(_ends_with(key, "truncated") for key in self.env.done_keys):
1225
+ raise RuntimeError(
1226
+ "set_truncated was set to True but no truncated key could be found "
1227
+ "in the environment. Make sure the truncated keys are properly set using "
1228
+ "`env.add_truncated_keys()` before passing the env to the collector."
1229
+ )
1230
+ self._truncated_keys = [
1231
+ key for key in self.env.done_keys if _ends_with(key, "truncated")
1232
+ ]
1233
+
1234
+ @classmethod
1235
+ def _get_devices(
1236
+ cls,
1237
+ *,
1238
+ storing_device: torch.device,
1239
+ policy_device: torch.device,
1240
+ env_device: torch.device,
1241
+ device: torch.device,
1242
+ ):
1243
+ device = _make_ordinal_device(torch.device(device) if device else device)
1244
+ storing_device = _make_ordinal_device(
1245
+ torch.device(storing_device) if storing_device else device
1246
+ )
1247
+ policy_device = _make_ordinal_device(
1248
+ torch.device(policy_device) if policy_device else device
1249
+ )
1250
+ env_device = _make_ordinal_device(
1251
+ torch.device(env_device) if env_device else device
1252
+ )
1253
+ if storing_device is None and (env_device == policy_device):
1254
+ storing_device = env_device
1255
+ return storing_device, policy_device, env_device
1256
+
1257
+ # for RPC
1258
+ def next(self):
1259
+ return super().next()
1260
+
1261
+ # for RPC
1262
+ def update_policy_weights_(
1263
+ self,
1264
+ policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
1265
+ *,
1266
+ worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
1267
+ **kwargs,
1268
+ ) -> None:
1269
+ if "policy_weights" in kwargs:
1270
+ warnings.warn(
1271
+ "`policy_weights` is deprecated. Use `policy_or_weights` instead.",
1272
+ DeprecationWarning,
1273
+ )
1274
+ policy_or_weights = kwargs.pop("policy_weights")
1275
+
1276
+ super().update_policy_weights_(
1277
+ policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
1278
+ )
1279
+
1280
+ def _maybe_fallback_update(
1281
+ self,
1282
+ policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
1283
+ *,
1284
+ model_id: str | None = None,
1285
+ ) -> None:
1286
+ """Copy weights from original policy to internal policy when no scheme configured."""
1287
+ if model_id is not None and model_id != "policy":
1288
+ return
1289
+
1290
+ # Get source weights - either from argument or from original policy
1291
+ if policy_or_weights is not None:
1292
+ weights = self._extract_weights_if_needed(policy_or_weights, "policy")
1293
+ elif self._orig_policy is not None:
1294
+ weights = TensorDict.from_module(self._orig_policy)
1295
+ else:
1296
+ return
1297
+
1298
+ # Apply to internal policy
1299
+ if (
1300
+ hasattr(self, "_policy_w_state_dict")
1301
+ and self._policy_w_state_dict is not None
1302
+ ):
1303
+ TensorDict.from_module(self._policy_w_state_dict).data.update_(weights.data)
1304
+
1305
+ def set_seed(self, seed: int, static_seed: bool = False) -> int:
1306
+ """Sets the seeds of the environments stored in the DataCollector.
1307
+
1308
+ Args:
1309
+ seed (int): integer representing the seed to be used for the environment.
1310
+ static_seed(bool, optional): if ``True``, the seed is not incremented.
1311
+ Defaults to False
1312
+
1313
+ Returns:
1314
+ Output seed. This is useful when more than one environment is contained in the DataCollector, as the
1315
+ seed will be incremented for each of these. The resulting seed is the seed of the last environment.
1316
+
1317
+ Examples:
1318
+ >>> from torchrl.envs import ParallelEnv
1319
+ >>> from torchrl.envs.libs.gym import GymEnv
1320
+ >>> from tensordict.nn import TensorDictModule
1321
+ >>> from torch import nn
1322
+ >>> env_fn = lambda: GymEnv("Pendulum-v1")
1323
+ >>> env_fn_parallel = ParallelEnv(6, env_fn)
1324
+ >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
1325
+ >>> collector = Collector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100)
1326
+ >>> out_seed = collector.set_seed(1) # out_seed = 6
1327
+
1328
+ """
1329
+ out = self.env.set_seed(seed, static_seed=static_seed)
1330
+ return out
1331
+
1332
+ def _increment_frames(self, numel):
1333
+ self._frames += numel
1334
+ completed = self._frames >= self.total_frames
1335
+ if completed:
1336
+ self.env.close()
1337
+ return completed
1338
+
1339
+ def iterator(self) -> Iterator[TensorDictBase]:
1340
+ """Iterates through the DataCollector.
1341
+
1342
+ Yields: TensorDictBase objects containing (chunks of) trajectories
1343
+
1344
+ """
1345
+ if (
1346
+ not self.no_cuda_sync
1347
+ and self.storing_device
1348
+ and self.storing_device.type == "cuda"
1349
+ ):
1350
+ stream = torch.cuda.Stream(self.storing_device, priority=-1)
1351
+ event = stream.record_event()
1352
+ streams = [stream]
1353
+ events = [event]
1354
+ elif not self.no_cuda_sync and self.storing_device is None:
1355
+ streams = []
1356
+ events = []
1357
+ # this way of checking cuda is robust to lazy stacks with mismatching shapes
1358
+ cuda_devices = set()
1359
+
1360
+ def cuda_check(tensor: torch.Tensor):
1361
+ if tensor.is_cuda:
1362
+ cuda_devices.add(tensor.device)
1363
+
1364
+ if not self._use_buffers:
1365
+ # This may be a bit dangerous as `torch.device("cuda")` may not have a precise
1366
+ # device associated, whereas `tensor.device` always has
1367
+ for spec in self.env.specs.values(True, True):
1368
+ if spec.device is not None and spec.device.type == "cuda":
1369
+ if ":" not in str(spec.device):
1370
+ raise RuntimeError(
1371
+ "A cuda spec did not have a device associated. Make sure to "
1372
+ "pass `'cuda:device_num'` to each spec device."
1373
+ )
1374
+ cuda_devices.add(spec.device)
1375
+ else:
1376
+ self._final_rollout.apply(cuda_check, filter_empty=True)
1377
+ for device in cuda_devices:
1378
+ streams.append(torch.cuda.Stream(device, priority=-1))
1379
+ events.append(streams[-1].record_event())
1380
+ else:
1381
+ streams = []
1382
+ events = []
1383
+
1384
+ # Set up profiler if configured
1385
+ profiler = None
1386
+ if self._profile_config is not None:
1387
+ profiler = _CollectorProfiler(self._profile_config)
1388
+ if profiler.is_active:
1389
+ profiler.start()
1390
+
1391
+ with contextlib.ExitStack() as stack:
1392
+ for stream in streams:
1393
+ stack.enter_context(torch.cuda.stream(stream))
1394
+
1395
+ while self._frames < self.total_frames:
1396
+ self._iter += 1
1397
+
1398
+ # Use profiler context if profiling is active
1399
+ profile_ctx = (
1400
+ profiler.profile_rollout()
1401
+ if profiler is not None and profiler.is_active
1402
+ else contextlib.nullcontext()
1403
+ )
1404
+
1405
+ with profile_ctx:
1406
+ tensordict_out = self.rollout()
1407
+
1408
+ # Step the profiler after each rollout
1409
+ if profiler is not None and profiler.is_active:
1410
+ profiler.step()
1411
+
1412
+ if tensordict_out is None:
1413
+ # if a replay buffer is passed and self.extend_buffer=False, there is no tensordict_out
1414
+ # frames are updated within the rollout function
1415
+ yield
1416
+ continue
1417
+ self._increment_frames(tensordict_out.numel())
1418
+ tensordict_out = self._postproc(tensordict_out)
1419
+ if self.return_same_td:
1420
+ # This is used with multiprocessed collectors to use the buffers
1421
+ # stored in the tensordict.
1422
+ if events:
1423
+ for event in events:
1424
+ event.record()
1425
+ event.synchronize()
1426
+ yield tensordict_out
1427
+ elif self.replay_buffer is not None and not self._ignore_rb:
1428
+ self.replay_buffer.extend(tensordict_out)
1429
+ yield
1430
+ else:
1431
+ # we must clone the values, as the tensordict is updated in-place.
1432
+ # otherwise the following code may break:
1433
+ # >>> for i, data in enumerate(collector):
1434
+ # >>> if i == 0:
1435
+ # >>> data0 = data
1436
+ # >>> elif i == 1:
1437
+ # >>> data1 = data
1438
+ # >>> else:
1439
+ # >>> break
1440
+ # >>> assert data0["done"] is not data1["done"]
1441
+ yield tensordict_out.clone()
1442
+
1443
+ # Stop profiler if it hasn't been stopped yet
1444
+ if profiler is not None and profiler.is_active:
1445
+ profiler.stop()
1446
+
1447
+ def start(self):
1448
+ """Starts the collector in a separate thread for asynchronous data collection.
1449
+
1450
+ The collected data is stored in the provided replay buffer. This method is useful when you want to decouple data
1451
+ collection from training, allowing your training loop to run independently of the data collection process.
1452
+
1453
+ Raises:
1454
+ RuntimeError: If no replay buffer is defined during the collector's initialization.
1455
+
1456
+ Example:
1457
+ >>> from torchrl.modules import RandomPolicy >>> >>> import time
1458
+ >>> from functools import partial
1459
+ >>>
1460
+ >>> import tqdm
1461
+ >>>
1462
+ >>> from torchrl.collectors import Collector
1463
+ >>> from torchrl.data import LazyTensorStorage, ReplayBuffer
1464
+ >>> from torchrl.envs import GymEnv, set_gym_backend
1465
+ >>> import ale_py
1466
+ >>>
1467
+ >>> # Set the gym backend to gymnasium
1468
+ >>> set_gym_backend("gymnasium").set()
1469
+ >>>
1470
+ >>> if __name__ == "__main__":
1471
+ ... # Create a random policy for the Pong environment
1472
+ ... env = GymEnv("ALE/Pong-v5")
1473
+ ... policy = RandomPolicy(env.action_spec)
1474
+ ...
1475
+ ... # Initialize a shared replay buffer
1476
+ ... rb = ReplayBuffer(storage=LazyTensorStorage(1000), shared=True)
1477
+ ...
1478
+ ... # Create a synchronous data collector
1479
+ ... collector = Collector(
1480
+ ... env,
1481
+ ... policy=policy,
1482
+ ... replay_buffer=rb,
1483
+ ... frames_per_batch=256,
1484
+ ... total_frames=-1,
1485
+ ... )
1486
+ ...
1487
+ ... # Progress bar to track the number of collected frames
1488
+ ... pbar = tqdm.tqdm(total=100_000)
1489
+ ...
1490
+ ... # Start the collector asynchronously
1491
+ ... collector.start()
1492
+ ...
1493
+ ... # Track the write count of the replay buffer
1494
+ ... prec_wc = 0
1495
+ ... while True:
1496
+ ... wc = rb.write_count
1497
+ ... c = wc - prec_wc
1498
+ ... prec_wc = wc
1499
+ ...
1500
+ ... # Update the progress bar
1501
+ ... pbar.update(c)
1502
+ ... pbar.set_description(f"Write Count: {rb.write_count}")
1503
+ ...
1504
+ ... # Check the write count every 0.5 seconds
1505
+ ... time.sleep(0.5)
1506
+ ...
1507
+ ... # Stop when the desired number of frames is reached
1508
+ ... if rb.write_count . 100_000:
1509
+ ... break
1510
+ ...
1511
+ ... # Shut down the collector
1512
+ ... collector.async_shutdown()
1513
+ """
1514
+ if self.replay_buffer is None:
1515
+ raise RuntimeError("Replay buffer must be defined for execution.")
1516
+ if not self.is_running():
1517
+ self._stop = False
1518
+ self._thread = threading.Thread(target=self._run_iterator)
1519
+ self._thread.daemon = (
1520
+ True # So that the thread dies when the main program exits
1521
+ )
1522
+ self._thread.start()
1523
+
1524
+ def _run_iterator(self):
1525
+ for _ in self:
1526
+ if self._stop:
1527
+ return
1528
+
1529
+ def is_running(self):
1530
+ return hasattr(self, "_thread") and self._thread.is_alive()
1531
+
1532
+ def _should_use_random_frames(self) -> bool:
1533
+ """Determine if random frames should be used instead of the policy.
1534
+
1535
+ When a replay buffer is provided, uses `replay_buffer.write_count` as the
1536
+ global step counter to support `.start()` mode where `_frames` isn't updated
1537
+ until after collection. Otherwise, uses the internal `_frames` counter.
1538
+
1539
+ Returns:
1540
+ bool: True if random frames should be used, False otherwise.
1541
+ """
1542
+ if self.init_random_frames is None or self.init_random_frames <= 0:
1543
+ return False
1544
+ # Use replay_buffer.write_count when available for accurate counting in .start() mode
1545
+ if self.replay_buffer is not None:
1546
+ return self.replay_buffer.write_count < self.init_random_frames
1547
+ return self._frames < self.init_random_frames
1548
+
1549
+ def async_shutdown(
1550
+ self, timeout: float | None = None, close_env: bool = True
1551
+ ) -> None:
1552
+ """Finishes processes started by ray.init() during async execution."""
1553
+ self._stop = True
1554
+ if hasattr(self, "_thread") and self._thread.is_alive():
1555
+ self._thread.join(timeout=timeout)
1556
+ self.shutdown(close_env=close_env)
1557
+
1558
+ def _postproc(self, tensordict_out):
1559
+ if self.split_trajs:
1560
+ tensordict_out = split_trajectories(tensordict_out, prefix="collector")
1561
+ if self.postproc is not None:
1562
+ tensordict_out = self.postproc(tensordict_out)
1563
+ if self._exclude_private_keys:
1564
+
1565
+ def is_private(key):
1566
+ if isinstance(key, str) and key.startswith("_"):
1567
+ return True
1568
+ if isinstance(key, tuple) and any(_key.startswith("_") for _key in key):
1569
+ return True
1570
+ return False
1571
+
1572
+ excluded_keys = [
1573
+ key for key in tensordict_out.keys(True) if is_private(key)
1574
+ ]
1575
+ tensordict_out = tensordict_out.exclude(*excluded_keys, inplace=True)
1576
+ return tensordict_out
1577
+
1578
+ def _update_traj_ids(self, env_output) -> None:
1579
+ # we can't use the reset keys because they're gone
1580
+ traj_sop = _aggregate_end_of_traj(
1581
+ env_output.get("next"), done_keys=self.env.done_keys
1582
+ )
1583
+ if traj_sop.any():
1584
+ device = self.storing_device
1585
+
1586
+ traj_ids = self._carrier.get(("collector", "traj_ids"))
1587
+ if device is not None:
1588
+ traj_ids = traj_ids.to(device)
1589
+ traj_sop = traj_sop.to(device)
1590
+ elif traj_sop.device != traj_ids.device:
1591
+ traj_sop = traj_sop.to(traj_ids.device)
1592
+
1593
+ pool = self._traj_pool
1594
+ new_traj = pool.get_traj_and_increment(
1595
+ traj_sop.sum(), device=traj_sop.device
1596
+ )
1597
+ traj_ids = traj_ids.masked_scatter(traj_sop, new_traj)
1598
+ self._carrier.set(("collector", "traj_ids"), traj_ids)
1599
+
1600
+ @torch.no_grad()
1601
+ def rollout(self) -> TensorDictBase:
1602
+ """Computes a rollout in the environment using the provided policy.
1603
+
1604
+ Returns:
1605
+ TensorDictBase containing the computed rollout.
1606
+
1607
+ """
1608
+ if self.reset_at_each_iter:
1609
+ self._carrier.update(self.env.reset())
1610
+
1611
+ # self._shuttle.fill_(("collector", "step_count"), 0)
1612
+ if self._use_buffers:
1613
+ self._final_rollout.fill_(("collector", "traj_ids"), -1)
1614
+ else:
1615
+ pass
1616
+ tensordicts = []
1617
+ with set_exploration_type(self.exploration_type):
1618
+ for t in range(self.frames_per_batch):
1619
+ if self._should_use_random_frames():
1620
+ self.env.rand_action(self._carrier)
1621
+ if (
1622
+ self.policy_device is not None
1623
+ and self.policy_device != self.env_device
1624
+ ):
1625
+ # TODO: This may break with exclusive / ragged lazy stacks
1626
+ self._carrier.apply(
1627
+ lambda name, val: val.to(
1628
+ device=self.policy_device, non_blocking=True
1629
+ )
1630
+ if name in self._policy_output_keys
1631
+ else val,
1632
+ out=self._carrier,
1633
+ named=True,
1634
+ nested_keys=True,
1635
+ )
1636
+ else:
1637
+ if self._cast_to_policy_device:
1638
+ if self.policy_device is not None:
1639
+ # This is unsafe if the shuttle is in pin_memory -- otherwise cuda will be happy with non_blocking
1640
+ non_blocking = (
1641
+ not self.no_cuda_sync
1642
+ or self.policy_device.type == "cuda"
1643
+ )
1644
+ policy_input = self._carrier.to(
1645
+ self.policy_device,
1646
+ non_blocking=non_blocking,
1647
+ )
1648
+ if not self.no_cuda_sync:
1649
+ self._sync_policy()
1650
+ elif self.policy_device is None:
1651
+ # we know the tensordict has a device otherwise we would not be here
1652
+ # we can pass this, clear_device_ must have been called earlier
1653
+ # policy_input = self._shuttle.clear_device_()
1654
+ policy_input = self._carrier
1655
+ else:
1656
+ policy_input = self._carrier
1657
+ # we still do the assignment for security
1658
+ if self.compiled_policy:
1659
+ cudagraph_mark_step_begin()
1660
+ policy_output = self._wrapped_policy(policy_input)
1661
+ if self.compiled_policy:
1662
+ policy_output = policy_output.clone()
1663
+ if self._carrier is not policy_output:
1664
+ # ad-hoc update shuttle
1665
+ self._carrier.update(
1666
+ policy_output, keys_to_update=self._policy_output_keys
1667
+ )
1668
+
1669
+ if self._cast_to_env_device:
1670
+ if self.env_device is not None:
1671
+ non_blocking = (
1672
+ not self.no_cuda_sync or self.env_device.type == "cuda"
1673
+ )
1674
+ env_input = self._carrier.to(
1675
+ self.env_device, non_blocking=non_blocking
1676
+ )
1677
+ if not self.no_cuda_sync:
1678
+ self._sync_env()
1679
+ elif self.env_device is None:
1680
+ # we know the tensordict has a device otherwise we would not be here
1681
+ # we can pass this, clear_device_ must have been called earlier
1682
+ # env_input = self._shuttle.clear_device_()
1683
+ env_input = self._carrier
1684
+ else:
1685
+ env_input = self._carrier
1686
+ env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
1687
+
1688
+ if self._carrier is not env_output:
1689
+ # ad-hoc update shuttle
1690
+ next_data = env_output.get("next")
1691
+ if self._shuttle_has_no_device:
1692
+ # Make sure
1693
+ next_data.clear_device_()
1694
+ self._carrier.set("next", next_data)
1695
+
1696
+ if (
1697
+ self.replay_buffer is not None
1698
+ and not self._ignore_rb
1699
+ and not self.extend_buffer
1700
+ ):
1701
+ self.replay_buffer.add(self._carrier)
1702
+ if self._increment_frames(self._carrier.numel()):
1703
+ return
1704
+ else:
1705
+ if self.storing_device is not None:
1706
+ non_blocking = (
1707
+ not self.no_cuda_sync or self.storing_device.type == "cuda"
1708
+ )
1709
+ tensordicts.append(
1710
+ self._carrier.to(
1711
+ self.storing_device, non_blocking=non_blocking
1712
+ )
1713
+ )
1714
+ if not self.no_cuda_sync:
1715
+ self._sync_storage()
1716
+ else:
1717
+ tensordicts.append(self._carrier)
1718
+
1719
+ # carry over collector data without messing up devices
1720
+ collector_data = self._carrier.get("collector").copy()
1721
+ self._carrier = env_next_output
1722
+ if self._shuttle_has_no_device:
1723
+ self._carrier.clear_device_()
1724
+ self._carrier.set("collector", collector_data)
1725
+ self._update_traj_ids(env_output)
1726
+
1727
+ if (
1728
+ self.interruptor is not None
1729
+ and self.interruptor.collection_stopped()
1730
+ ):
1731
+ if (
1732
+ self.replay_buffer is not None
1733
+ and not self._ignore_rb
1734
+ and not self.extend_buffer
1735
+ ):
1736
+ return
1737
+ result = self._final_rollout
1738
+ if self._use_buffers:
1739
+ try:
1740
+ torch.stack(
1741
+ tensordicts,
1742
+ self._final_rollout.ndim - 1,
1743
+ out=self._final_rollout[..., : t + 1],
1744
+ )
1745
+ except RuntimeError:
1746
+ with self._final_rollout.unlock_():
1747
+ torch.stack(
1748
+ tensordicts,
1749
+ self._final_rollout.ndim - 1,
1750
+ out=self._final_rollout[..., : t + 1],
1751
+ )
1752
+ else:
1753
+ result = TensorDict.maybe_dense_stack(tensordicts, dim=-1)
1754
+ break
1755
+ else:
1756
+ if self._use_buffers:
1757
+ result = self._final_rollout
1758
+ try:
1759
+ result = torch.stack(
1760
+ tensordicts,
1761
+ self._final_rollout.ndim - 1,
1762
+ out=self._final_rollout,
1763
+ )
1764
+
1765
+ except RuntimeError:
1766
+ with self._final_rollout.unlock_():
1767
+ result = torch.stack(
1768
+ tensordicts,
1769
+ self._final_rollout.ndim - 1,
1770
+ out=self._final_rollout,
1771
+ )
1772
+ elif (
1773
+ self.replay_buffer is not None
1774
+ and not self._ignore_rb
1775
+ and not self.extend_buffer
1776
+ ):
1777
+ return
1778
+ else:
1779
+ result = TensorDict.maybe_dense_stack(tensordicts, dim=-1)
1780
+ result.refine_names(..., "time")
1781
+
1782
+ return self._maybe_set_truncated(result)
1783
+
1784
+ def _maybe_set_truncated(self, final_rollout):
1785
+ last_step = (slice(None),) * (final_rollout.ndim - 1) + (-1,)
1786
+ for truncated_key in self._truncated_keys:
1787
+ truncated = final_rollout["next", truncated_key]
1788
+ truncated[last_step] = True
1789
+ final_rollout["next", truncated_key] = truncated
1790
+ done = final_rollout["next", _replace_last(truncated_key, "done")]
1791
+ final_rollout["next", _replace_last(truncated_key, "done")] = (
1792
+ done | truncated
1793
+ )
1794
+ return final_rollout
1795
+
1796
+ @torch.no_grad()
1797
+ def reset(self, index=None, **kwargs) -> None:
1798
+ """Resets the environments to a new initial state."""
1799
+ # metadata
1800
+ collector_metadata = self._carrier.get("collector").clone()
1801
+ if index is not None:
1802
+ # check that the env supports partial reset
1803
+ if prod(self.env.batch_size) == 0:
1804
+ raise RuntimeError("resetting unique env with index is not permitted.")
1805
+ for reset_key, done_keys in zip(
1806
+ self.env.reset_keys, self.env.done_keys_groups
1807
+ ):
1808
+ _reset = torch.zeros(
1809
+ self.env.full_done_spec[done_keys[0]].shape,
1810
+ dtype=torch.bool,
1811
+ device=self.env.device,
1812
+ )
1813
+ _reset[index] = 1
1814
+ self._carrier.set(reset_key, _reset)
1815
+ else:
1816
+ _reset = None
1817
+ self._carrier.zero_()
1818
+
1819
+ self._carrier.update(self.env.reset(**kwargs), inplace=True)
1820
+ collector_metadata["traj_ids"] = (
1821
+ collector_metadata["traj_ids"] - collector_metadata["traj_ids"].min()
1822
+ )
1823
+ self._carrier["collector"] = collector_metadata
1824
+
1825
+ def shutdown(
1826
+ self,
1827
+ timeout: float | None = None,
1828
+ close_env: bool = True,
1829
+ raise_on_error: bool = True,
1830
+ ) -> None:
1831
+ """Shuts down all workers and/or closes the local environment.
1832
+
1833
+ Args:
1834
+ timeout (float, optional): The timeout for closing pipes between workers.
1835
+ No effect for this class.
1836
+ close_env (bool, optional): Whether to close the environment. Defaults to `True`.
1837
+ raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`.
1838
+ """
1839
+ try:
1840
+ if not self.closed:
1841
+ self.closed = True
1842
+ del self._carrier
1843
+ if self._use_buffers:
1844
+ del self._final_rollout
1845
+ if close_env and not self.env.is_closed:
1846
+ self.env.close(raise_if_closed=raise_on_error)
1847
+ del self.env
1848
+ return
1849
+ except Exception as e:
1850
+ if raise_on_error:
1851
+ raise e
1852
+ else:
1853
+ pass
1854
+
1855
+ def __del__(self):
1856
+ try:
1857
+ self.shutdown()
1858
+ except Exception:
1859
+ # an AttributeError will typically be raised if the collector is deleted when the program ends.
1860
+ # In the future, insignificant changes to the close method may change the error type.
1861
+ # We excplicitely assume that any error raised during closure in
1862
+ # __del__ will not affect the program.
1863
+ pass
1864
+
1865
+ def state_dict(self) -> OrderedDict:
1866
+ """Returns the local state_dict of the data collector (environment and policy).
1867
+
1868
+ Returns:
1869
+ an ordered dictionary with fields :obj:`"policy_state_dict"` and
1870
+ `"env_state_dict"`.
1871
+
1872
+ """
1873
+ from torchrl.envs.batched_envs import BatchedEnvBase
1874
+
1875
+ if isinstance(self.env, TransformedEnv):
1876
+ env_state_dict = self.env.transform.state_dict()
1877
+ elif isinstance(self.env, BatchedEnvBase):
1878
+ env_state_dict = self.env.state_dict()
1879
+ else:
1880
+ env_state_dict = OrderedDict()
1881
+
1882
+ if hasattr(self, "_policy_w_state_dict"):
1883
+ policy_state_dict = self._policy_w_state_dict.state_dict()
1884
+ state_dict = OrderedDict(
1885
+ policy_state_dict=policy_state_dict,
1886
+ env_state_dict=env_state_dict,
1887
+ )
1888
+ else:
1889
+ state_dict = OrderedDict(env_state_dict=env_state_dict)
1890
+
1891
+ state_dict.update({"frames": self._frames, "iter": self._iter})
1892
+
1893
+ return state_dict
1894
+
1895
+ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:
1896
+ """Loads a state_dict on the environment and policy.
1897
+
1898
+ Args:
1899
+ state_dict (OrderedDict): ordered dictionary containing the fields
1900
+ `"policy_state_dict"` and :obj:`"env_state_dict"`.
1901
+
1902
+ """
1903
+ strict = kwargs.get("strict", True)
1904
+ if strict or "env_state_dict" in state_dict:
1905
+ self.env.load_state_dict(state_dict["env_state_dict"], **kwargs)
1906
+ if strict or "policy_state_dict" in state_dict:
1907
+ if not hasattr(self, "_policy_w_state_dict"):
1908
+ raise ValueError(
1909
+ "Underlying policy does not have state_dict to load policy_state_dict into."
1910
+ )
1911
+ self._policy_w_state_dict.load_state_dict(
1912
+ state_dict["policy_state_dict"], **kwargs
1913
+ )
1914
+ self._frames = state_dict["frames"]
1915
+ self._iter = state_dict["iter"]
1916
+
1917
+ def __repr__(self) -> str:
1918
+ try:
1919
+ env_str = indent(f"env={self.env}", 4 * " ")
1920
+ policy_str = indent(f"policy={self._wrapped_policy}", 4 * " ")
1921
+ td_out_str = repr(getattr(self, "_final_rollout", None))
1922
+ if len(td_out_str) > 50:
1923
+ td_out_str = td_out_str[:50] + "..."
1924
+ td_out_str = indent(f"td_out={td_out_str}", 4 * " ")
1925
+ string = (
1926
+ f"{self.__class__.__name__}("
1927
+ f"\n{env_str},"
1928
+ f"\n{policy_str},"
1929
+ f"\n{td_out_str},"
1930
+ f"\nexploration={self.exploration_type})"
1931
+ )
1932
+ return string
1933
+ except Exception:
1934
+ return f"{type(self).__name__}(not_init)"
1935
+
1936
+ def increment_version(self):
1937
+ """Increment the policy version."""
1938
+ if self.policy_version_tracker is not None:
1939
+ if not hasattr(self.policy_version_tracker, "increment_version"):
1940
+ raise RuntimeError(
1941
+ "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector."
1942
+ )
1943
+ self.policy_version_tracker.increment_version()
1944
+
1945
+ @property
1946
+ def policy_version(self) -> str | int | None:
1947
+ """The current policy version."""
1948
+ if not hasattr(self.policy_version_tracker, "version"):
1949
+ return None
1950
+ return self.policy_version_tracker.version
1951
+
1952
+ def get_policy_version(self) -> str | int | None:
1953
+ """Get the current policy version.
1954
+
1955
+ This method exists to support remote calls in Ray actors, since properties
1956
+ cannot be accessed directly through Ray's RPC mechanism.
1957
+
1958
+ Returns:
1959
+ The current version number (int) or UUID (str), or None if version tracking is disabled.
1960
+ """
1961
+ return self.policy_version
1962
+
1963
+ def getattr_policy(self, attr):
1964
+ """Get an attribute from the policy."""
1965
+ # send command to policy to return the attr
1966
+ return getattr(self._wrapped_policy, attr)
1967
+
1968
+ def getattr_env(self, attr):
1969
+ """Get an attribute from the environment."""
1970
+ # send command to env to return the attr
1971
+ return getattr(self.env, attr)
1972
+
1973
+ def getattr_rb(self, attr):
1974
+ """Get an attribute from the replay buffer."""
1975
+ # send command to rb to return the attr
1976
+ return getattr(self.replay_buffer, attr)
1977
+
1978
+ def get_model(self, model_id: str):
1979
+ """Get model instance by ID (for weight sync schemes).
1980
+
1981
+ Args:
1982
+ model_id: Model identifier (e.g., "policy", "value_net")
1983
+
1984
+ Returns:
1985
+ The model instance
1986
+
1987
+ Raises:
1988
+ ValueError: If model_id is not recognized
1989
+ """
1990
+ if model_id == "policy":
1991
+ # Return the unwrapped policy instance for weight synchronization
1992
+ # The unwrapped policy has the same parameter structure as what's
1993
+ # extracted in the main process, avoiding key mismatches when
1994
+ # the policy is auto-wrapped (e.g., WrappablePolicy -> TensorDictModule)
1995
+ if hasattr(self, "policy") and self.policy is not None:
1996
+ return self.policy
1997
+ else:
1998
+ raise ValueError(f"No policy found for model_id '{model_id}'")
1999
+ else:
2000
+ return _resolve_model(self, model_id)
2001
+
2002
+ def _receive_weights_scheme(self):
2003
+ return super()._receive_weights_scheme()
2004
+
2005
+
2006
+ class SyncDataCollector(Collector, metaclass=_LegacyCollectorMeta):
2007
+ """Deprecated version of :class:`~torchrl.collectors.Collector`."""
2008
+
2009
+ ...